import torch import av import numpy as np import os import requests import tempfile import gc import time import threading import uuid from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration class EndpointHandler: def __init__(self, path=""): # 1. SETUP model_id = "LanguageBind/Video-LLaVA-7B-hf" print(f"Loading model: {model_id}...") # Using bfloat16 to match your local script's success self.processor = VideoLlavaProcessor.from_pretrained(model_id) self.model = VideoLlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True ) self.model.eval() print("Model loaded successfully.") def download_video(self, video_url): # Exact logic from your script, adapted for class structure suffix = os.path.splitext(video_url)[1] or '.mp4' temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) temp_path = temp_file.name temp_file.close() try: # Added 30s timeout to prevent hanging, otherwise logic matches response = requests.get(video_url, stream=True, timeout=60) response.raise_for_status() # Helper to get size for logging file_size = int(response.headers.get('content-length', 0)) with open(temp_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) if file_size == 0: file_size = os.path.getsize(temp_path) print(f"Downloaded video ({file_size/1024/1024:.2f} MB) to {temp_path}") return temp_path except Exception as e: if os.path.exists(temp_path): os.unlink(temp_path) raise Exception(f"Failed to download video: {str(e)}") def read_video_pyav(self, container, indices): # The logic expected by VideoLlava frames = [] container.seek(0) start_index = indices[0] end_index = indices[-1] for i, frame in enumerate(container.decode(video=0)): if i > end_index: break if i >= start_index and i in indices: frames.append(frame) if not frames: raise ValueError("Video decoding failed: No frames found.") # Return list of numpy arrays (RGB) return [x.to_ndarray(format="rgb24") for x in frames] def trigger_webhook(self, url, payload): """ Sends payload to callback_url. Fire-and-forget style: catches errors so main execution doesn't fail. """ if not url: return print(f"Sending webhook to {url}") try: # 5s timeout ensures the HF Endpoint doesn't hang if your server is slow resp = requests.post(url, json=payload, timeout=5) resp.raise_for_status() print(f"Webhook success: {resp.status_code}") except Exception as e: # We print the error but do NOT raise it, ensuring the user still gets their result print(f"Webhook failed: {str(e)}") def _process_video(self, inputs, video_url, parameters, callback_url=None, request_id=None): """ Core video processing logic. Used by both sync and async paths. If callback_url is provided, sends result via webhook. Returns the response payload. """ # Start timing exactly like your script predict_start = time.time() print(f"\nStarting prediction at {time.strftime('%H:%M:%S')}") container = None video_path = None try: # 1. CONFIGURATION matches your script defaults # Your script defaulted to 10 frames num_frames = parameters.get("num_frames", 10) # Your script defaults: max 500, temp 0.1, top_p 0.9 max_new_tokens = parameters.get("max_new_tokens", 500) temperature = parameters.get("temperature", 0.1) top_p = parameters.get("top_p", 0.9) print(f"Prompt: {inputs}") # 2. DOWNLOAD video_path = self.download_video(video_url) container = av.open(video_path) # 3. FRAME EXTRACTION total_frames = container.streams.video[0].frames if total_frames == 0: total_frames = sum(1 for _ in container.decode(video=0)) container.seek(0) # Logic: frames_to_use = min(total_frames, num_frames) frames_to_use = min(total_frames, num_frames) if total_frames > 0 else num_frames print(f"Using {frames_to_use} frames") indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int) print(f"Using indices: {indices}") clip = self.read_video_pyav(container, indices) print(f"Extracted {len(clip)} frames") # 4. PROMPT CONSTRUCTION # We check if 'USER:' exists to allow your custom full prompts to pass through. # If it's a simple string, we apply your script's formatting exactly. if "USER:" in inputs: full_prompt = inputs else: full_prompt = f"USER: