uncensored-com commited on
Commit
23fc12f
·
verified ·
1 Parent(s): fcd38b0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +63 -21
handler.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import os
5
  import requests
6
  import tempfile
 
 
7
  from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration
8
 
9
  class EndpointHandler:
@@ -28,7 +30,11 @@ class EndpointHandler:
28
  temp_path = temp_file.name
29
  temp_file.close()
30
  try:
31
- response = requests.get(video_url, stream=True)
 
 
 
 
32
  with open(temp_path, 'wb') as f:
33
  for chunk in response.iter_content(chunk_size=8192):
34
  if chunk:
@@ -48,20 +54,39 @@ class EndpointHandler:
48
  break
49
  if i >= start_index and i in indices:
50
  frames.append(frame)
51
- # Return as list of numpy arrays, which acts like a "list of images" for the processor
 
 
 
 
 
52
  return [x.to_ndarray(format="rgb24") for x in frames]
53
 
54
  def __call__(self, data):
 
55
  print("\n--- NEW REQUEST ---")
 
 
 
 
56
  try:
57
  # 1. EXTRACT DATA
58
- inputs = data.pop("inputs", "What is happening in this video?")
59
  video_url = data.pop("video", None)
60
  parameters = data.pop("parameters", {})
61
 
62
- num_frames = parameters.get("num_frames", 8)
63
- max_new_tokens = parameters.get("max_new_tokens", 250)
64
- temperature = parameters.get("temperature", 0.1)
 
 
 
 
 
 
 
 
 
65
 
66
  if not video_url:
67
  return {"error": "Missing 'video' URL."}
@@ -71,23 +96,27 @@ class EndpointHandler:
71
  video_path = self.download_video(video_url)
72
  container = av.open(video_path)
73
 
74
- # 3. SAMPLE FRAMES
75
  total_frames = container.streams.video[0].frames
76
  if total_frames == 0:
 
77
  total_frames = sum(1 for _ in container.decode(video=0))
78
  container.seek(0)
79
 
80
- # Ensure we don't request more frames than exist
81
  frames_to_use = min(total_frames, num_frames)
82
  if frames_to_use < 1: frames_to_use = 1
83
-
84
  indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int)
85
  clip = self.read_video_pyav(container, indices)
86
  print(f"Processed {len(clip)} frames.")
87
 
88
- # 4. PREPARE INPUTS
89
- # Note: VideoLlava expects specific prompt formatting
90
- full_prompt = f"USER: <video>{inputs} ASSISTANT:"
 
 
 
91
 
92
  model_inputs = self.processor(
93
  text=full_prompt,
@@ -96,26 +125,39 @@ class EndpointHandler:
96
  ).to(self.model.device)
97
 
98
  # 5. GENERATE
99
- print("Generating...")
100
  with torch.inference_mode():
101
  generate_ids = self.model.generate(
102
  **model_inputs,
103
- max_new_tokens=max_new_tokens,
104
- temperature=temperature,
105
- do_sample=True if temperature > 0 else False
106
  )
107
 
108
  result = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
109
- final_output = result.split("ASSISTANT:")[-1].strip()
110
 
111
- print(f"Result: {final_output[:50]}...")
 
 
 
 
 
 
 
 
 
 
112
  return [{"generated_text": final_output}]
113
 
114
  except Exception as e:
115
  import traceback
116
  traceback.print_exc()
117
  return {"error": str(e)}
 
118
  finally:
119
- if 'container' in locals() and container: container.close()
120
- if 'video_path' in locals() and video_path and os.path.exists(video_path):
121
- os.unlink(video_path)
 
 
 
 
 
 
4
  import os
5
  import requests
6
  import tempfile
7
+ import gc
8
+ import time
9
  from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration
10
 
11
  class EndpointHandler:
 
30
  temp_path = temp_file.name
31
  temp_file.close()
32
  try:
33
+ # Added timeout (30s) to prevent hanging on bad URLs
34
+ response = requests.get(video_url, stream=True, timeout=30)
35
+ if response.status_code != 200:
36
+ raise ValueError(f"Failed to download: {response.status_code}")
37
+
38
  with open(temp_path, 'wb') as f:
39
  for chunk in response.iter_content(chunk_size=8192):
40
  if chunk:
 
54
  break
55
  if i >= start_index and i in indices:
56
  frames.append(frame)
57
+
58
+ # Guard clause: If video is corrupted or empty
59
+ if not frames:
60
+ raise ValueError("Video decoding failed: No frames found.")
61
+
62
+ # Return as list of numpy arrays
63
  return [x.to_ndarray(format="rgb24") for x in frames]
64
 
65
  def __call__(self, data):
66
+ start_time = time.time()
67
  print("\n--- NEW REQUEST ---")
68
+
69
+ container = None
70
+ video_path = None
71
+
72
  try:
73
  # 1. EXTRACT DATA
74
+ inputs = data.pop("inputs", "Describe this video.")
75
  video_url = data.pop("video", None)
76
  parameters = data.pop("parameters", {})
77
 
78
+ # Default to 8 frames because LanguageBind is trained on 8 frames.
79
+ # Only change this if you are sure the model handles interpolation.
80
+ num_frames = parameters.pop("num_frames", 8)
81
+
82
+ # Clean parameters for generation (pass everything else to the model)
83
+ gen_kwargs = {
84
+ "max_new_tokens": parameters.pop("max_new_tokens", 500),
85
+ "temperature": parameters.pop("temperature", 0.7),
86
+ "top_p": parameters.pop("top_p", 0.9),
87
+ "do_sample": True
88
+ }
89
+ gen_kwargs.update(parameters) # Merge any other params user sent
90
 
91
  if not video_url:
92
  return {"error": "Missing 'video' URL."}
 
96
  video_path = self.download_video(video_url)
97
  container = av.open(video_path)
98
 
99
+ # 3. SMART FRAME SAMPLING
100
  total_frames = container.streams.video[0].frames
101
  if total_frames == 0:
102
+ # Fallback for videos with missing metadata
103
  total_frames = sum(1 for _ in container.decode(video=0))
104
  container.seek(0)
105
 
106
+ # Clamp frames to available count
107
  frames_to_use = min(total_frames, num_frames)
108
  if frames_to_use < 1: frames_to_use = 1
109
+
110
  indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int)
111
  clip = self.read_video_pyav(container, indices)
112
  print(f"Processed {len(clip)} frames.")
113
 
114
+ # 4. PREPARE PROMPT
115
+ # Check if user already added the template to avoid double-templating
116
+ if "USER:" in inputs:
117
+ full_prompt = inputs
118
+ else:
119
+ full_prompt = f"USER: <video>\n{inputs}\nASSISTANT:"
120
 
121
  model_inputs = self.processor(
122
  text=full_prompt,
 
125
  ).to(self.model.device)
126
 
127
  # 5. GENERATE
128
+ print(f"Generating with params: {gen_kwargs}")
129
  with torch.inference_mode():
130
  generate_ids = self.model.generate(
131
  **model_inputs,
132
+ **gen_kwargs
 
 
133
  )
134
 
135
  result = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
136
 
137
+ # Clean output based on prompt structure
138
+ if "ASSISTANT:" in result:
139
+ final_output = result.split("ASSISTANT:")[-1].strip()
140
+ else:
141
+ final_output = result
142
+
143
+ # LOG TIME
144
+ duration = time.time() - start_time
145
+ print(f"✅ Success! Total time: {duration:.2f} seconds.")
146
+ print(f"Result preview: {final_output[:50]}...")
147
+
148
  return [{"generated_text": final_output}]
149
 
150
  except Exception as e:
151
  import traceback
152
  traceback.print_exc()
153
  return {"error": str(e)}
154
+
155
  finally:
156
+ # 6. CLEANUP (Crucial for long-running endpoints)
157
+ if container: container.close()
158
+ if video_path and os.path.exists(video_path):
159
+ os.unlink(video_path)
160
+
161
+ # Clear GPU memory
162
+ torch.cuda.empty_cache()
163
+ gc.collect()