| import torch |
| import av |
| import numpy as np |
| import os |
| import requests |
| import tempfile |
| import gc |
| import time |
| import threading |
| import uuid |
| from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| model_id = "LanguageBind/Video-LLaVA-7B-hf" |
| print(f"Loading model: {model_id}...") |
| |
| |
| self.processor = VideoLlavaProcessor.from_pretrained(model_id) |
| self.model = VideoLlavaForConditionalGeneration.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| low_cpu_mem_usage=True |
| ) |
| self.model.eval() |
| print("Model loaded successfully.") |
|
|
| def download_video(self, video_url): |
| |
| suffix = os.path.splitext(video_url)[1] or '.mp4' |
| temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) |
| temp_path = temp_file.name |
| temp_file.close() |
| |
| try: |
| |
| response = requests.get(video_url, stream=True, timeout=60) |
| response.raise_for_status() |
| |
| |
| file_size = int(response.headers.get('content-length', 0)) |
| |
| with open(temp_path, 'wb') as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| if chunk: |
| f.write(chunk) |
| |
| if file_size == 0: |
| file_size = os.path.getsize(temp_path) |
| |
| print(f"Downloaded video ({file_size/1024/1024:.2f} MB) to {temp_path}") |
| return temp_path |
|
|
| except Exception as e: |
| if os.path.exists(temp_path): |
| os.unlink(temp_path) |
| raise Exception(f"Failed to download video: {str(e)}") |
|
|
| def read_video_pyav(self, container, indices): |
| |
| frames = [] |
| container.seek(0) |
| start_index = indices[0] |
| end_index = indices[-1] |
| for i, frame in enumerate(container.decode(video=0)): |
| if i > end_index: |
| break |
| if i >= start_index and i in indices: |
| frames.append(frame) |
| |
| if not frames: |
| raise ValueError("Video decoding failed: No frames found.") |
|
|
| |
| return [x.to_ndarray(format="rgb24") for x in frames] |
|
|
| def trigger_webhook(self, url, payload): |
| """ |
| Sends payload to callback_url. |
| Fire-and-forget style: catches errors so main execution doesn't fail. |
| """ |
| if not url: |
| return |
| |
| print(f"Sending webhook to {url}") |
| try: |
| |
| resp = requests.post(url, json=payload, timeout=5) |
| resp.raise_for_status() |
| print(f"Webhook success: {resp.status_code}") |
| except Exception as e: |
| |
| print(f"Webhook failed: {str(e)}") |
|
|
| def _process_video(self, inputs, video_url, parameters, callback_url=None, request_id=None): |
| """ |
| Core video processing logic. Used by both sync and async paths. |
| If callback_url is provided, sends result via webhook. |
| Returns the response payload. |
| """ |
| |
| predict_start = time.time() |
| print(f"\nStarting prediction at {time.strftime('%H:%M:%S')}") |
| |
| container = None |
| video_path = None |
| |
| try: |
| |
| |
| num_frames = parameters.get("num_frames", 10) |
| |
| |
| max_new_tokens = parameters.get("max_new_tokens", 500) |
| temperature = parameters.get("temperature", 0.1) |
| top_p = parameters.get("top_p", 0.9) |
|
|
| print(f"Prompt: {inputs}") |
|
|
| |
| video_path = self.download_video(video_url) |
| container = av.open(video_path) |
| |
| |
| total_frames = container.streams.video[0].frames |
| if total_frames == 0: |
| total_frames = sum(1 for _ in container.decode(video=0)) |
| container.seek(0) |
| |
| |
| frames_to_use = min(total_frames, num_frames) if total_frames > 0 else num_frames |
| print(f"Using {frames_to_use} frames") |
| |
| indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int) |
| print(f"Using indices: {indices}") |
| |
| clip = self.read_video_pyav(container, indices) |
| print(f"Extracted {len(clip)} frames") |
|
|
| |
| |
| |
| if "USER:" in inputs: |
| full_prompt = inputs |
| else: |
| full_prompt = f"USER: <video>{inputs} ASSISTANT:" |
| |
| |
| model_inputs = self.processor( |
| text=full_prompt, |
| videos=clip, |
| return_tensors="pt" |
| ).to(self.model.device) |
|
|
| |
| with torch.inference_mode(): |
| generate_ids = self.model.generate( |
| **model_inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True if temperature > 0 else False |
| ) |
|
|
| |
| result = self.processor.batch_decode( |
| generate_ids, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| )[0] |
| |
| if "ASSISTANT:" in result: |
| final_output = result.split("ASSISTANT:")[-1].strip() |
| else: |
| final_output = result |
|
|
| |
| execution_time = f"{time.time() - predict_start:.2f}s" |
| print(f"Total prediction time: {execution_time}") |
| |
| response_payload = { |
| "generated_text": final_output, |
| "status": "success", |
| "execution_time": execution_time |
| } |
|
|
| |
| if callback_url: |
| webhook_data = { |
| "request_id": request_id, |
| "input_prompt": inputs, |
| "video_url": video_url, |
| "result": response_payload |
| } |
| self.trigger_webhook(callback_url, webhook_data) |
| |
| return response_payload |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| print(f"Inference failed: {str(e)}") |
| |
| error_payload = {"error": str(e), "status": "failed"} |
| |
| |
| if callback_url: |
| webhook_data = { |
| "request_id": request_id, |
| "input_prompt": inputs, |
| "video_url": video_url, |
| "result": error_payload |
| } |
| self.trigger_webhook(callback_url, webhook_data) |
| |
| return error_payload |
| |
| finally: |
| |
| if container: container.close() |
| if video_path and os.path.exists(video_path): |
| os.unlink(video_path) |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| def __call__(self, data): |
| |
| callback_url = data.get("callback_url", None) |
| inputs = data.get("inputs", "What is happening in this video?") |
| video_url = data.get("video", None) |
| parameters = data.get("parameters", {}) |
| |
| |
| request_id = str(uuid.uuid4()) |
| |
| |
| if not video_url: |
| return {"error": "Missing 'video' URL.", "status": "failed", "request_id": request_id} |
|
|
| |
| if callback_url: |
| print(f"Async mode: request_id={request_id}, will send result to {callback_url}") |
| |
| |
| thread = threading.Thread( |
| target=self._process_video, |
| args=(inputs, video_url, parameters, callback_url, request_id), |
| daemon=True |
| ) |
| thread.start() |
| |
| |
| return [{ |
| "request_id": request_id, |
| "status": "accepted", |
| "message": "Processing started. Result will be sent to callback_url.", |
| "callback_url": callback_url |
| }] |
| |
| |
| else: |
| result = self._process_video(inputs, video_url, parameters, request_id=request_id) |
| result["request_id"] = request_id |
| return [result] |
|
|