Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -49,6 +49,29 @@ CUSTOM_CSS = """
|
|
| 49 |
}
|
| 50 |
"""
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# ============================================================================
|
| 54 |
# Utility Functions
|
|
@@ -82,233 +105,206 @@ def detect_media_type(file_path: str | None) -> str | None:
|
|
| 82 |
return "video"
|
| 83 |
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 96 |
-
|
| 97 |
-
# Load model
|
| 98 |
-
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 99 |
-
model_path,
|
| 100 |
-
dtype="bfloat16",
|
| 101 |
-
attn_implementation="sdpa",
|
| 102 |
-
).to('cuda').eval()
|
| 103 |
-
|
| 104 |
-
self.processor = AutoProcessor.from_pretrained(model_path)
|
| 105 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 106 |
-
self.system_prompt = COT_SYSTEM_PROMPT_ANSWER_TWICE
|
| 107 |
-
|
| 108 |
-
def process_image(
|
| 109 |
-
self,
|
| 110 |
-
image_path: str,
|
| 111 |
-
image_min_pixels: int = 128 * 28 * 28,
|
| 112 |
-
image_max_pixels: int = 16384 * 28 * 28,
|
| 113 |
-
) -> dict | None:
|
| 114 |
-
"""
|
| 115 |
-
Process image file to base64 format.
|
| 116 |
-
|
| 117 |
-
Args:
|
| 118 |
-
image_path: Path to image file
|
| 119 |
-
image_min_pixels: Minimum pixel count
|
| 120 |
-
image_max_pixels: Maximum pixel count
|
| 121 |
-
|
| 122 |
-
Returns:
|
| 123 |
-
Dictionary with image data or None
|
| 124 |
-
"""
|
| 125 |
-
if image_path is None:
|
| 126 |
-
return None
|
| 127 |
-
|
| 128 |
-
image = Image.open(image_path).convert("RGB")
|
| 129 |
-
buffer = BytesIO()
|
| 130 |
-
image.save(buffer, format="JPEG")
|
| 131 |
-
base64_bytes = base64.b64encode(buffer.getvalue())
|
| 132 |
-
base64_string = base64_bytes.decode("utf-8")
|
| 133 |
-
|
| 134 |
-
return {
|
| 135 |
-
"type": "image",
|
| 136 |
-
"image": f"data:image/jpeg;base64,{base64_string}",
|
| 137 |
-
"min_pixels": image_min_pixels,
|
| 138 |
-
"max_pixels": image_max_pixels,
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
def process_video(
|
| 142 |
-
self,
|
| 143 |
-
video_path: str,
|
| 144 |
-
video_min_pixels: int = 16 * 28 * 28,
|
| 145 |
-
video_max_pixels: int = 768 * 28 * 28,
|
| 146 |
-
video_total_pixels: int = 128000 * 28 * 28,
|
| 147 |
-
min_frames: int = 4,
|
| 148 |
-
max_frames: int = 64,
|
| 149 |
-
fps: float = 2.0,
|
| 150 |
-
) -> dict | None:
|
| 151 |
-
"""
|
| 152 |
-
Process video file configuration.
|
| 153 |
-
|
| 154 |
-
Args:
|
| 155 |
-
video_path: Path to video file
|
| 156 |
-
video_min_pixels: Minimum pixels per frame
|
| 157 |
-
video_max_pixels: Maximum pixels per frame
|
| 158 |
-
video_total_pixels: Total pixels across all frames
|
| 159 |
-
min_frames: Minimum number of frames
|
| 160 |
-
max_frames: Maximum number of frames
|
| 161 |
-
fps: Frames per second for sampling
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
Dictionary with video configuration or None
|
| 165 |
-
"""
|
| 166 |
-
if video_path is None:
|
| 167 |
-
return None
|
| 168 |
-
|
| 169 |
-
return {
|
| 170 |
-
"type": "video",
|
| 171 |
-
"video": video_path,
|
| 172 |
-
"min_pixels": video_min_pixels,
|
| 173 |
-
"max_pixels": video_max_pixels,
|
| 174 |
-
"total_pixels": video_total_pixels,
|
| 175 |
-
"min_frames": min_frames,
|
| 176 |
-
"max_frames": max_frames,
|
| 177 |
-
"fps": fps,
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
@spaces.GPU(duration=120)
|
| 181 |
-
def generate(
|
| 182 |
-
self,
|
| 183 |
-
media_input: str | None,
|
| 184 |
-
prompt: str,
|
| 185 |
-
early_exit_thresh: float,
|
| 186 |
-
temperature: float,
|
| 187 |
-
max_new_tokens: int = 4096,
|
| 188 |
-
) -> dict:
|
| 189 |
-
"""
|
| 190 |
-
Generate response with adaptive inference.
|
| 191 |
-
|
| 192 |
-
Args:
|
| 193 |
-
media_input: Path to media file
|
| 194 |
-
prompt: Text prompt
|
| 195 |
-
early_exit_thresh: Confidence threshold for early exit
|
| 196 |
-
temperature: Sampling temperature
|
| 197 |
-
max_new_tokens: Maximum tokens to generate
|
| 198 |
-
|
| 199 |
-
Returns:
|
| 200 |
-
Dictionary containing response and metadata
|
| 201 |
-
"""
|
| 202 |
-
# if self.model.device.type != "cuda":
|
| 203 |
-
# self.model.to("cuda")
|
| 204 |
-
|
| 205 |
-
# Prepare message
|
| 206 |
-
message = [{"role": "system", "content": self.system_prompt}]
|
| 207 |
-
content_parts = []
|
| 208 |
-
|
| 209 |
-
# Process media input
|
| 210 |
-
if media_input is not None:
|
| 211 |
-
media_type = detect_media_type(media_input)
|
| 212 |
-
|
| 213 |
-
if media_type == "video":
|
| 214 |
-
video_dict = self.process_video(media_input)
|
| 215 |
-
if video_dict:
|
| 216 |
-
content_parts.append(video_dict)
|
| 217 |
-
elif media_type == "image":
|
| 218 |
-
image_dict = self.process_image(media_input)
|
| 219 |
-
if image_dict:
|
| 220 |
-
content_parts.append(image_dict)
|
| 221 |
-
|
| 222 |
-
# Add text prompt
|
| 223 |
-
content_parts.append({"type": "text", "text": prompt})
|
| 224 |
-
message.append({"role": "user", "content": content_parts})
|
| 225 |
-
|
| 226 |
-
# Apply chat template
|
| 227 |
-
text = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
|
| 228 |
-
|
| 229 |
-
# Process vision inputs
|
| 230 |
-
image_inputs, video_inputs, video_kwargs = process_vision_info(
|
| 231 |
-
[message],
|
| 232 |
-
image_patch_size=16,
|
| 233 |
-
return_video_kwargs=True,
|
| 234 |
-
return_video_metadata=True,
|
| 235 |
-
)
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
|
| 313 |
|
| 314 |
# ============================================================================
|
|
@@ -361,18 +357,18 @@ def chat_generate(
|
|
| 361 |
|
| 362 |
# Initialize system prompt
|
| 363 |
if len(messages_state) == 0:
|
| 364 |
-
messages_state.append({"role": "system", "content":
|
| 365 |
|
| 366 |
# Prepare user message
|
| 367 |
content_parts = []
|
| 368 |
if media_path is not None:
|
| 369 |
mtype = detect_media_type(media_path)
|
| 370 |
if mtype == "video":
|
| 371 |
-
vd =
|
| 372 |
if vd:
|
| 373 |
content_parts.append(vd)
|
| 374 |
elif mtype == "image":
|
| 375 |
-
imd =
|
| 376 |
if imd:
|
| 377 |
content_parts.append(imd)
|
| 378 |
|
|
@@ -380,7 +376,7 @@ def chat_generate(
|
|
| 380 |
messages_state.append({"role": "user", "content": content_parts})
|
| 381 |
|
| 382 |
# Generate response
|
| 383 |
-
result =
|
| 384 |
|
| 385 |
# Format assistant response
|
| 386 |
first_ans = (result.get("first_answer") or "").strip()
|
|
@@ -465,155 +461,145 @@ EXAMPLES = [
|
|
| 465 |
# Gradio Interface
|
| 466 |
# ============================================================================
|
| 467 |
|
|
|
|
| 468 |
|
| 469 |
-
|
| 470 |
-
"
|
| 471 |
-
with gr.Blocks(title="VideoAuto-R1 Demo") as demo:
|
| 472 |
-
gr.Markdown("# VideoAuto-R1 (Qwen3-VL-8B) Demo")
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
)
|
| 502 |
-
temperature = gr.Slider(
|
| 503 |
-
minimum=0.0,
|
| 504 |
-
maximum=2.0,
|
| 505 |
-
value=0.0,
|
| 506 |
-
step=0.1,
|
| 507 |
-
label="Temperature",
|
| 508 |
-
)
|
| 509 |
-
|
| 510 |
-
# Right column: Chat interface
|
| 511 |
-
with gr.Column(scale=7):
|
| 512 |
-
chatbot = gr.Chatbot(
|
| 513 |
-
label="Chat",
|
| 514 |
-
elem_id="chatbot",
|
| 515 |
-
height=600,
|
| 516 |
-
sanitize_html=False,
|
| 517 |
-
)
|
| 518 |
-
textbox = gr.Textbox(
|
| 519 |
-
show_label=False,
|
| 520 |
-
placeholder="Enter text and press ENTER",
|
| 521 |
-
lines=2,
|
| 522 |
)
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
"
|
| 529 |
)
|
| 530 |
|
| 531 |
-
#
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
early_exit_thresh,
|
| 548 |
-
temperature,
|
| 549 |
-
],
|
| 550 |
-
outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
|
| 551 |
-
).then(
|
| 552 |
-
fn=lambda cs: cs,
|
| 553 |
-
inputs=[chatbot_state],
|
| 554 |
-
outputs=[chatbot],
|
| 555 |
-
)
|
| 556 |
|
| 557 |
-
|
| 558 |
-
textbox.submit(
|
| 559 |
-
fn=chat_generate,
|
| 560 |
-
inputs=[
|
| 561 |
-
media_input,
|
| 562 |
-
textbox,
|
| 563 |
-
messages_state,
|
| 564 |
-
chatbot_state,
|
| 565 |
-
last_media_state,
|
| 566 |
-
early_exit_thresh,
|
| 567 |
-
temperature,
|
| 568 |
-
],
|
| 569 |
-
outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
|
| 570 |
-
).then(
|
| 571 |
-
fn=lambda cs: cs,
|
| 572 |
-
inputs=[chatbot_state],
|
| 573 |
-
outputs=[chatbot],
|
| 574 |
-
)
|
| 575 |
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
chatbot_state,
|
| 583 |
-
last_media_state,
|
| 584 |
-
media_input,
|
| 585 |
-
image_preview,
|
| 586 |
-
video_preview,
|
| 587 |
-
textbox,
|
| 588 |
-
send_btn,
|
| 589 |
-
],
|
| 590 |
-
).then(
|
| 591 |
-
fn=lambda cs: cs,
|
| 592 |
-
inputs=[chatbot_state],
|
| 593 |
-
outputs=[chatbot],
|
| 594 |
-
)
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
|
|
|
|
|
|
|
|
|
| 609 |
|
| 610 |
-
if __name__ == "__main__":
|
| 611 |
-
# Initialize model
|
| 612 |
-
demo_model = Qwen3VLAutoThinkDemo()
|
| 613 |
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
}
|
| 50 |
"""
|
| 51 |
|
| 52 |
+
MODEL_PATH = "IVUL-KAUST/VideoAuto-R1-Qwen3-VL-8B"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ============================================================================
|
| 56 |
+
# Global Model Variables
|
| 57 |
+
# ============================================================================
|
| 58 |
+
|
| 59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
|
| 61 |
+
# Load model
|
| 62 |
+
model = (
|
| 63 |
+
Qwen3VLForConditionalGeneration.from_pretrained(
|
| 64 |
+
MODEL_PATH,
|
| 65 |
+
dtype="bfloat16",
|
| 66 |
+
attn_implementation="sdpa",
|
| 67 |
+
)
|
| 68 |
+
.to("cuda")
|
| 69 |
+
.eval()
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 74 |
+
|
| 75 |
|
| 76 |
# ============================================================================
|
| 77 |
# Utility Functions
|
|
|
|
| 105 |
return "video"
|
| 106 |
|
| 107 |
|
| 108 |
+
def process_image(
|
| 109 |
+
image_path: str,
|
| 110 |
+
image_min_pixels: int = 128 * 28 * 28,
|
| 111 |
+
image_max_pixels: int = 16384 * 28 * 28,
|
| 112 |
+
) -> dict | None:
|
| 113 |
+
"""
|
| 114 |
+
Process image file to base64 format.
|
| 115 |
|
| 116 |
+
Args:
|
| 117 |
+
image_path: Path to image file
|
| 118 |
+
image_min_pixels: Minimum pixel count
|
| 119 |
+
image_max_pixels: Maximum pixel count
|
| 120 |
|
| 121 |
+
Returns:
|
| 122 |
+
Dictionary with image data or None
|
| 123 |
+
"""
|
| 124 |
+
if image_path is None:
|
| 125 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
image = Image.open(image_path).convert("RGB")
|
| 128 |
+
buffer = BytesIO()
|
| 129 |
+
image.save(buffer, format="JPEG")
|
| 130 |
+
base64_bytes = base64.b64encode(buffer.getvalue())
|
| 131 |
+
base64_string = base64_bytes.decode("utf-8")
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"type": "image",
|
| 135 |
+
"image": f"data:image/jpeg;base64,{base64_string}",
|
| 136 |
+
"min_pixels": image_min_pixels,
|
| 137 |
+
"max_pixels": image_max_pixels,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def process_video(
|
| 142 |
+
video_path: str,
|
| 143 |
+
video_min_pixels: int = 16 * 28 * 28,
|
| 144 |
+
video_max_pixels: int = 768 * 28 * 28,
|
| 145 |
+
video_total_pixels: int = 128000 * 28 * 28,
|
| 146 |
+
min_frames: int = 4,
|
| 147 |
+
max_frames: int = 64,
|
| 148 |
+
fps: float = 2.0,
|
| 149 |
+
) -> dict | None:
|
| 150 |
+
"""
|
| 151 |
+
Process video file configuration.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
video_path: Path to video file
|
| 155 |
+
video_min_pixels: Minimum pixels per frame
|
| 156 |
+
video_max_pixels: Maximum pixels per frame
|
| 157 |
+
video_total_pixels: Total pixels across all frames
|
| 158 |
+
min_frames: Minimum number of frames
|
| 159 |
+
max_frames: Maximum number of frames
|
| 160 |
+
fps: Frames per second for sampling
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Dictionary with video configuration or None
|
| 164 |
+
"""
|
| 165 |
+
if video_path is None:
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
"type": "video",
|
| 170 |
+
"video": video_path,
|
| 171 |
+
"min_pixels": video_min_pixels,
|
| 172 |
+
"max_pixels": video_max_pixels,
|
| 173 |
+
"total_pixels": video_total_pixels,
|
| 174 |
+
"min_frames": min_frames,
|
| 175 |
+
"max_frames": max_frames,
|
| 176 |
+
"fps": fps,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@spaces.GPU(duration=180)
|
| 181 |
+
def generate(
|
| 182 |
+
media_input: str | None,
|
| 183 |
+
prompt: str,
|
| 184 |
+
early_exit_thresh: float,
|
| 185 |
+
temperature: float,
|
| 186 |
+
max_new_tokens: int = 4096,
|
| 187 |
+
) -> dict:
|
| 188 |
+
"""
|
| 189 |
+
Generate response with adaptive inference.
|
| 190 |
|
| 191 |
+
Args:
|
| 192 |
+
media_input: Path to media file
|
| 193 |
+
prompt: Text prompt
|
| 194 |
+
early_exit_thresh: Confidence threshold for early exit
|
| 195 |
+
temperature: Sampling temperature
|
| 196 |
+
max_new_tokens: Maximum tokens to generate
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Dictionary containing response and metadata
|
| 200 |
+
"""
|
| 201 |
+
# Prepare message
|
| 202 |
+
message = [{"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE}]
|
| 203 |
+
content_parts = []
|
| 204 |
+
|
| 205 |
+
# Process media input
|
| 206 |
+
if media_input is not None:
|
| 207 |
+
media_type = detect_media_type(media_input)
|
| 208 |
+
|
| 209 |
+
if media_type == "video":
|
| 210 |
+
video_dict = process_video(media_input)
|
| 211 |
+
if video_dict:
|
| 212 |
+
content_parts.append(video_dict)
|
| 213 |
+
elif media_type == "image":
|
| 214 |
+
image_dict = process_image(media_input)
|
| 215 |
+
if image_dict:
|
| 216 |
+
content_parts.append(image_dict)
|
| 217 |
+
|
| 218 |
+
# Add text prompt
|
| 219 |
+
content_parts.append({"type": "text", "text": prompt})
|
| 220 |
+
message.append({"role": "user", "content": content_parts})
|
| 221 |
+
|
| 222 |
+
# Apply chat template
|
| 223 |
+
text = processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
|
| 224 |
+
|
| 225 |
+
# Process vision inputs
|
| 226 |
+
image_inputs, video_inputs, video_kwargs = process_vision_info(
|
| 227 |
+
[message],
|
| 228 |
+
image_patch_size=16,
|
| 229 |
+
return_video_kwargs=True,
|
| 230 |
+
return_video_metadata=True,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if video_inputs is not None:
|
| 234 |
+
video_inputs, video_metadatas = zip(*video_inputs)
|
| 235 |
+
video_inputs = list(video_inputs)
|
| 236 |
+
video_metadatas = list(video_metadatas)
|
| 237 |
+
else:
|
| 238 |
+
video_metadatas = None
|
| 239 |
+
|
| 240 |
+
# Prepare inputs
|
| 241 |
+
inputs = processor(
|
| 242 |
+
text=text,
|
| 243 |
+
images=image_inputs,
|
| 244 |
+
videos=video_inputs,
|
| 245 |
+
video_metadata=video_metadatas,
|
| 246 |
+
do_resize=False,
|
| 247 |
+
padding=True,
|
| 248 |
+
return_tensors="pt",
|
| 249 |
+
**video_kwargs,
|
| 250 |
+
)
|
| 251 |
+
inputs = inputs.to(device)
|
| 252 |
+
|
| 253 |
+
# Generation configuration
|
| 254 |
+
gen_kwargs = {
|
| 255 |
+
"max_new_tokens": max_new_tokens,
|
| 256 |
+
"temperature": temperature if temperature > 0 else None,
|
| 257 |
+
"do_sample": temperature > 0,
|
| 258 |
+
"top_p": 0.9 if temperature > 0 else None,
|
| 259 |
+
"num_beams": 1,
|
| 260 |
+
"use_cache": True,
|
| 261 |
+
"return_dict_in_generate": True,
|
| 262 |
+
"output_scores": True,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
# Generate response
|
| 266 |
+
with torch.no_grad():
|
| 267 |
+
gen_out = model.generate(
|
| 268 |
+
**inputs,
|
| 269 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 270 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 271 |
+
**gen_kwargs,
|
| 272 |
)
|
| 273 |
|
| 274 |
+
# Decode output
|
| 275 |
+
generated_ids = gen_out.sequences[0][len(inputs.input_ids[0]) :]
|
| 276 |
+
answer = processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
| 277 |
+
|
| 278 |
+
# Compute confidence
|
| 279 |
+
first_box_probs = compute_first_boxed_answer_probs(
|
| 280 |
+
b=0,
|
| 281 |
+
gen_ids=generated_ids,
|
| 282 |
+
gen_out=gen_out,
|
| 283 |
+
ans=answer,
|
| 284 |
+
task="",
|
| 285 |
+
tokenizer=tokenizer,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Parse response
|
| 289 |
+
first_answer = answer.split("<think>")[0]
|
| 290 |
+
second_answer = answer.split("</think>")[-1] if "</think>" in answer else first_answer
|
| 291 |
+
reasoning = answer.split("<think>")[-1].split("</think>")[0] if "<think>" in answer else "N/A"
|
| 292 |
|
| 293 |
+
# Determine inference mode
|
| 294 |
+
if first_box_probs >= early_exit_thresh:
|
| 295 |
+
need_cot = False
|
| 296 |
+
reasoning = False
|
| 297 |
+
else:
|
| 298 |
+
need_cot = True
|
| 299 |
|
| 300 |
+
return {
|
| 301 |
+
"full_response": answer,
|
| 302 |
+
"first_answer": first_answer,
|
| 303 |
+
"confidence": f"{first_box_probs:.4f}",
|
| 304 |
+
"need_cot": need_cot,
|
| 305 |
+
"reasoning": reasoning,
|
| 306 |
+
"second_answer": second_answer,
|
| 307 |
+
}
|
| 308 |
|
| 309 |
|
| 310 |
# ============================================================================
|
|
|
|
| 357 |
|
| 358 |
# Initialize system prompt
|
| 359 |
if len(messages_state) == 0:
|
| 360 |
+
messages_state.append({"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE})
|
| 361 |
|
| 362 |
# Prepare user message
|
| 363 |
content_parts = []
|
| 364 |
if media_path is not None:
|
| 365 |
mtype = detect_media_type(media_path)
|
| 366 |
if mtype == "video":
|
| 367 |
+
vd = process_video(media_path)
|
| 368 |
if vd:
|
| 369 |
content_parts.append(vd)
|
| 370 |
elif mtype == "image":
|
| 371 |
+
imd = process_image(media_path)
|
| 372 |
if imd:
|
| 373 |
content_parts.append(imd)
|
| 374 |
|
|
|
|
| 376 |
messages_state.append({"role": "user", "content": content_parts})
|
| 377 |
|
| 378 |
# Generate response
|
| 379 |
+
result = generate(media_path, user_text, early_exit_thresh, temperature)
|
| 380 |
|
| 381 |
# Format assistant response
|
| 382 |
first_ans = (result.get("first_answer") or "").strip()
|
|
|
|
| 461 |
# Gradio Interface
|
| 462 |
# ============================================================================
|
| 463 |
|
| 464 |
+
demo = gr.Blocks(title="VideoAuto-R1 Demo")
|
| 465 |
|
| 466 |
+
with demo:
|
| 467 |
+
gr.Markdown("# VideoAuto-R1 (Qwen3-VL-8B) Demo")
|
|
|
|
|
|
|
| 468 |
|
| 469 |
+
# Display system prompt
|
| 470 |
+
with gr.Accordion("System Prompt", open=False):
|
| 471 |
+
gr.Markdown(f"```\n{COT_SYSTEM_PROMPT_ANSWER_TWICE}\n```")
|
| 472 |
|
| 473 |
+
# State variables
|
| 474 |
+
messages_state = gr.State([])
|
| 475 |
+
chatbot_state = gr.State([])
|
| 476 |
+
last_media_state = gr.State(None)
|
| 477 |
|
| 478 |
+
with gr.Row():
|
| 479 |
+
# Left column: Media input and settings
|
| 480 |
+
with gr.Column(scale=3):
|
| 481 |
+
media_input = gr.File(
|
| 482 |
+
label="Upload Image or Video",
|
| 483 |
+
file_types=["image", "video"],
|
| 484 |
+
type="filepath",
|
| 485 |
+
)
|
| 486 |
+
image_preview = gr.Image(label="Image Preview", visible=False)
|
| 487 |
+
video_preview = gr.Video(label="Video Preview", visible=False)
|
| 488 |
+
|
| 489 |
+
with gr.Accordion("Advanced Settings", open=True):
|
| 490 |
+
early_exit_thresh = gr.Slider(
|
| 491 |
+
minimum=0.0,
|
| 492 |
+
maximum=1.0,
|
| 493 |
+
value=0.98,
|
| 494 |
+
step=0.01,
|
| 495 |
+
label="Early Exit Threshold",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
)
|
| 497 |
+
temperature = gr.Slider(
|
| 498 |
+
minimum=0.0,
|
| 499 |
+
maximum=2.0,
|
| 500 |
+
value=0.0,
|
| 501 |
+
step=0.1,
|
| 502 |
+
label="Temperature",
|
| 503 |
)
|
| 504 |
|
| 505 |
+
# Right column: Chat interface
|
| 506 |
+
with gr.Column(scale=7):
|
| 507 |
+
chatbot = gr.Chatbot(
|
| 508 |
+
label="Chat",
|
| 509 |
+
elem_id="chatbot",
|
| 510 |
+
height=600,
|
| 511 |
+
sanitize_html=False,
|
| 512 |
+
)
|
| 513 |
+
textbox = gr.Textbox(
|
| 514 |
+
show_label=False,
|
| 515 |
+
placeholder="Enter text and press ENTER",
|
| 516 |
+
lines=2,
|
| 517 |
+
)
|
| 518 |
+
with gr.Row():
|
| 519 |
+
send_btn = gr.Button("Send", variant="primary")
|
| 520 |
+
clear_btn = gr.Button("Clear")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
+
gr.Markdown("Please click the **Clear** button before starting a new conversation or trying a new example.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
+
# Event handlers
|
| 525 |
+
media_input.change(
|
| 526 |
+
fn=update_preview,
|
| 527 |
+
inputs=[media_input],
|
| 528 |
+
outputs=[image_preview, video_preview],
|
| 529 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
+
# Send button click: generate response and disable input controls
|
| 532 |
+
send_btn.click(
|
| 533 |
+
fn=chat_generate,
|
| 534 |
+
inputs=[
|
| 535 |
+
media_input,
|
| 536 |
+
textbox,
|
| 537 |
+
messages_state,
|
| 538 |
+
chatbot_state,
|
| 539 |
+
last_media_state,
|
| 540 |
+
early_exit_thresh,
|
| 541 |
+
temperature,
|
| 542 |
+
],
|
| 543 |
+
outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
|
| 544 |
+
).then(
|
| 545 |
+
fn=lambda cs: cs,
|
| 546 |
+
inputs=[chatbot_state],
|
| 547 |
+
outputs=[chatbot],
|
| 548 |
+
)
|
| 549 |
|
| 550 |
+
# Textbox submit: generate response and disable input controls
|
| 551 |
+
textbox.submit(
|
| 552 |
+
fn=chat_generate,
|
| 553 |
+
inputs=[
|
| 554 |
+
media_input,
|
| 555 |
+
textbox,
|
| 556 |
+
messages_state,
|
| 557 |
+
chatbot_state,
|
| 558 |
+
last_media_state,
|
| 559 |
+
early_exit_thresh,
|
| 560 |
+
temperature,
|
| 561 |
+
],
|
| 562 |
+
outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn],
|
| 563 |
+
).then(
|
| 564 |
+
fn=lambda cs: cs,
|
| 565 |
+
inputs=[chatbot_state],
|
| 566 |
+
outputs=[chatbot],
|
| 567 |
+
)
|
| 568 |
|
| 569 |
+
# Clear button: reset all states and re-enable input controls
|
| 570 |
+
clear_btn.click(
|
| 571 |
+
fn=clear_history,
|
| 572 |
+
inputs=[],
|
| 573 |
+
outputs=[
|
| 574 |
+
messages_state,
|
| 575 |
+
chatbot_state,
|
| 576 |
+
last_media_state,
|
| 577 |
+
media_input,
|
| 578 |
+
image_preview,
|
| 579 |
+
video_preview,
|
| 580 |
+
textbox,
|
| 581 |
+
send_btn,
|
| 582 |
+
],
|
| 583 |
+
).then(
|
| 584 |
+
fn=lambda cs: cs,
|
| 585 |
+
inputs=[chatbot_state],
|
| 586 |
+
outputs=[chatbot],
|
| 587 |
+
)
|
| 588 |
|
| 589 |
+
gr.Examples(
|
| 590 |
+
examples=EXAMPLES,
|
| 591 |
+
inputs=[media_input, textbox],
|
| 592 |
+
label="Examples",
|
| 593 |
+
cache_examples=False,
|
| 594 |
+
)
|
| 595 |
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
+
# Launch demo
|
| 598 |
+
demo.launch(
|
| 599 |
+
share=True,
|
| 600 |
+
server_name="0.0.0.0",
|
| 601 |
+
server_port=7860,
|
| 602 |
+
allowed_paths=["assets"],
|
| 603 |
+
debug=True,
|
| 604 |
+
css=CUSTOM_CSS,
|
| 605 |
+
)
|