|
|
import torch |
|
|
import av |
|
|
import numpy as np |
|
|
import os |
|
|
import requests |
|
|
import tempfile |
|
|
import gc |
|
|
import time |
|
|
import threading |
|
|
import uuid |
|
|
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
|
|
|
model_id = "LanguageBind/Video-LLaVA-7B-hf" |
|
|
print(f"Loading model: {model_id}...") |
|
|
|
|
|
|
|
|
self.processor = VideoLlavaProcessor.from_pretrained(model_id) |
|
|
self.model = VideoLlavaForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
self.model.eval() |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
def download_video(self, video_url): |
|
|
|
|
|
suffix = os.path.splitext(video_url)[1] or '.mp4' |
|
|
temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) |
|
|
temp_path = temp_file.name |
|
|
temp_file.close() |
|
|
|
|
|
try: |
|
|
|
|
|
response = requests.get(video_url, stream=True, timeout=60) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
file_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(temp_path, 'wb') as f: |
|
|
for chunk in response.iter_content(chunk_size=8192): |
|
|
if chunk: |
|
|
f.write(chunk) |
|
|
|
|
|
if file_size == 0: |
|
|
file_size = os.path.getsize(temp_path) |
|
|
|
|
|
print(f"Downloaded video ({file_size/1024/1024:.2f} MB) to {temp_path}") |
|
|
return temp_path |
|
|
|
|
|
except Exception as e: |
|
|
if os.path.exists(temp_path): |
|
|
os.unlink(temp_path) |
|
|
raise Exception(f"Failed to download video: {str(e)}") |
|
|
|
|
|
def read_video_pyav(self, container, indices): |
|
|
|
|
|
frames = [] |
|
|
container.seek(0) |
|
|
start_index = indices[0] |
|
|
end_index = indices[-1] |
|
|
for i, frame in enumerate(container.decode(video=0)): |
|
|
if i > end_index: |
|
|
break |
|
|
if i >= start_index and i in indices: |
|
|
frames.append(frame) |
|
|
|
|
|
if not frames: |
|
|
raise ValueError("Video decoding failed: No frames found.") |
|
|
|
|
|
|
|
|
return [x.to_ndarray(format="rgb24") for x in frames] |
|
|
|
|
|
def trigger_webhook(self, url, payload): |
|
|
""" |
|
|
Sends payload to callback_url. |
|
|
Fire-and-forget style: catches errors so main execution doesn't fail. |
|
|
""" |
|
|
if not url: |
|
|
return |
|
|
|
|
|
print(f"Sending webhook to {url}") |
|
|
try: |
|
|
|
|
|
resp = requests.post(url, json=payload, timeout=5) |
|
|
resp.raise_for_status() |
|
|
print(f"Webhook success: {resp.status_code}") |
|
|
except Exception as e: |
|
|
|
|
|
print(f"Webhook failed: {str(e)}") |
|
|
|
|
|
def _process_video(self, inputs, video_url, parameters, callback_url=None, request_id=None): |
|
|
""" |
|
|
Core video processing logic. Used by both sync and async paths. |
|
|
If callback_url is provided, sends result via webhook. |
|
|
Returns the response payload. |
|
|
""" |
|
|
|
|
|
predict_start = time.time() |
|
|
print(f"\nStarting prediction at {time.strftime('%H:%M:%S')}") |
|
|
|
|
|
container = None |
|
|
video_path = None |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
num_frames = parameters.get("num_frames", 10) |
|
|
|
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 500) |
|
|
temperature = parameters.get("temperature", 0.1) |
|
|
top_p = parameters.get("top_p", 0.9) |
|
|
|
|
|
print(f"Prompt: {inputs}") |
|
|
|
|
|
|
|
|
video_path = self.download_video(video_url) |
|
|
container = av.open(video_path) |
|
|
|
|
|
|
|
|
total_frames = container.streams.video[0].frames |
|
|
if total_frames == 0: |
|
|
total_frames = sum(1 for _ in container.decode(video=0)) |
|
|
container.seek(0) |
|
|
|
|
|
|
|
|
frames_to_use = min(total_frames, num_frames) if total_frames > 0 else num_frames |
|
|
print(f"Using {frames_to_use} frames") |
|
|
|
|
|
indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int) |
|
|
print(f"Using indices: {indices}") |
|
|
|
|
|
clip = self.read_video_pyav(container, indices) |
|
|
print(f"Extracted {len(clip)} frames") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "USER:" in inputs: |
|
|
full_prompt = inputs |
|
|
else: |
|
|
full_prompt = f"USER: <video>{inputs} ASSISTANT:" |
|
|
|
|
|
|
|
|
model_inputs = self.processor( |
|
|
text=full_prompt, |
|
|
videos=clip, |
|
|
return_tensors="pt" |
|
|
).to(self.model.device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
generate_ids = self.model.generate( |
|
|
**model_inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=True if temperature > 0 else False |
|
|
) |
|
|
|
|
|
|
|
|
result = self.processor.batch_decode( |
|
|
generate_ids, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
if "ASSISTANT:" in result: |
|
|
final_output = result.split("ASSISTANT:")[-1].strip() |
|
|
else: |
|
|
final_output = result |
|
|
|
|
|
|
|
|
execution_time = f"{time.time() - predict_start:.2f}s" |
|
|
print(f"Total prediction time: {execution_time}") |
|
|
|
|
|
response_payload = { |
|
|
"generated_text": final_output, |
|
|
"status": "success", |
|
|
"execution_time": execution_time |
|
|
} |
|
|
|
|
|
|
|
|
if callback_url: |
|
|
webhook_data = { |
|
|
"request_id": request_id, |
|
|
"input_prompt": inputs, |
|
|
"video_url": video_url, |
|
|
"result": response_payload |
|
|
} |
|
|
self.trigger_webhook(callback_url, webhook_data) |
|
|
|
|
|
return response_payload |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
print(f"Inference failed: {str(e)}") |
|
|
|
|
|
error_payload = {"error": str(e), "status": "failed"} |
|
|
|
|
|
|
|
|
if callback_url: |
|
|
webhook_data = { |
|
|
"request_id": request_id, |
|
|
"input_prompt": inputs, |
|
|
"video_url": video_url, |
|
|
"result": error_payload |
|
|
} |
|
|
self.trigger_webhook(callback_url, webhook_data) |
|
|
|
|
|
return error_payload |
|
|
|
|
|
finally: |
|
|
|
|
|
if container: container.close() |
|
|
if video_path and os.path.exists(video_path): |
|
|
os.unlink(video_path) |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
callback_url = data.get("callback_url", None) |
|
|
inputs = data.get("inputs", "What is happening in this video?") |
|
|
video_url = data.get("video", None) |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
request_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
if not video_url: |
|
|
return {"error": "Missing 'video' URL.", "status": "failed", "request_id": request_id} |
|
|
|
|
|
|
|
|
if callback_url: |
|
|
print(f"Async mode: request_id={request_id}, will send result to {callback_url}") |
|
|
|
|
|
|
|
|
thread = threading.Thread( |
|
|
target=self._process_video, |
|
|
args=(inputs, video_url, parameters, callback_url, request_id), |
|
|
daemon=True |
|
|
) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
return [{ |
|
|
"request_id": request_id, |
|
|
"status": "accepted", |
|
|
"message": "Processing started. Result will be sent to callback_url.", |
|
|
"callback_url": callback_url |
|
|
}] |
|
|
|
|
|
|
|
|
else: |
|
|
result = self._process_video(inputs, video_url, parameters, request_id=request_id) |
|
|
result["request_id"] = request_id |
|
|
return [result] |
|
|
|