Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| VideoAuto-R1 (Qwen3-VL) Demo | |
| A Gradio-based chat interface for adaptive inference with image/video inputs. | |
| """ | |
| import spaces | |
| import os | |
| import base64 | |
| from io import BytesIO | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoTokenizer | |
| from videoauto_r1.qwen_vl_utils.vision_process import process_vision_info | |
| from videoauto_r1.modeling_qwen3_vl_patched import Qwen3VLForConditionalGeneration | |
| from videoauto_r1.early_exit import compute_first_boxed_answer_probs | |
| # ============================================================================ | |
| # Constants | |
| # ============================================================================ | |
| COT_SYSTEM_PROMPT_ANSWER_TWICE = ( | |
| "You are a helpful assistant.\n" | |
| "FIRST: Output your initial answer inside the first \\boxed{...} without any analysis or explanations. " | |
| "If you cannot determine the answer without reasoning, output \\boxed{Let's analyze the problem step by step.} instead.\n" | |
| "THEN: Think through the reasoning as an internal monologue enclosed within <think>...</think>.\n" | |
| "AT LAST: Output the final answer again inside \\boxed{...}. If you believe the previous answer was correct, repeat it; otherwise, correct it.\n" | |
| "Output format: \\boxed{...}<think>...</think>\\boxed{...}" | |
| ) | |
| VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm") | |
| IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff") | |
| CUSTOM_CSS = """ | |
| #chatbot .message[class*="user"] { | |
| max-width: 50% !important; | |
| } | |
| #chatbot .message[class*="bot"], | |
| #chatbot .message[class*="assistant"] { | |
| max-width: 60% !important; | |
| } | |
| #chatbot .message > div { | |
| width: 100% !important; | |
| max-width: 100% !important; | |
| } | |
| """ | |
| MODEL_PATH = "IVUL-KAUST/VideoAuto-R1-Qwen3-VL-8B" | |
| # ============================================================================ | |
| # Global Model Variables | |
| # ============================================================================ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model | |
| model = ( | |
| Qwen3VLForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| dtype="bfloat16", | |
| attn_implementation="sdpa", | |
| ) | |
| .to("cuda") | |
| .eval() | |
| ) | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| # ============================================================================ | |
| # Utility Functions | |
| # ============================================================================ | |
| def detect_media_type(file_path: str | None) -> str | None: | |
| """ | |
| Detect media type from file extension. | |
| Args: | |
| file_path: Path to the media file | |
| Returns: | |
| 'image', 'video', or None | |
| """ | |
| if not file_path: | |
| return None | |
| p = file_path.lower() | |
| if p.endswith(VIDEO_EXTS): | |
| return "video" | |
| if p.endswith(IMAGE_EXTS): | |
| return "image" | |
| # Fallback: try to open as image | |
| try: | |
| Image.open(file_path) | |
| return "image" | |
| except Exception: | |
| return "video" | |
| def process_image( | |
| image_path: str, | |
| image_min_pixels: int = 128 * 28 * 28, | |
| image_max_pixels: int = 16384 * 28 * 28, | |
| ) -> dict | None: | |
| """ | |
| Process image file to base64 format. | |
| Args: | |
| image_path: Path to image file | |
| image_min_pixels: Minimum pixel count | |
| image_max_pixels: Maximum pixel count | |
| Returns: | |
| Dictionary with image data or None | |
| """ | |
| if image_path is None: | |
| return None | |
| image = Image.open(image_path).convert("RGB") | |
| buffer = BytesIO() | |
| image.save(buffer, format="JPEG") | |
| base64_bytes = base64.b64encode(buffer.getvalue()) | |
| base64_string = base64_bytes.decode("utf-8") | |
| return { | |
| "type": "image", | |
| "image": f"data:image/jpeg;base64,{base64_string}", | |
| "min_pixels": image_min_pixels, | |
| "max_pixels": image_max_pixels, | |
| } | |
| def process_video( | |
| video_path: str, | |
| video_min_pixels: int = 16 * 28 * 28, | |
| video_max_pixels: int = 768 * 28 * 28, | |
| video_total_pixels: int = 128000 * 28 * 28, | |
| min_frames: int = 4, | |
| max_frames: int = 64, | |
| fps: float = 2.0, | |
| ) -> dict | None: | |
| """ | |
| Process video file configuration. | |
| Args: | |
| video_path: Path to video file | |
| video_min_pixels: Minimum pixels per frame | |
| video_max_pixels: Maximum pixels per frame | |
| video_total_pixels: Total pixels across all frames | |
| min_frames: Minimum number of frames | |
| max_frames: Maximum number of frames | |
| fps: Frames per second for sampling | |
| Returns: | |
| Dictionary with video configuration or None | |
| """ | |
| if video_path is None: | |
| return None | |
| return { | |
| "type": "video", | |
| "video": video_path, | |
| "min_pixels": video_min_pixels, | |
| "max_pixels": video_max_pixels, | |
| "total_pixels": video_total_pixels, | |
| "min_frames": min_frames, | |
| "max_frames": max_frames, | |
| "fps": fps, | |
| } | |
| def generate( | |
| media_input: str | None, | |
| prompt: str, | |
| early_exit_thresh: float, | |
| temperature: float, | |
| max_new_tokens: int = 4096, | |
| ) -> dict: | |
| """ | |
| Generate response with adaptive inference. | |
| Args: | |
| media_input: Path to media file | |
| prompt: Text prompt | |
| early_exit_thresh: Confidence threshold for early exit | |
| temperature: Sampling temperature | |
| max_new_tokens: Maximum tokens to generate | |
| Returns: | |
| Dictionary containing response and metadata | |
| """ | |
| # Prepare message | |
| message = [{"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE}] | |
| content_parts = [] | |
| # Process media input | |
| if media_input is not None: | |
| media_type = detect_media_type(media_input) | |
| if media_type == "video": | |
| video_dict = process_video(media_input) | |
| if video_dict: | |
| content_parts.append(video_dict) | |
| elif media_type == "image": | |
| image_dict = process_image(media_input) | |
| if image_dict: | |
| content_parts.append(image_dict) | |
| # Add text prompt | |
| content_parts.append({"type": "text", "text": prompt}) | |
| message.append({"role": "user", "content": content_parts}) | |
| # Apply chat template | |
| text = processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True) | |
| # Process vision inputs | |
| image_inputs, video_inputs, video_kwargs = process_vision_info( | |
| [message], | |
| image_patch_size=16, | |
| return_video_kwargs=True, | |
| return_video_metadata=True, | |
| ) | |
| if video_inputs is not None: | |
| video_inputs, video_metadatas = zip(*video_inputs) | |
| video_inputs = list(video_inputs) | |
| video_metadatas = list(video_metadatas) | |
| else: | |
| video_metadatas = None | |
| # Prepare inputs | |
| inputs = processor( | |
| text=text, | |
| images=image_inputs, | |
| videos=video_inputs, | |
| video_metadata=video_metadatas, | |
| do_resize=False, | |
| padding=True, | |
| return_tensors="pt", | |
| **video_kwargs, | |
| ) | |
| inputs = inputs.to(device) | |
| # Generation configuration | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature if temperature > 0 else None, | |
| "do_sample": temperature > 0, | |
| "top_p": 0.9 if temperature > 0 else None, | |
| "num_beams": 1, | |
| "use_cache": True, | |
| "return_dict_in_generate": True, | |
| "output_scores": True, | |
| } | |
| # Generate response | |
| with torch.no_grad(): | |
| gen_out = model.generate( | |
| **inputs, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| **gen_kwargs, | |
| ) | |
| # Decode output | |
| generated_ids = gen_out.sequences[0][len(inputs.input_ids[0]) :] | |
| answer = processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| # Compute confidence | |
| first_box_probs = compute_first_boxed_answer_probs( | |
| b=0, | |
| gen_ids=generated_ids, | |
| gen_out=gen_out, | |
| ans=answer, | |
| task="", | |
| tokenizer=tokenizer, | |
| ) | |
| # Parse response | |
| first_answer = answer.split("<think>")[0] | |
| second_answer = answer.split("</think>")[-1] if "</think>" in answer else first_answer | |
| reasoning = answer.split("<think>")[-1].split("</think>")[0] if "<think>" in answer else "N/A" | |
| # Determine inference mode | |
| if first_box_probs >= early_exit_thresh: | |
| need_cot = False | |
| reasoning = False | |
| else: | |
| need_cot = True | |
| return { | |
| "full_response": answer, | |
| "first_answer": first_answer, | |
| "confidence": f"{first_box_probs:.4f}", | |
| "need_cot": need_cot, | |
| "reasoning": reasoning, | |
| "second_answer": second_answer, | |
| } | |
| # ============================================================================ | |
| # Gradio Callback Functions | |
| # ============================================================================ | |
| def update_preview(file_path: str | None): | |
| """Update preview widgets based on media type.""" | |
| mtype = detect_media_type(file_path) | |
| if mtype == "image": | |
| return ( | |
| gr.update(value=file_path, visible=True), # image_preview | |
| gr.update(value=None, visible=False), # video_preview | |
| ) | |
| elif mtype == "video": | |
| return ( | |
| gr.update(value=None, visible=False), # image_preview | |
| gr.update(value=file_path, visible=True), # video_preview | |
| ) | |
| else: | |
| return ( | |
| gr.update(value=None, visible=False), | |
| gr.update(value=None, visible=False), | |
| ) | |
| def chat_generate( | |
| media_path, | |
| user_text, | |
| messages_state, | |
| chatbot_state, | |
| last_media_state, | |
| early_exit_thresh, | |
| temperature, | |
| ): | |
| """Handle chat message generation.""" | |
| if user_text is None or str(user_text).strip() == "": | |
| raise gr.Error("Chat message cannot be empty.") | |
| # Clear history if media changed | |
| if ( | |
| (media_path is not None) | |
| and (last_media_state is not None) | |
| and (os.path.basename(media_path) != os.path.basename(last_media_state)) | |
| ): | |
| messages_state = [] | |
| chatbot_state = [] | |
| # Initialize system prompt | |
| if len(messages_state) == 0: | |
| messages_state.append({"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE}) | |
| # Prepare user message | |
| content_parts = [] | |
| if media_path is not None: | |
| mtype = detect_media_type(media_path) | |
| if mtype == "video": | |
| vd = process_video(media_path) | |
| if vd: | |
| content_parts.append(vd) | |
| elif mtype == "image": | |
| imd = process_image(media_path) | |
| if imd: | |
| content_parts.append(imd) | |
| content_parts.append({"type": "text", "text": user_text}) | |
| messages_state.append({"role": "user", "content": content_parts}) | |
| # Generate response | |
| result = generate(media_path, user_text, early_exit_thresh, temperature) | |
| # Format assistant response | |
| first_ans = (result.get("first_answer") or "").strip() | |
| conf = result.get("confidence", "N/A") | |
| need_cot = result.get("need_cot", "") | |
| reasoning = result.get("reasoning", "") | |
| final_ans = (result.get("second_answer") or "").strip() | |
| if need_cot: | |
| decision_prompt = f"Continue CoT Reasoning (confidence = {conf})" | |
| else: | |
| decision_prompt = f"Early Exit (confidence = {conf})" | |
| assistant_display_1 = f"**Initial Answer:**\n{first_ans}\n\n" f"**{decision_prompt}**\n\n" | |
| # Update state | |
| messages_state.append({"role": "assistant", "content": assistant_display_1}) | |
| chatbot_state.append({"role": "user", "content": user_text}) | |
| chatbot_state.append({"role": "assistant", "content": assistant_display_1}) | |
| if need_cot: | |
| assistant_display_2 = ( | |
| f"\n\n**<think>**\n\n{reasoning}\n**</think>**\n\n" f"**Reviewed Answer:**\n{final_ans}\n\n" | |
| ) | |
| messages_state.append({"role": "assistant", "content": assistant_display_2}) | |
| chatbot_state.append({"role": "assistant", "content": assistant_display_2}) | |
| # Disable textbox and send button after generation to prevent interleaved conversation | |
| return ( | |
| messages_state, | |
| chatbot_state, | |
| media_path, | |
| gr.update(value="", interactive=False), # Disable and clear textbox | |
| gr.update(interactive=False), # Disable send button | |
| ) | |
| def clear_history(): | |
| """Clear all chat history and reset interface.""" | |
| return ( | |
| [], # messages_state | |
| [], # chatbot_state | |
| None, # last_media_state | |
| gr.update(value=None), # file | |
| gr.update(value=None, visible=False), # image_preview | |
| gr.update(value=None, visible=False), # video_preview | |
| gr.update(value="", interactive=True), # Re-enable and clear textbox | |
| gr.update(interactive=True), # Re-enable send button | |
| ) | |
| # ============================================================================ | |
| # Example Data | |
| # ============================================================================ | |
| EXAMPLES = [ | |
| [ | |
| "assets/yt--MAYaJ5cyOE_70.mp4", | |
| "Question: Which one of these descriptions correctly matches the actions in the video?\nOptions:\n(A) officiating\n(B) skating\n(C) stopping\n(D) playing sports\nPut your final answer in \\boxed{}.", | |
| # GT is B | |
| ], | |
| [ | |
| "assets/validation_Finance_2.mp4", | |
| "Using the Arbitrage Pricing Theory model shown above, calculate the expected return E(rp) if the risk-free rate increases to 5%. All other risk premiums (RP) and beta (\\beta) values remain unchanged.\nOptions:\nA. 13.4%\nB. 14.8%\nC. 15.6%\nD. 16.1%\nE. 16.5%\nF. 16.9%\nG. 17.5%\nH. 17.8%\nI. 17.2%\nJ. 18.1%\nPut your final answer in \\boxed{}.", | |
| # GT is I | |
| ], | |
| [ | |
| "assets/M3CoT-25169-0.png", | |
| "Within the image, you'll notice several purchased items. And we assume that the water temperature is 4 ° C at this time.\nWithin the image, can you identify the count of items among the provided options that will go below the waterline?\nA. 0\nB. 1\nC. 2\nD. 3\nPut your final answer in \\boxed{}.", | |
| # GT is B | |
| ], | |
| [ | |
| None, | |
| "Determine the value of the parameter $m$ such that the equation $(m-2)x^2 + (m^2-4m+3)x - (6m^2-2) = 0$ has real solutions, and the sum of the cubes of these solutions is equal to zero.\nPut your final answer in \\boxed{}.", | |
| # GT is 3 | |
| ], | |
| ] | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| demo = gr.Blocks(title="VideoAuto-R1 Demo") | |
| with demo: | |
| gr.Markdown("# [VideoAuto-R1 Demo](https://github.com/IVUL-KAUST/VideoAuto-R1/)") | |
| # Display system prompt | |
| with gr.Accordion("System Prompt", open=False): | |
| gr.Markdown(f"```\n{COT_SYSTEM_PROMPT_ANSWER_TWICE}\n```") | |
| # State variables | |
| messages_state = gr.State([]) | |
| chatbot_state = gr.State([]) | |
| last_media_state = gr.State(None) | |
| with gr.Row(): | |
| # Left column: Media input and settings | |
| with gr.Column(scale=3): | |
| media_input = gr.File( | |
| label="Upload Image or Video", | |
| file_types=["image", "video"], | |
| type="filepath", | |
| ) | |
| image_preview = gr.Image(label="Image Preview", visible=False) | |
| video_preview = gr.Video(label="Video Preview", visible=False) | |
| with gr.Accordion("Advanced Settings", open=True): | |
| early_exit_thresh = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.98, | |
| step=0.01, | |
| label="Early Exit Threshold", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.0, | |
| step=0.1, | |
| label="Temperature", | |
| ) | |
| # Right column: Chat interface | |
| with gr.Column(scale=7): | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| elem_id="chatbot", | |
| height=600, | |
| sanitize_html=False, | |
| ) | |
| textbox = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text and press ENTER", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| gr.Markdown("Please click the **Clear** button before starting a new conversation or trying a new example.") | |
| # Event handlers | |
| media_input.change( | |
| fn=update_preview, | |
| inputs=[media_input], | |
| outputs=[image_preview, video_preview], | |
| ) | |
| # Send button click: generate response and disable input controls | |
| send_btn.click( | |
| fn=chat_generate, | |
| inputs=[ | |
| media_input, | |
| textbox, | |
| messages_state, | |
| chatbot_state, | |
| last_media_state, | |
| early_exit_thresh, | |
| temperature, | |
| ], | |
| outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn], | |
| ).then( | |
| fn=lambda cs: cs, | |
| inputs=[chatbot_state], | |
| outputs=[chatbot], | |
| ) | |
| # Textbox submit: generate response and disable input controls | |
| textbox.submit( | |
| fn=chat_generate, | |
| inputs=[ | |
| media_input, | |
| textbox, | |
| messages_state, | |
| chatbot_state, | |
| last_media_state, | |
| early_exit_thresh, | |
| temperature, | |
| ], | |
| outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn], | |
| ).then( | |
| fn=lambda cs: cs, | |
| inputs=[chatbot_state], | |
| outputs=[chatbot], | |
| ) | |
| # Clear button: reset all states and re-enable input controls | |
| clear_btn.click( | |
| fn=clear_history, | |
| inputs=[], | |
| outputs=[ | |
| messages_state, | |
| chatbot_state, | |
| last_media_state, | |
| media_input, | |
| image_preview, | |
| video_preview, | |
| textbox, | |
| send_btn, | |
| ], | |
| ).then( | |
| fn=lambda cs: cs, | |
| inputs=[chatbot_state], | |
| outputs=[chatbot], | |
| ) | |
| examples_ds = gr.Dataset( | |
| components=[media_input, textbox], | |
| samples=EXAMPLES, | |
| label="Examples", | |
| type="index", # important: pass selected row index to fn | |
| ) | |
| def load_example(idx: int | None): | |
| # idx can be None when deselecting | |
| if idx is None: | |
| # just clear everything | |
| return clear_history() | |
| media, text = EXAMPLES[idx][0], EXAMPLES[idx][1] | |
| # 1) clear all states + re-enable inputs | |
| ms, cs, last, file_u, img_u, vid_u, tb_u, send_u = clear_history() | |
| # 2) set selected example values | |
| file_u = gr.update(value=media) | |
| tb_u = gr.update(value=text, interactive=True) | |
| send_u = gr.update(interactive=True) | |
| # 3) update preview explicitly (don't rely on File.change always firing) | |
| img_u, vid_u = update_preview(media) | |
| # 4) optionally set last_media_state to current media | |
| last = media | |
| return ms, cs, last, file_u, img_u, vid_u, tb_u, send_u | |
| examples_ds.select( | |
| fn=load_example, | |
| inputs=[examples_ds], | |
| outputs=[ | |
| messages_state, | |
| chatbot_state, | |
| last_media_state, | |
| media_input, | |
| image_preview, | |
| video_preview, | |
| textbox, | |
| send_btn, | |
| ], | |
| ).then( | |
| fn=lambda cs: cs, | |
| inputs=[chatbot_state], | |
| outputs=[chatbot], | |
| ) | |
| # Launch demo | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| allowed_paths=["assets"], | |
| debug=True, | |
| css=CUSTOM_CSS, | |
| ) | |