Update README.md
Browse files
README.md
CHANGED
|
@@ -226,7 +226,7 @@ class LiveCCDemoInfer:
|
|
| 226 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 227 |
model_path, torch_dtype="auto",
|
| 228 |
device_map=device,
|
| 229 |
-
attn_implementation='
|
| 230 |
)
|
| 231 |
self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
|
| 232 |
self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
|
|
@@ -239,17 +239,13 @@ class LiveCCDemoInfer:
|
|
| 239 |
}
|
| 240 |
texts = self.processor.apply_chat_template([message], tokenize=False)
|
| 241 |
self.system_prompt_offset = texts.index('<|im_start|>user')
|
| 242 |
-
self._cached_video_readers_with_hw = {}
|
| 243 |
|
| 244 |
-
@torch.inference_mode()
|
| 245 |
def video_qa(
|
| 246 |
self,
|
| 247 |
message: str,
|
| 248 |
state: dict,
|
| 249 |
-
history: list = [],
|
| 250 |
do_sample: bool = False,
|
| 251 |
repetition_penalty: float = 1.05,
|
| 252 |
-
hf_spaces: bool = False,
|
| 253 |
**kwargs,
|
| 254 |
):
|
| 255 |
"""
|
|
@@ -263,15 +259,6 @@ class LiveCCDemoInfer:
|
|
| 263 |
"""
|
| 264 |
video_path = state.get('video_path', None)
|
| 265 |
conversation = []
|
| 266 |
-
if hf_spaces:
|
| 267 |
-
for past_message in history:
|
| 268 |
-
content = [{"type": "text", "text": past_message['content']}]
|
| 269 |
-
if video_path: # only use once
|
| 270 |
-
content.insert(0, {"type": "video", "video": video_path})
|
| 271 |
-
video_path = None
|
| 272 |
-
conversation.append({"role": past_message["role"], "content": content})
|
| 273 |
-
else:
|
| 274 |
-
pass # use past_key_values
|
| 275 |
past_ids = state.get('past_ids', None)
|
| 276 |
content = [{"type": "text", "text": message}]
|
| 277 |
if past_ids is None and video_path: # only use once
|
|
@@ -297,20 +284,25 @@ class LiveCCDemoInfer:
|
|
| 297 |
repetition_penalty=repetition_penalty,
|
| 298 |
max_new_tokens=512,
|
| 299 |
)
|
| 300 |
-
state['past_key_values'] = outputs.past_key_values
|
| 301 |
-
state['past_ids'] = outputs.sequences[:, :-1]
|
| 302 |
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
| 303 |
return response, state
|
| 304 |
|
| 305 |
model_path = 'chenjoya/LiveCC-7B-Instruct'
|
| 306 |
-
|
|
|
|
| 307 |
|
| 308 |
infer = LiveCCDemoInfer(model_path=model_path)
|
| 309 |
state = {'video_path': video_path}
|
| 310 |
# first round
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
# second round
|
| 313 |
-
|
|
|
|
|
|
|
| 314 |
```
|
| 315 |
|
| 316 |
## Limitations
|
|
|
|
| 226 |
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 227 |
model_path, torch_dtype="auto",
|
| 228 |
device_map=device,
|
| 229 |
+
attn_implementation='flash_attention_2'
|
| 230 |
)
|
| 231 |
self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
|
| 232 |
self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
|
|
|
|
| 239 |
}
|
| 240 |
texts = self.processor.apply_chat_template([message], tokenize=False)
|
| 241 |
self.system_prompt_offset = texts.index('<|im_start|>user')
|
|
|
|
| 242 |
|
|
|
|
| 243 |
def video_qa(
|
| 244 |
self,
|
| 245 |
message: str,
|
| 246 |
state: dict,
|
|
|
|
| 247 |
do_sample: bool = False,
|
| 248 |
repetition_penalty: float = 1.05,
|
|
|
|
| 249 |
**kwargs,
|
| 250 |
):
|
| 251 |
"""
|
|
|
|
| 259 |
"""
|
| 260 |
video_path = state.get('video_path', None)
|
| 261 |
conversation = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
past_ids = state.get('past_ids', None)
|
| 263 |
content = [{"type": "text", "text": message}]
|
| 264 |
if past_ids is None and video_path: # only use once
|
|
|
|
| 284 |
repetition_penalty=repetition_penalty,
|
| 285 |
max_new_tokens=512,
|
| 286 |
)
|
| 287 |
+
state['past_key_values'] = outputs.past_key_values
|
| 288 |
+
state['past_ids'] = outputs.sequences[:, :-1]
|
| 289 |
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
| 290 |
return response, state
|
| 291 |
|
| 292 |
model_path = 'chenjoya/LiveCC-7B-Instruct'
|
| 293 |
+
# download a test video at: https://github.com/showlab/livecc/blob/main/demo/sources/howto_fix_laptop_mute_1080p.mp4
|
| 294 |
+
video_path = "demo/sources/howto_fix_laptop_mute_1080p.mp4"
|
| 295 |
|
| 296 |
infer = LiveCCDemoInfer(model_path=model_path)
|
| 297 |
state = {'video_path': video_path}
|
| 298 |
# first round
|
| 299 |
+
query1 = 'What is the video?'
|
| 300 |
+
response1, state = infer.video_qa(message=query1, state=state)
|
| 301 |
+
print(f'Q1: {query1}\nA1: {response1}')
|
| 302 |
# second round
|
| 303 |
+
query2 = 'How do you know that?'
|
| 304 |
+
response2, state = infer.video_qa(message=query2, state=state)
|
| 305 |
+
print(f'Q2: {query2}\nA2: {response2}')
|
| 306 |
```
|
| 307 |
|
| 308 |
## Limitations
|