Spaces:
Runtime error
Runtime error
fix
Browse files- demo/infer.py +54 -48
demo/infer.py
CHANGED
|
@@ -156,57 +156,63 @@ class LiveCCDemoInfer:
|
|
| 156 |
state['past_ids'] = outputs.sequences[:, :-1]
|
| 157 |
yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
|
| 158 |
|
|
|
|
| 159 |
def video_qa(
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
| 185 |
else:
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
"
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
inputs = processor(
|
| 198 |
-
text=
|
| 199 |
images=image_inputs,
|
| 200 |
videos=video_inputs,
|
| 201 |
return_tensors="pt",
|
|
|
|
| 202 |
)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
| 156 |
state['past_ids'] = outputs.sequences[:, :-1]
|
| 157 |
yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
|
| 158 |
|
| 159 |
+
@torch.inference_mode()
|
| 160 |
def video_qa(
|
| 161 |
+
self,
|
| 162 |
+
query: str,
|
| 163 |
+
state: dict,
|
| 164 |
+
default_query: str = 'Please describe the video.',
|
| 165 |
+
do_sample: bool = False,
|
| 166 |
+
repetition_penalty: float = 1.05,
|
| 167 |
+
**kwargs,
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
state: dict, (maybe) with keys:
|
| 171 |
+
video_path: str, video path
|
| 172 |
+
video_timestamp: float, current video timestamp
|
| 173 |
+
last_timestamp: float, last processed video timestamp
|
| 174 |
+
last_video_pts_index: int, last processed video frame index
|
| 175 |
+
video_pts: np.ndarray, video pts
|
| 176 |
+
last_history: list, last processed history
|
| 177 |
+
"""
|
| 178 |
+
video_path = state.get('video_path', None)
|
| 179 |
+
if video_path:
|
| 180 |
+
message = {
|
| 181 |
+
"role": "user",
|
| 182 |
+
"content": [
|
| 183 |
+
{"type": "video", "video": video_path},
|
| 184 |
+
{"type": "text", "text": query if query else default_query},
|
| 185 |
+
],
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
else:
|
| 189 |
+
message = {
|
| 190 |
+
"role": "user",
|
| 191 |
+
"content": [
|
| 192 |
+
{"type": "text", "text": query if query else default_query},
|
| 193 |
+
],
|
| 194 |
+
}
|
| 195 |
+
image_inputs, video_inputs = process_vision_info([message])
|
| 196 |
+
texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
| 197 |
+
past_ids = state.get('past_ids', None)
|
| 198 |
+
if past_ids is not None:
|
| 199 |
+
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
| 200 |
+
inputs = self.processor(
|
| 201 |
+
text=texts,
|
| 202 |
images=image_inputs,
|
| 203 |
videos=video_inputs,
|
| 204 |
return_tensors="pt",
|
| 205 |
+
return_attention_mask=False
|
| 206 |
)
|
| 207 |
+
inputs.to('cuda')
|
| 208 |
+
if past_ids is not None:
|
| 209 |
+
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
| 210 |
+
outputs = self.model.generate(
|
| 211 |
+
**inputs, past_key_values=state.get('past_key_values', None),
|
| 212 |
+
return_dict_in_generate=True, do_sample=do_sample,
|
| 213 |
+
repetition_penalty=repetition_penalty,
|
| 214 |
+
max_new_tokens=512,
|
| 215 |
+
)
|
| 216 |
+
state['past_key_values'] = outputs.past_key_values
|
| 217 |
+
state['past_ids'] = outputs.sequences[:, :-1]
|
| 218 |
+
return self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
|