uncensored-com commited on
Commit
fcd38b0
·
verified ·
1 Parent(s): 542ecdb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +52 -79
handler.py CHANGED
@@ -8,37 +8,21 @@ from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- # 1. LOAD MODEL
12
  model_id = "LanguageBind/Video-LLaVA-7B-hf"
13
-
14
  print(f"Loading model: {model_id}...")
 
15
  self.processor = VideoLlavaProcessor.from_pretrained(model_id)
16
  self.model = VideoLlavaForConditionalGeneration.from_pretrained(
17
  model_id,
18
- torch_dtype=torch.float16, # bfloat16 is better if your GPU supports it (A10G/A100), otherwise float16
19
  device_map="auto",
20
  low_cpu_mem_usage=True
21
  )
22
  self.model.eval()
23
  print("Model loaded successfully.")
24
 
25
- def read_video_pyav(self, container, indices):
26
- '''
27
- Decode the video with PyAV decoder.
28
- '''
29
- frames = []
30
- container.seek(0)
31
- start_index = indices[0]
32
- end_index = indices[-1]
33
- for i, frame in enumerate(container.decode(video=0)):
34
- if i > end_index:
35
- break
36
- if i >= start_index and i in indices:
37
- frames.append(frame)
38
- return np.stack([x.to_ndarray(format="rgb24") for x in frames])
39
-
40
  def download_video(self, video_url):
41
- # Your specific download logic
42
  suffix = os.path.splitext(video_url)[1] or '.mp4'
43
  temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
44
  temp_path = temp_file.name
@@ -51,98 +35,87 @@ class EndpointHandler:
51
  f.write(chunk)
52
  return temp_path
53
  except Exception as e:
54
- if os.path.exists(temp_path):
55
- os.unlink(temp_path)
56
  raise e
57
- def __call__(self, data):
58
- """
59
- The endpoint calls this function when you send a JSON request.
60
- Expected JSON: {"inputs": "text prompt", "video": "url", "parameters": {...}}
61
- """
62
- print("\n--- NEW REQUEST RECEIVED ---") # LOG
63
 
64
- # 1. EXTRACT DATA
65
- inputs = data.pop("inputs", "What is happening in this video?")
66
- video_url = data.pop("video", None)
67
- parameters = data.pop("parameters", {})
68
-
69
- # Log the inputs so you can verify overrides in the dashboard logs
70
- print(f"Input Prompt: {inputs}")
71
- print(f"Video URL: {video_url}")
72
- print(f"Raw Parameters: {parameters}")
 
 
 
73
 
74
- # Default parameters from your script
75
- max_new_tokens = parameters.get("max_new_tokens", 500)
76
- temperature = parameters.get("temperature", 0.1)
77
- top_p = parameters.get("top_p", 0.9)
78
- num_frames = parameters.get("num_frames", 10)
79
-
80
- # Log effective parameters
81
- print(f"Effective Config -> Frames: {num_frames}, Max Tokens: {max_new_tokens}, Temp: {temperature}")
 
 
 
82
 
83
- if not video_url:
84
- print("Error: No video URL provided.")
85
- return {"error": "No 'video' key provided in the payload."}
86
 
87
- video_path = None
88
- container = None
89
-
90
- try:
91
- # 2. DOWNLOAD & PROCESS
92
- print(f"Downloading video from {video_url}...")
93
  video_path = self.download_video(video_url)
94
  container = av.open(video_path)
95
 
 
96
  total_frames = container.streams.video[0].frames
97
  if total_frames == 0:
98
  total_frames = sum(1 for _ in container.decode(video=0))
99
  container.seek(0)
100
-
101
- print(f"Video Info -> Total Frames: {total_frames}")
102
-
103
- frames_to_use = min(total_frames, num_frames) if total_frames > 0 else num_frames
104
- indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int)
105
 
 
 
 
 
 
106
  clip = self.read_video_pyav(container, indices)
107
- print(f"Extracted {len(clip)} frames for inference.")
108
-
109
- # 3. PREPARE PROMPT
 
110
  full_prompt = f"USER: <video>{inputs} ASSISTANT:"
111
 
112
  model_inputs = self.processor(
113
  text=full_prompt,
114
- videos=clip,
115
  return_tensors="pt"
116
  ).to(self.model.device)
117
 
118
- # 4. GENERATE
119
- print("Starting generation...")
120
  with torch.inference_mode():
121
  generate_ids = self.model.generate(
122
  **model_inputs,
123
- max_length=max_new_tokens,
124
  temperature=temperature,
125
- top_p=top_p,
126
  do_sample=True if temperature > 0 else False
127
  )
128
 
129
- result = self.processor.batch_decode(
130
- generate_ids,
131
- skip_special_tokens=True,
132
- clean_up_tokenization_spaces=False
133
- )[0]
134
-
135
  final_output = result.split("ASSISTANT:")[-1].strip()
136
- print(f"Generation complete. Result: {final_output[:50]}...") # Log first 50 chars
137
 
 
138
  return [{"generated_text": final_output}]
139
 
140
  except Exception as e:
141
- print(f"CRITICAL ERROR: {str(e)}")
 
142
  return {"error": str(e)}
143
  finally:
144
- if container: container.close()
145
- if video_path and os.path.exists(video_path):
146
- os.unlink(video_path)
147
-
148
-
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ # Load Model
12
  model_id = "LanguageBind/Video-LLaVA-7B-hf"
 
13
  print(f"Loading model: {model_id}...")
14
+
15
  self.processor = VideoLlavaProcessor.from_pretrained(model_id)
16
  self.model = VideoLlavaForConditionalGeneration.from_pretrained(
17
  model_id,
18
+ torch_dtype=torch.float16,
19
  device_map="auto",
20
  low_cpu_mem_usage=True
21
  )
22
  self.model.eval()
23
  print("Model loaded successfully.")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def download_video(self, video_url):
 
26
  suffix = os.path.splitext(video_url)[1] or '.mp4'
27
  temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
28
  temp_path = temp_file.name
 
35
  f.write(chunk)
36
  return temp_path
37
  except Exception as e:
38
+ if os.path.exists(temp_path): os.unlink(temp_path)
 
39
  raise e
 
 
 
 
 
 
40
 
41
+ def read_video_pyav(self, container, indices):
42
+ frames = []
43
+ container.seek(0)
44
+ start_index = indices[0]
45
+ end_index = indices[-1]
46
+ for i, frame in enumerate(container.decode(video=0)):
47
+ if i > end_index:
48
+ break
49
+ if i >= start_index and i in indices:
50
+ frames.append(frame)
51
+ # Return as list of numpy arrays, which acts like a "list of images" for the processor
52
+ return [x.to_ndarray(format="rgb24") for x in frames]
53
 
54
+ def __call__(self, data):
55
+ print("\n--- NEW REQUEST ---")
56
+ try:
57
+ # 1. EXTRACT DATA
58
+ inputs = data.pop("inputs", "What is happening in this video?")
59
+ video_url = data.pop("video", None)
60
+ parameters = data.pop("parameters", {})
61
+
62
+ num_frames = parameters.get("num_frames", 8)
63
+ max_new_tokens = parameters.get("max_new_tokens", 250)
64
+ temperature = parameters.get("temperature", 0.1)
65
 
66
+ if not video_url:
67
+ return {"error": "Missing 'video' URL."}
 
68
 
69
+ # 2. DOWNLOAD
70
+ print(f"Downloading: {video_url}")
 
 
 
 
71
  video_path = self.download_video(video_url)
72
  container = av.open(video_path)
73
 
74
+ # 3. SAMPLE FRAMES
75
  total_frames = container.streams.video[0].frames
76
  if total_frames == 0:
77
  total_frames = sum(1 for _ in container.decode(video=0))
78
  container.seek(0)
 
 
 
 
 
79
 
80
+ # Ensure we don't request more frames than exist
81
+ frames_to_use = min(total_frames, num_frames)
82
+ if frames_to_use < 1: frames_to_use = 1
83
+
84
+ indices = np.linspace(0, total_frames - 1, frames_to_use, dtype=int)
85
  clip = self.read_video_pyav(container, indices)
86
+ print(f"Processed {len(clip)} frames.")
87
+
88
+ # 4. PREPARE INPUTS
89
+ # Note: VideoLlava expects specific prompt formatting
90
  full_prompt = f"USER: <video>{inputs} ASSISTANT:"
91
 
92
  model_inputs = self.processor(
93
  text=full_prompt,
94
+ videos=clip,
95
  return_tensors="pt"
96
  ).to(self.model.device)
97
 
98
+ # 5. GENERATE
99
+ print("Generating...")
100
  with torch.inference_mode():
101
  generate_ids = self.model.generate(
102
  **model_inputs,
103
+ max_new_tokens=max_new_tokens,
104
  temperature=temperature,
 
105
  do_sample=True if temperature > 0 else False
106
  )
107
 
108
+ result = self.processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
 
 
 
 
109
  final_output = result.split("ASSISTANT:")[-1].strip()
 
110
 
111
+ print(f"Result: {final_output[:50]}...")
112
  return [{"generated_text": final_output}]
113
 
114
  except Exception as e:
115
+ import traceback
116
+ traceback.print_exc()
117
  return {"error": str(e)}
118
  finally:
119
+ if 'container' in locals() and container: container.close()
120
+ if 'video_path' in locals() and video_path and os.path.exists(video_path):
121
+ os.unlink(video_path)