Jonwi0706 uncensored-com commited on
Commit
e505c6b
·
0 Parent(s):

Duplicate from uncensored-com/video-llava-7b-deployable

Browse files

Co-authored-by: uncensored ai <uncensored-com@users.noreply.huggingface.co>

Files changed (3) hide show
  1. .gitattributes +35 -0
  2. handler.py +266 -0
  3. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
handler.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import av
3
+ import numpy as np
4
+ import os
5
+ import requests
6
+ import tempfile
7
+ import gc
8
+ import time
9
+ import threading
10
+ import uuid
11
+ from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration
12
+
13
+ class EndpointHandler:
14
+ def __init__(self, path=""):
15
+ # 1. SETUP
16
+ model_id = "LanguageBind/Video-LLaVA-7B-hf"
17
+ print(f"Loading model: {model_id}...")
18
+
19
+ # Using bfloat16 to match your local script's success
20
+ self.processor = VideoLlavaProcessor.from_pretrained(model_id)
21
+ self.model = VideoLlavaForConditionalGeneration.from_pretrained(
22
+ model_id,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto",
25
+ low_cpu_mem_usage=True
26
+ )
27
+ self.model.eval()
28
+ print("Model loaded successfully.")
29
+
30
+ def download_video(self, video_url):
31
+ # Exact logic from your script, adapted for class structure
32
+ suffix = os.path.splitext(video_url)[1] or '.mp4'
33
+ temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
34
+ temp_path = temp_file.name
35
+ temp_file.close()
36
+
37
+ try:
38
+ # Added 30s timeout to prevent hanging, otherwise logic matches
39
+ response = requests.get(video_url, stream=True, timeout=60)
40
+ response.raise_for_status()
41
+
42
+ # Helper to get size for logging
43
+ file_size = int(response.headers.get('content-length', 0))
44
+
45
+ with open(temp_path, 'wb') as f:
46
+ for chunk in response.iter_content(chunk_size=8192):
47
+ if chunk:
48
+ f.write(chunk)
49
+
50
+ if file_size == 0:
51
+ file_size = os.path.getsize(temp_path)
52
+
53
+ print(f"Downloaded video ({file_size/1024/1024:.2f} MB) to {temp_path}")
54
+ return temp_path
55
+
56
+ except Exception as e:
57
+ if os.path.exists(temp_path):
58
+ os.unlink(temp_path)
59
+ raise Exception(f"Failed to download video: {str(e)}")
60
+
61
+ def read_video_pyav(self, container, indices):
62
+ # The logic expected by VideoLlava
63
+ frames = []
64
+ container.seek(0)
65
+ start_index = indices[0]
66
+ end_index = indices[-1]
67
+ for i, frame in enumerate(container.decode(video=0)):
68
+ if i > end_index:
69
+ break
70
+ if i >= start_index and i in indices:
71
+ frames.append(frame)
72
+
73
+ if not frames:
74
+ raise ValueError("Video decoding failed: No frames found.")
75
+
76
+ # Return list of numpy arrays (RGB)
77
+ return [x.to_ndarray(format="rgb24") for x in frames]
78
+
79
+ def trigger_webhook(self, url, payload):
80
+ """
81
+ Sends payload to callback_url.
82
+ Fire-and-forget style: catches errors so main execution doesn't fail.
83
+ """
84
+ if not url:
85
+ return
86
+
87
+ print(f"Sending webhook to {url}")
88
+ try:
89
+ # 5s timeout ensures the HF Endpoint doesn't hang if your server is slow
90
+ resp = requests.post(url, json=payload, timeout=5)
91
+ resp.raise_for_status()
92
+ print(f"Webhook success: {resp.status_code}")
93
+ except Exception as e:
94
+ # We print the error but do NOT raise it, ensuring the user still gets their result
95
+ print(f"Webhook failed: {str(e)}")
96
+
97
+ def _process_video(self, inputs, video_url, parameters, callback_url=None, request_id=None):
98
+ """
99
+ Core video processing logic. Used by both sync and async paths.
100
+ If callback_url is provided, sends result via webhook.
101
+ Returns the response payload.
102
+ """
103
+ # Start timing exactly like your script
104
+ predict_start = time.time()
105
+ print(f"\nStarting prediction at {time.strftime('%H:%M:%S')}")
106
+
107
+ container = None
108
+ video_path = None
109
+
110
+ try:
111
+ # 1. CONFIGURATION matches your script defaults
112
+ # Your script defaulted to 10 frames
113
+ num_frames = parameters.get("num_frames", 10)
114
+
115
+ # Your script defaults: max 500, temp 0.1, top_p 0.9
116
+ max_new_tokens = parameters.get("max_new_tokens", 500)
117
+ temperature = parameters.get("temperature", 0.1)
118
+ top_p = parameters.get("top_p", 0.9)
119
+
120
+ print(f"Prompt: {inputs}")
121
+
122
+ # 2. DOWNLOAD
123
+ video_path = self.download_video(video_url)
124
+ container = av.open(video_path)
125
+
126
+ # 3. FRAME EXTRACTION
127
+ total_frames = container.streams.video[0].frames
128
+ if total_frames == 0:
129
+ total_frames = sum(1 for _ in container.decode(video=0))
130
+ container.seek(0)
131
+
132
+ # Logic: frames_to_use = min(total_frames, num_frames)
133
+ frames_to_use = min(total_frames, num_frames) if total_frames > 0 else num_frames
134
+ print(f"Using {frames_to_use} frames")
135
+
136
+ indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int)
137
+ print(f"Using indices: {indices}")
138
+
139
+ clip = self.read_video_pyav(container, indices)
140
+ print(f"Extracted {len(clip)} frames")
141
+
142
+ # 4. PROMPT CONSTRUCTION
143
+ # We check if 'USER:' exists to allow your custom full prompts to pass through.
144
+ # If it's a simple string, we apply your script's formatting exactly.
145
+ if "USER:" in inputs:
146
+ full_prompt = inputs
147
+ else:
148
+ full_prompt = f"USER: <video>{inputs} ASSISTANT:"
149
+
150
+ # 5. TOKENIZE
151
+ model_inputs = self.processor(
152
+ text=full_prompt,
153
+ videos=clip,
154
+ return_tensors="pt"
155
+ ).to(self.model.device)
156
+
157
+ # 6. GENERATE
158
+ with torch.inference_mode():
159
+ generate_ids = self.model.generate(
160
+ **model_inputs,
161
+ max_new_tokens=max_new_tokens,
162
+ temperature=temperature,
163
+ top_p=top_p,
164
+ do_sample=True if temperature > 0 else False
165
+ )
166
+
167
+ # 7. DECODE
168
+ result = self.processor.batch_decode(
169
+ generate_ids,
170
+ skip_special_tokens=True,
171
+ clean_up_tokenization_spaces=False
172
+ )[0]
173
+
174
+ if "ASSISTANT:" in result:
175
+ final_output = result.split("ASSISTANT:")[-1].strip()
176
+ else:
177
+ final_output = result
178
+
179
+ # 8. END TIMING
180
+ execution_time = f"{time.time() - predict_start:.2f}s"
181
+ print(f"Total prediction time: {execution_time}")
182
+
183
+ response_payload = {
184
+ "generated_text": final_output,
185
+ "status": "success",
186
+ "execution_time": execution_time
187
+ }
188
+
189
+ # 9. SEND WEBHOOK (if callback_url provided)
190
+ if callback_url:
191
+ webhook_data = {
192
+ "request_id": request_id,
193
+ "input_prompt": inputs,
194
+ "video_url": video_url,
195
+ "result": response_payload
196
+ }
197
+ self.trigger_webhook(callback_url, webhook_data)
198
+
199
+ return response_payload
200
+
201
+ except Exception as e:
202
+ import traceback
203
+ traceback.print_exc()
204
+ print(f"Inference failed: {str(e)}")
205
+
206
+ error_payload = {"error": str(e), "status": "failed"}
207
+
208
+ # Send error via webhook if callback_url provided
209
+ if callback_url:
210
+ webhook_data = {
211
+ "request_id": request_id,
212
+ "input_prompt": inputs,
213
+ "video_url": video_url,
214
+ "result": error_payload
215
+ }
216
+ self.trigger_webhook(callback_url, webhook_data)
217
+
218
+ return error_payload
219
+
220
+ finally:
221
+ # Cleanup
222
+ if container: container.close()
223
+ if video_path and os.path.exists(video_path):
224
+ os.unlink(video_path)
225
+ torch.cuda.empty_cache()
226
+ gc.collect()
227
+
228
+ def __call__(self, data):
229
+ # --- EXTRACT DATA ---
230
+ callback_url = data.get("callback_url", None)
231
+ inputs = data.get("inputs", "What is happening in this video?")
232
+ video_url = data.get("video", None)
233
+ parameters = data.get("parameters", {})
234
+
235
+ # Generate unique request ID
236
+ request_id = str(uuid.uuid4())
237
+
238
+ # Validation
239
+ if not video_url:
240
+ return {"error": "Missing 'video' URL.", "status": "failed", "request_id": request_id}
241
+
242
+ # --- ASYNC MODE: Return early, process in background ---
243
+ if callback_url:
244
+ print(f"Async mode: request_id={request_id}, will send result to {callback_url}")
245
+
246
+ # Spawn background thread for processing
247
+ thread = threading.Thread(
248
+ target=self._process_video,
249
+ args=(inputs, video_url, parameters, callback_url, request_id),
250
+ daemon=True # Daemon thread won't block process exit
251
+ )
252
+ thread.start()
253
+
254
+ # Return immediately with acknowledgment
255
+ return [{
256
+ "request_id": request_id,
257
+ "status": "accepted",
258
+ "message": "Processing started. Result will be sent to callback_url.",
259
+ "callback_url": callback_url
260
+ }]
261
+
262
+ # --- SYNC MODE: Process and return result ---
263
+ else:
264
+ result = self._process_video(inputs, video_url, parameters, request_id=request_id)
265
+ result["request_id"] = request_id
266
+ return [result]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ av
2
+ numpy
3
+ requests
4
+ transformers>=4.42.0
5
+ accelerate
6
+ protobuf