Spaces:
Runtime error
Runtime error
Update demo/infer.py
Browse files- demo/infer.py +6 -6
demo/infer.py
CHANGED
|
@@ -54,7 +54,7 @@ class LiveCCDemoInfer:
|
|
| 54 |
@torch.inference_mode()
|
| 55 |
def live_cc(
|
| 56 |
self,
|
| 57 |
-
|
| 58 |
state: dict,
|
| 59 |
max_pixels: int = 384 * 28 * 28,
|
| 60 |
default_query: str = 'Please describe the video.',
|
|
@@ -129,12 +129,12 @@ class LiveCCDemoInfer:
|
|
| 129 |
{"type": "video", "video": clip}
|
| 130 |
]
|
| 131 |
}
|
| 132 |
-
if not
|
| 133 |
-
|
| 134 |
logger.warning(f'No query provided, use default_query={default_query}')
|
| 135 |
-
if
|
| 136 |
-
message['content'].append({"type": "text", "text":
|
| 137 |
-
state['
|
| 138 |
texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
| 139 |
past_ids = state.get('past_ids', None)
|
| 140 |
if past_ids is not None:
|
|
|
|
| 54 |
@torch.inference_mode()
|
| 55 |
def live_cc(
|
| 56 |
self,
|
| 57 |
+
message: str,
|
| 58 |
state: dict,
|
| 59 |
max_pixels: int = 384 * 28 * 28,
|
| 60 |
default_query: str = 'Please describe the video.',
|
|
|
|
| 129 |
{"type": "video", "video": clip}
|
| 130 |
]
|
| 131 |
}
|
| 132 |
+
if not message and not state.get('message', None):
|
| 133 |
+
message = default_query
|
| 134 |
logger.warning(f'No query provided, use default_query={default_query}')
|
| 135 |
+
if message and state.get('message', None) != message:
|
| 136 |
+
message['content'].append({"type": "text", "text": message})
|
| 137 |
+
state['message'] = message
|
| 138 |
texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
| 139 |
past_ids = state.get('past_ids', None)
|
| 140 |
if past_ids is not None:
|