Spaces:
Runtime error
Runtime error
Update demo/infer.py
Browse files- demo/infer.py +29 -26
demo/infer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import functools, torch
|
| 2 |
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
|
| 3 |
apply_liger_kernel_to_qwen2_vl()
|
| 4 |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
|
|
@@ -62,6 +62,7 @@ class LiveCCDemoInfer:
|
|
| 62 |
repetition_penalty: float = 1.05,
|
| 63 |
streaming_eos_base_threshold: float = None,
|
| 64 |
streaming_eos_threshold_step: float = None,
|
|
|
|
| 65 |
**kwargs,
|
| 66 |
):
|
| 67 |
"""
|
|
@@ -83,6 +84,7 @@ class LiveCCDemoInfer:
|
|
| 83 |
state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
|
| 84 |
state['last_video_pts_index'] = -1
|
| 85 |
video_pts = state['video_pts']
|
|
|
|
| 86 |
if last_timestamp + self.frame_time_interval > video_pts[-1]:
|
| 87 |
state['video_end'] = True
|
| 88 |
return
|
|
@@ -140,7 +142,7 @@ class LiveCCDemoInfer:
|
|
| 140 |
return_tensors="pt",
|
| 141 |
return_attention_mask=False
|
| 142 |
)
|
| 143 |
-
inputs.to(
|
| 144 |
if past_ids is not None:
|
| 145 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
| 146 |
if streaming_eos_base_threshold is not None:
|
|
@@ -153,9 +155,11 @@ class LiveCCDemoInfer:
|
|
| 153 |
repetition_penalty=repetition_penalty,
|
| 154 |
logits_processor=logits_processor,
|
| 155 |
)
|
| 156 |
-
state['past_key_values'] = outputs.past_key_values
|
| 157 |
-
state['past_ids'] = outputs.sequences[:, :-1]
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
@torch.inference_mode()
|
| 161 |
def video_qa(
|
|
@@ -165,7 +169,7 @@ class LiveCCDemoInfer:
|
|
| 165 |
state: dict,
|
| 166 |
do_sample: bool = False,
|
| 167 |
repetition_penalty: float = 1.05,
|
| 168 |
-
|
| 169 |
**kwargs,
|
| 170 |
):
|
| 171 |
"""
|
|
@@ -178,25 +182,24 @@ class LiveCCDemoInfer:
|
|
| 178 |
last_history: list, last processed history
|
| 179 |
"""
|
| 180 |
video_path = state.get('video_path', None)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
{"type": "
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
else:
|
| 191 |
-
|
| 192 |
-
"role": "user",
|
| 193 |
-
"content": [
|
| 194 |
-
{"type": "text", "text": query},
|
| 195 |
-
],
|
| 196 |
-
}
|
| 197 |
-
image_inputs, video_inputs = process_vision_info([message])
|
| 198 |
-
texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
| 199 |
past_ids = state.get('past_ids', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
if past_ids is not None:
|
| 201 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
| 202 |
inputs = self.processor(
|
|
@@ -204,6 +207,7 @@ class LiveCCDemoInfer:
|
|
| 204 |
images=image_inputs,
|
| 205 |
videos=video_inputs,
|
| 206 |
return_tensors="pt",
|
|
|
|
| 207 |
)
|
| 208 |
inputs.to(self.model.device)
|
| 209 |
if past_ids is not None:
|
|
@@ -214,9 +218,8 @@ class LiveCCDemoInfer:
|
|
| 214 |
repetition_penalty=repetition_penalty,
|
| 215 |
max_new_tokens=512,
|
| 216 |
)
|
| 217 |
-
state['past_key_values'] = outputs.past_key_values
|
| 218 |
-
state['past_ids'] = outputs.sequences[:, :-1]
|
| 219 |
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
| 220 |
print(response)
|
| 221 |
-
state.pop('past_key_values')
|
| 222 |
return response, state
|
|
|
|
| 1 |
+
import functools, torch
|
| 2 |
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
|
| 3 |
apply_liger_kernel_to_qwen2_vl()
|
| 4 |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
|
|
|
|
| 62 |
repetition_penalty: float = 1.05,
|
| 63 |
streaming_eos_base_threshold: float = None,
|
| 64 |
streaming_eos_threshold_step: float = None,
|
| 65 |
+
hf_spaces: bool = False,
|
| 66 |
**kwargs,
|
| 67 |
):
|
| 68 |
"""
|
|
|
|
| 84 |
state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
|
| 85 |
state['last_video_pts_index'] = -1
|
| 86 |
video_pts = state['video_pts']
|
| 87 |
+
video_timestamp = min(video_timestamp, video_pts[-1])
|
| 88 |
if last_timestamp + self.frame_time_interval > video_pts[-1]:
|
| 89 |
state['video_end'] = True
|
| 90 |
return
|
|
|
|
| 142 |
return_tensors="pt",
|
| 143 |
return_attention_mask=False
|
| 144 |
)
|
| 145 |
+
inputs.to(self.model.device)
|
| 146 |
if past_ids is not None:
|
| 147 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
| 148 |
if streaming_eos_base_threshold is not None:
|
|
|
|
| 155 |
repetition_penalty=repetition_penalty,
|
| 156 |
logits_processor=logits_processor,
|
| 157 |
)
|
| 158 |
+
state['past_key_values'] = outputs.past_key_values if not hf_spaces else None
|
| 159 |
+
state['past_ids'] = outputs.sequences[:, :-1] if not hf_spaces else None
|
| 160 |
+
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
| 161 |
+
print(response)
|
| 162 |
+
yield (start_timestamp, stop_timestamp), response, state
|
| 163 |
|
| 164 |
@torch.inference_mode()
|
| 165 |
def video_qa(
|
|
|
|
| 169 |
state: dict,
|
| 170 |
do_sample: bool = False,
|
| 171 |
repetition_penalty: float = 1.05,
|
| 172 |
+
hf_spaces: bool = False,
|
| 173 |
**kwargs,
|
| 174 |
):
|
| 175 |
"""
|
|
|
|
| 182 |
last_history: list, last processed history
|
| 183 |
"""
|
| 184 |
video_path = state.get('video_path', None)
|
| 185 |
+
conversation = []
|
| 186 |
+
if hf_spaces:
|
| 187 |
+
for past_message in history:
|
| 188 |
+
content = [{"type": "text", "text": past_message['content']}]
|
| 189 |
+
if video_path: # only use once
|
| 190 |
+
content.insert(0, {"type": "video", "video": video_path})
|
| 191 |
+
video_path = None
|
| 192 |
+
conversation.append({"role": past_message["role"], "content": content})
|
|
|
|
| 193 |
else:
|
| 194 |
+
pass # use past_key_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
past_ids = state.get('past_ids', None)
|
| 196 |
+
content = [{"type": "text", "text": message}]
|
| 197 |
+
if past_ids is None and video_path: # only use once
|
| 198 |
+
content.insert(0, {"type": "video", "video": video_path})
|
| 199 |
+
conversation.append({"role": "user", "content": content})
|
| 200 |
+
print(conversation)
|
| 201 |
+
image_inputs, video_inputs = process_vision_info(conversation)
|
| 202 |
+
texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
| 203 |
if past_ids is not None:
|
| 204 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
| 205 |
inputs = self.processor(
|
|
|
|
| 207 |
images=image_inputs,
|
| 208 |
videos=video_inputs,
|
| 209 |
return_tensors="pt",
|
| 210 |
+
return_attention_mask=False
|
| 211 |
)
|
| 212 |
inputs.to(self.model.device)
|
| 213 |
if past_ids is not None:
|
|
|
|
| 218 |
repetition_penalty=repetition_penalty,
|
| 219 |
max_new_tokens=512,
|
| 220 |
)
|
| 221 |
+
state['past_key_values'] = outputs.past_key_values if not hf_spaces else None
|
| 222 |
+
state['past_ids'] = outputs.sequences[:, :-1] if not hf_spaces else None
|
| 223 |
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
| 224 |
print(response)
|
|
|
|
| 225 |
return response, state
|