uncensored-com commited on
Commit
000f3a6
·
verified ·
1 Parent(s): 23fc12f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -25
handler.py CHANGED
@@ -10,14 +10,14 @@ from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
- # Load Model
14
  model_id = "LanguageBind/Video-LLaVA-7B-hf"
15
  print(f"Loading model: {model_id}...")
16
 
 
17
  self.processor = VideoLlavaProcessor.from_pretrained(model_id)
18
  self.model = VideoLlavaForConditionalGeneration.from_pretrained(
19
  model_id,
20
- torch_dtype=torch.float16,
21
  device_map="auto",
22
  low_cpu_mem_usage=True
23
  )
@@ -30,7 +30,7 @@ class EndpointHandler:
30
  temp_path = temp_file.name
31
  temp_file.close()
32
  try:
33
- # Added timeout (30s) to prevent hanging on bad URLs
34
  response = requests.get(video_url, stream=True, timeout=30)
35
  if response.status_code != 200:
36
  raise ValueError(f"Failed to download: {response.status_code}")
@@ -55,11 +55,9 @@ class EndpointHandler:
55
  if i >= start_index and i in indices:
56
  frames.append(frame)
57
 
58
- # Guard clause: If video is corrupted or empty
59
  if not frames:
60
  raise ValueError("Video decoding failed: No frames found.")
61
 
62
- # Return as list of numpy arrays
63
  return [x.to_ndarray(format="rgb24") for x in frames]
64
 
65
  def __call__(self, data):
@@ -70,40 +68,34 @@ class EndpointHandler:
70
  video_path = None
71
 
72
  try:
73
- # 1. EXTRACT DATA
74
  inputs = data.pop("inputs", "Describe this video.")
75
  video_url = data.pop("video", None)
76
  parameters = data.pop("parameters", {})
77
 
78
- # Default to 8 frames because LanguageBind is trained on 8 frames.
79
- # Only change this if you are sure the model handles interpolation.
80
  num_frames = parameters.pop("num_frames", 8)
81
 
82
- # Clean parameters for generation (pass everything else to the model)
83
  gen_kwargs = {
84
  "max_new_tokens": parameters.pop("max_new_tokens", 500),
85
- "temperature": parameters.pop("temperature", 0.7),
86
- "top_p": parameters.pop("top_p", 0.9),
87
  "do_sample": True
88
  }
89
- gen_kwargs.update(parameters) # Merge any other params user sent
90
 
91
  if not video_url:
92
  return {"error": "Missing 'video' URL."}
93
 
94
- # 2. DOWNLOAD
95
  print(f"Downloading: {video_url}")
96
  video_path = self.download_video(video_url)
97
  container = av.open(video_path)
98
 
99
- # 3. SMART FRAME SAMPLING
100
  total_frames = container.streams.video[0].frames
101
  if total_frames == 0:
102
- # Fallback for videos with missing metadata
103
  total_frames = sum(1 for _ in container.decode(video=0))
104
  container.seek(0)
105
 
106
- # Clamp frames to available count
107
  frames_to_use = min(total_frames, num_frames)
108
  if frames_to_use < 1: frames_to_use = 1
109
 
@@ -111,20 +103,18 @@ class EndpointHandler:
111
  clip = self.read_video_pyav(container, indices)
112
  print(f"Processed {len(clip)} frames.")
113
 
114
- # 4. PREPARE PROMPT
115
- # Check if user already added the template to avoid double-templating
116
  if "USER:" in inputs:
117
  full_prompt = inputs
118
  else:
119
- full_prompt = f"USER: <video>\n{inputs}\nASSISTANT:"
120
 
 
121
  model_inputs = self.processor(
122
  text=full_prompt,
123
  videos=clip,
124
  return_tensors="pt"
125
  ).to(self.model.device)
126
 
127
- # 5. GENERATE
128
  print(f"Generating with params: {gen_kwargs}")
129
  with torch.inference_mode():
130
  generate_ids = self.model.generate(
@@ -134,13 +124,11 @@ class EndpointHandler:
134
 
135
  result = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
136
 
137
- # Clean output based on prompt structure
138
  if "ASSISTANT:" in result:
139
  final_output = result.split("ASSISTANT:")[-1].strip()
140
  else:
141
  final_output = result
142
 
143
- # LOG TIME
144
  duration = time.time() - start_time
145
  print(f"✅ Success! Total time: {duration:.2f} seconds.")
146
  print(f"Result preview: {final_output[:50]}...")
@@ -153,11 +141,8 @@ class EndpointHandler:
153
  return {"error": str(e)}
154
 
155
  finally:
156
- # 6. CLEANUP (Crucial for long-running endpoints)
157
  if container: container.close()
158
  if video_path and os.path.exists(video_path):
159
  os.unlink(video_path)
160
-
161
- # Clear GPU memory
162
  torch.cuda.empty_cache()
163
  gc.collect()
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
 
13
  model_id = "LanguageBind/Video-LLaVA-7B-hf"
14
  print(f"Loading model: {model_id}...")
15
 
16
+ # 1. Use bfloat16 (Matches your local script for better precision)
17
  self.processor = VideoLlavaProcessor.from_pretrained(model_id)
18
  self.model = VideoLlavaForConditionalGeneration.from_pretrained(
19
  model_id,
20
+ torch_dtype=torch.bfloat16,
21
  device_map="auto",
22
  low_cpu_mem_usage=True
23
  )
 
30
  temp_path = temp_file.name
31
  temp_file.close()
32
  try:
33
+ # 30s timeout prevents hanging
34
  response = requests.get(video_url, stream=True, timeout=30)
35
  if response.status_code != 200:
36
  raise ValueError(f"Failed to download: {response.status_code}")
 
55
  if i >= start_index and i in indices:
56
  frames.append(frame)
57
 
 
58
  if not frames:
59
  raise ValueError("Video decoding failed: No frames found.")
60
 
 
61
  return [x.to_ndarray(format="rgb24") for x in frames]
62
 
63
  def __call__(self, data):
 
68
  video_path = None
69
 
70
  try:
 
71
  inputs = data.pop("inputs", "Describe this video.")
72
  video_url = data.pop("video", None)
73
  parameters = data.pop("parameters", {})
74
 
75
+ # Default to 8 frames (Native to this model architecture)
 
76
  num_frames = parameters.pop("num_frames", 8)
77
 
78
+ # 2. Configuration that matches your script's logic
79
  gen_kwargs = {
80
  "max_new_tokens": parameters.pop("max_new_tokens", 500),
81
+ "temperature": parameters.pop("temperature", 0.1), # Defaulted to your 0.1
82
+ "top_p": parameters.pop("top_p", 0.9), # Defaulted to your 0.9
83
  "do_sample": True
84
  }
85
+ gen_kwargs.update(parameters)
86
 
87
  if not video_url:
88
  return {"error": "Missing 'video' URL."}
89
 
 
90
  print(f"Downloading: {video_url}")
91
  video_path = self.download_video(video_url)
92
  container = av.open(video_path)
93
 
 
94
  total_frames = container.streams.video[0].frames
95
  if total_frames == 0:
 
96
  total_frames = sum(1 for _ in container.decode(video=0))
97
  container.seek(0)
98
 
 
99
  frames_to_use = min(total_frames, num_frames)
100
  if frames_to_use < 1: frames_to_use = 1
101
 
 
103
  clip = self.read_video_pyav(container, indices)
104
  print(f"Processed {len(clip)} frames.")
105
 
 
 
106
  if "USER:" in inputs:
107
  full_prompt = inputs
108
  else:
109
+ full_prompt = f"USER: <video>{inputs} ASSISTANT:"
110
 
111
+ # 3. Ensure input tensors are also bfloat16
112
  model_inputs = self.processor(
113
  text=full_prompt,
114
  videos=clip,
115
  return_tensors="pt"
116
  ).to(self.model.device)
117
 
 
118
  print(f"Generating with params: {gen_kwargs}")
119
  with torch.inference_mode():
120
  generate_ids = self.model.generate(
 
124
 
125
  result = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
126
 
 
127
  if "ASSISTANT:" in result:
128
  final_output = result.split("ASSISTANT:")[-1].strip()
129
  else:
130
  final_output = result
131
 
 
132
  duration = time.time() - start_time
133
  print(f"✅ Success! Total time: {duration:.2f} seconds.")
134
  print(f"Result preview: {final_output[:50]}...")
 
141
  return {"error": str(e)}
142
 
143
  finally:
 
144
  if container: container.close()
145
  if video_path and os.path.exists(video_path):
146
  os.unlink(video_path)
 
 
147
  torch.cuda.empty_cache()
148
  gc.collect()