File size: 9,917 Bytes
641ac1a
 
 
 
 
 
23fc12f
 
8b840af
8165459
641ac1a
 
 
 
17e17dc
641ac1a
 
fcd38b0
17e17dc
641ac1a
 
 
000f3a6
641ac1a
 
 
 
 
 
 
17e17dc
641ac1a
 
 
 
17e17dc
641ac1a
17e17dc
 
 
 
 
 
 
641ac1a
 
 
 
17e17dc
 
 
 
 
641ac1a
17e17dc
641ac1a
17e17dc
 
 
641ac1a
fcd38b0
17e17dc
fcd38b0
 
 
 
 
 
 
 
 
23fc12f
 
 
 
17e17dc
fcd38b0
641ac1a
8b840af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8165459
8b840af
 
 
 
 
17e17dc
 
 
23fc12f
 
 
 
fcd38b0
8b840af
17e17dc
8b840af
23fc12f
17e17dc
8b840af
 
 
641ac1a
17e17dc
 
8b840af
641ac1a
 
 
8b840af
641ac1a
 
 
 
 
17e17dc
 
 
23fc12f
fcd38b0
17e17dc
 
641ac1a
17e17dc
fcd38b0
8b840af
17e17dc
 
23fc12f
 
 
000f3a6
641ac1a
8b840af
641ac1a
 
fcd38b0
641ac1a
 
 
8b840af
641ac1a
 
 
17e17dc
 
 
 
641ac1a
 
8b840af
17e17dc
 
 
 
 
641ac1a
23fc12f
 
 
 
 
8b840af
 
 
23fc12f
8b840af
 
 
 
 
 
 
 
 
8165459
8b840af
 
 
 
 
 
 
641ac1a
 
fcd38b0
 
8b840af
 
 
 
 
 
 
8165459
8b840af
 
 
 
 
 
 
23fc12f
641ac1a
17e17dc
23fc12f
 
 
 
8b840af
 
 
 
 
 
 
 
 
8165459
 
 
8b840af
 
8165459
8b840af
 
 
8165459
8b840af
 
 
 
8165459
8b840af
 
 
 
 
 
8165459
8b840af
 
 
 
 
 
 
8165459
 
8b840af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import torch
import av
import numpy as np
import os
import requests
import tempfile
import gc
import time
import threading
import uuid
from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration

class EndpointHandler:
    def __init__(self, path=""):
        # 1. SETUP
        model_id = "LanguageBind/Video-LLaVA-7B-hf"
        print(f"Loading model: {model_id}...")
        
        # Using bfloat16 to match your local script's success
        self.processor = VideoLlavaProcessor.from_pretrained(model_id)
        self.model = VideoLlavaForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16, 
            device_map="auto",
            low_cpu_mem_usage=True
        )
        self.model.eval()
        print("Model loaded successfully.")

    def download_video(self, video_url):
        # Exact logic from your script, adapted for class structure
        suffix = os.path.splitext(video_url)[1] or '.mp4'
        temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
        temp_path = temp_file.name
        temp_file.close()
        
        try:
            # Added 30s timeout to prevent hanging, otherwise logic matches
            response = requests.get(video_url, stream=True, timeout=60)
            response.raise_for_status()
            
            # Helper to get size for logging
            file_size = int(response.headers.get('content-length', 0))
            
            with open(temp_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
            
            if file_size == 0:
                file_size = os.path.getsize(temp_path)
                
            print(f"Downloaded video ({file_size/1024/1024:.2f} MB) to {temp_path}")
            return temp_path

        except Exception as e:
            if os.path.exists(temp_path):
                os.unlink(temp_path)
            raise Exception(f"Failed to download video: {str(e)}")

    def read_video_pyav(self, container, indices):
        # The logic expected by VideoLlava
        frames = []
        container.seek(0)
        start_index = indices[0]
        end_index = indices[-1]
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break
            if i >= start_index and i in indices:
                frames.append(frame)
        
        if not frames:
            raise ValueError("Video decoding failed: No frames found.")

        # Return list of numpy arrays (RGB)
        return [x.to_ndarray(format="rgb24") for x in frames]

    def trigger_webhook(self, url, payload):
        """
        Sends payload to callback_url. 
        Fire-and-forget style: catches errors so main execution doesn't fail.
        """
        if not url:
            return
            
        print(f"Sending webhook to {url}")
        try:
            # 5s timeout ensures the HF Endpoint doesn't hang if your server is slow
            resp = requests.post(url, json=payload, timeout=5)
            resp.raise_for_status()
            print(f"Webhook success: {resp.status_code}")
        except Exception as e:
            # We print the error but do NOT raise it, ensuring the user still gets their result
            print(f"Webhook failed: {str(e)}")

    def _process_video(self, inputs, video_url, parameters, callback_url=None, request_id=None):
        """
        Core video processing logic. Used by both sync and async paths.
        If callback_url is provided, sends result via webhook.
        Returns the response payload.
        """
        # Start timing exactly like your script
        predict_start = time.time()
        print(f"\nStarting prediction at {time.strftime('%H:%M:%S')}")
        
        container = None
        video_path = None
        
        try:
            # 1. CONFIGURATION matches your script defaults
            # Your script defaulted to 10 frames
            num_frames = parameters.get("num_frames", 10) 
            
            # Your script defaults: max 500, temp 0.1, top_p 0.9
            max_new_tokens = parameters.get("max_new_tokens", 500)
            temperature = parameters.get("temperature", 0.1)
            top_p = parameters.get("top_p", 0.9)

            print(f"Prompt: {inputs}")

            # 2. DOWNLOAD
            video_path = self.download_video(video_url)
            container = av.open(video_path)
            
            # 3. FRAME EXTRACTION
            total_frames = container.streams.video[0].frames
            if total_frames == 0:
                total_frames = sum(1 for _ in container.decode(video=0))
                container.seek(0)
            
            # Logic: frames_to_use = min(total_frames, num_frames)
            frames_to_use = min(total_frames, num_frames) if total_frames > 0 else num_frames
            print(f"Using {frames_to_use} frames")
            
            indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int)
            print(f"Using indices: {indices}")
            
            clip = self.read_video_pyav(container, indices)
            print(f"Extracted {len(clip)} frames")

            # 4. PROMPT CONSTRUCTION
            # We check if 'USER:' exists to allow your custom full prompts to pass through.
            # If it's a simple string, we apply your script's formatting exactly.
            if "USER:" in inputs:
                full_prompt = inputs
            else:
                full_prompt = f"USER: <video>{inputs} ASSISTANT:"
            
            # 5. TOKENIZE
            model_inputs = self.processor(
                text=full_prompt,
                videos=clip, 
                return_tensors="pt"
            ).to(self.model.device)

            # 6. GENERATE
            with torch.inference_mode():
                generate_ids = self.model.generate(
                    **model_inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True if temperature > 0 else False
                )

            # 7. DECODE
            result = self.processor.batch_decode(
                generate_ids, 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
            )[0]
            
            if "ASSISTANT:" in result:
                final_output = result.split("ASSISTANT:")[-1].strip()
            else:
                final_output = result

            # 8. END TIMING
            execution_time = f"{time.time() - predict_start:.2f}s"
            print(f"Total prediction time: {execution_time}")
            
            response_payload = {
                "generated_text": final_output,
                "status": "success",
                "execution_time": execution_time
            }

            # 9. SEND WEBHOOK (if callback_url provided)
            if callback_url:
                webhook_data = {
                    "request_id": request_id,
                    "input_prompt": inputs,
                    "video_url": video_url,
                    "result": response_payload
                }
                self.trigger_webhook(callback_url, webhook_data)
            
            return response_payload

        except Exception as e:
            import traceback
            traceback.print_exc()
            print(f"Inference failed: {str(e)}")
            
            error_payload = {"error": str(e), "status": "failed"}
            
            # Send error via webhook if callback_url provided
            if callback_url:
                webhook_data = {
                    "request_id": request_id,
                    "input_prompt": inputs,
                    "video_url": video_url,
                    "result": error_payload
                }
                self.trigger_webhook(callback_url, webhook_data)
            
            return error_payload
            
        finally:
            # Cleanup
            if container: container.close()
            if video_path and os.path.exists(video_path):
                os.unlink(video_path)
            torch.cuda.empty_cache()
            gc.collect()

    def __call__(self, data):
        # --- EXTRACT DATA ---
        callback_url = data.get("callback_url", None)
        inputs = data.get("inputs", "What is happening in this video?")
        video_url = data.get("video", None)
        parameters = data.get("parameters", {})
        
        # Generate unique request ID
        request_id = str(uuid.uuid4())
        
        # Validation
        if not video_url:
            return {"error": "Missing 'video' URL.", "status": "failed", "request_id": request_id}

        # --- ASYNC MODE: Return early, process in background ---
        if callback_url:
            print(f"Async mode: request_id={request_id}, will send result to {callback_url}")
            
            # Spawn background thread for processing
            thread = threading.Thread(
                target=self._process_video,
                args=(inputs, video_url, parameters, callback_url, request_id),
                daemon=True  # Daemon thread won't block process exit
            )
            thread.start()
            
            # Return immediately with acknowledgment
            return [{
                "request_id": request_id,
                "status": "accepted",
                "message": "Processing started. Result will be sent to callback_url.",
                "callback_url": callback_url
            }]
        
        # --- SYNC MODE: Process and return result ---
        else:
            result = self._process_video(inputs, video_url, parameters, request_id=request_id)
            result["request_id"] = request_id
            return [result]