""" VideoAuto-R1 (Qwen3-VL) Demo A Gradio-based chat interface for adaptive inference with image/video inputs. """ import spaces import os import base64 from io import BytesIO import torch import gradio as gr from PIL import Image from transformers import AutoProcessor, AutoTokenizer from videoauto_r1.qwen_vl_utils.vision_process import process_vision_info from videoauto_r1.modeling_qwen3_vl_patched import Qwen3VLForConditionalGeneration from videoauto_r1.early_exit import compute_first_boxed_answer_probs # ============================================================================ # Constants # ============================================================================ COT_SYSTEM_PROMPT_ANSWER_TWICE = ( "You are a helpful assistant.\n" "FIRST: Output your initial answer inside the first \\boxed{...} without any analysis or explanations. " "If you cannot determine the answer without reasoning, output \\boxed{Let's analyze the problem step by step.} instead.\n" "THEN: Think through the reasoning as an internal monologue enclosed within ....\n" "AT LAST: Output the final answer again inside \\boxed{...}. If you believe the previous answer was correct, repeat it; otherwise, correct it.\n" "Output format: \\boxed{...}...\\boxed{...}" ) VIDEO_EXTS = (".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm") IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff") CUSTOM_CSS = """ #chatbot .message[class*="user"] { max-width: 50% !important; } #chatbot .message[class*="bot"], #chatbot .message[class*="assistant"] { max-width: 60% !important; } #chatbot .message > div { width: 100% !important; max-width: 100% !important; } """ MODEL_PATH = "IVUL-KAUST/VideoAuto-R1-Qwen3-VL-8B" # ============================================================================ # Global Model Variables # ============================================================================ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = ( Qwen3VLForConditionalGeneration.from_pretrained( MODEL_PATH, dtype="bfloat16", attn_implementation="sdpa", ) .to("cuda") .eval() ) processor = AutoProcessor.from_pretrained(MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) # ============================================================================ # Utility Functions # ============================================================================ def detect_media_type(file_path: str | None) -> str | None: """ Detect media type from file extension. Args: file_path: Path to the media file Returns: 'image', 'video', or None """ if not file_path: return None p = file_path.lower() if p.endswith(VIDEO_EXTS): return "video" if p.endswith(IMAGE_EXTS): return "image" # Fallback: try to open as image try: Image.open(file_path) return "image" except Exception: return "video" def process_image( image_path: str, image_min_pixels: int = 128 * 28 * 28, image_max_pixels: int = 16384 * 28 * 28, ) -> dict | None: """ Process image file to base64 format. Args: image_path: Path to image file image_min_pixels: Minimum pixel count image_max_pixels: Maximum pixel count Returns: Dictionary with image data or None """ if image_path is None: return None image = Image.open(image_path).convert("RGB") buffer = BytesIO() image.save(buffer, format="JPEG") base64_bytes = base64.b64encode(buffer.getvalue()) base64_string = base64_bytes.decode("utf-8") return { "type": "image", "image": f"data:image/jpeg;base64,{base64_string}", "min_pixels": image_min_pixels, "max_pixels": image_max_pixels, } def process_video( video_path: str, video_min_pixels: int = 16 * 28 * 28, video_max_pixels: int = 768 * 28 * 28, video_total_pixels: int = 128000 * 28 * 28, min_frames: int = 4, max_frames: int = 64, fps: float = 2.0, ) -> dict | None: """ Process video file configuration. Args: video_path: Path to video file video_min_pixels: Minimum pixels per frame video_max_pixels: Maximum pixels per frame video_total_pixels: Total pixels across all frames min_frames: Minimum number of frames max_frames: Maximum number of frames fps: Frames per second for sampling Returns: Dictionary with video configuration or None """ if video_path is None: return None return { "type": "video", "video": video_path, "min_pixels": video_min_pixels, "max_pixels": video_max_pixels, "total_pixels": video_total_pixels, "min_frames": min_frames, "max_frames": max_frames, "fps": fps, } @spaces.GPU(duration=180) def generate( media_input: str | None, prompt: str, early_exit_thresh: float, temperature: float, max_new_tokens: int = 4096, ) -> dict: """ Generate response with adaptive inference. Args: media_input: Path to media file prompt: Text prompt early_exit_thresh: Confidence threshold for early exit temperature: Sampling temperature max_new_tokens: Maximum tokens to generate Returns: Dictionary containing response and metadata """ # Prepare message message = [{"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE}] content_parts = [] # Process media input if media_input is not None: media_type = detect_media_type(media_input) if media_type == "video": video_dict = process_video(media_input) if video_dict: content_parts.append(video_dict) elif media_type == "image": image_dict = process_image(media_input) if image_dict: content_parts.append(image_dict) # Add text prompt content_parts.append({"type": "text", "text": prompt}) message.append({"role": "user", "content": content_parts}) # Apply chat template text = processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True) # Process vision inputs image_inputs, video_inputs, video_kwargs = process_vision_info( [message], image_patch_size=16, return_video_kwargs=True, return_video_metadata=True, ) if video_inputs is not None: video_inputs, video_metadatas = zip(*video_inputs) video_inputs = list(video_inputs) video_metadatas = list(video_metadatas) else: video_metadatas = None # Prepare inputs inputs = processor( text=text, images=image_inputs, videos=video_inputs, video_metadata=video_metadatas, do_resize=False, padding=True, return_tensors="pt", **video_kwargs, ) inputs = inputs.to(device) # Generation configuration gen_kwargs = { "max_new_tokens": max_new_tokens, "temperature": temperature if temperature > 0 else None, "do_sample": temperature > 0, "top_p": 0.9 if temperature > 0 else None, "num_beams": 1, "use_cache": True, "return_dict_in_generate": True, "output_scores": True, } # Generate response with torch.no_grad(): gen_out = model.generate( **inputs, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, **gen_kwargs, ) # Decode output generated_ids = gen_out.sequences[0][len(inputs.input_ids[0]) :] answer = processor.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Compute confidence first_box_probs = compute_first_boxed_answer_probs( b=0, gen_ids=generated_ids, gen_out=gen_out, ans=answer, task="", tokenizer=tokenizer, ) # Parse response first_answer = answer.split("")[0] second_answer = answer.split("")[-1] if "" in answer else first_answer reasoning = answer.split("")[-1].split("")[0] if "" in answer else "N/A" # Determine inference mode if first_box_probs >= early_exit_thresh: need_cot = False reasoning = False else: need_cot = True return { "full_response": answer, "first_answer": first_answer, "confidence": f"{first_box_probs:.4f}", "need_cot": need_cot, "reasoning": reasoning, "second_answer": second_answer, } # ============================================================================ # Gradio Callback Functions # ============================================================================ def update_preview(file_path: str | None): """Update preview widgets based on media type.""" mtype = detect_media_type(file_path) if mtype == "image": return ( gr.update(value=file_path, visible=True), # image_preview gr.update(value=None, visible=False), # video_preview ) elif mtype == "video": return ( gr.update(value=None, visible=False), # image_preview gr.update(value=file_path, visible=True), # video_preview ) else: return ( gr.update(value=None, visible=False), gr.update(value=None, visible=False), ) def chat_generate( media_path, user_text, messages_state, chatbot_state, last_media_state, early_exit_thresh, temperature, ): """Handle chat message generation.""" if user_text is None or str(user_text).strip() == "": raise gr.Error("Chat message cannot be empty.") # Clear history if media changed if ( (media_path is not None) and (last_media_state is not None) and (os.path.basename(media_path) != os.path.basename(last_media_state)) ): messages_state = [] chatbot_state = [] # Initialize system prompt if len(messages_state) == 0: messages_state.append({"role": "system", "content": COT_SYSTEM_PROMPT_ANSWER_TWICE}) # Prepare user message content_parts = [] if media_path is not None: mtype = detect_media_type(media_path) if mtype == "video": vd = process_video(media_path) if vd: content_parts.append(vd) elif mtype == "image": imd = process_image(media_path) if imd: content_parts.append(imd) content_parts.append({"type": "text", "text": user_text}) messages_state.append({"role": "user", "content": content_parts}) # Generate response result = generate(media_path, user_text, early_exit_thresh, temperature) # Format assistant response first_ans = (result.get("first_answer") or "").strip() conf = result.get("confidence", "N/A") need_cot = result.get("need_cot", "") reasoning = result.get("reasoning", "") final_ans = (result.get("second_answer") or "").strip() if need_cot: decision_prompt = f"Continue CoT Reasoning (confidence = {conf})" else: decision_prompt = f"Early Exit (confidence = {conf})" assistant_display_1 = f"**Initial Answer:**\n{first_ans}\n\n" f"**{decision_prompt}**\n\n" # Update state messages_state.append({"role": "assistant", "content": assistant_display_1}) chatbot_state.append({"role": "user", "content": user_text}) chatbot_state.append({"role": "assistant", "content": assistant_display_1}) if need_cot: assistant_display_2 = ( f"\n\n****\n\n{reasoning}\n****\n\n" f"**Reviewed Answer:**\n{final_ans}\n\n" ) messages_state.append({"role": "assistant", "content": assistant_display_2}) chatbot_state.append({"role": "assistant", "content": assistant_display_2}) # Disable textbox and send button after generation to prevent interleaved conversation return ( messages_state, chatbot_state, media_path, gr.update(value="", interactive=False), # Disable and clear textbox gr.update(interactive=False), # Disable send button ) def clear_history(): """Clear all chat history and reset interface.""" return ( [], # messages_state [], # chatbot_state None, # last_media_state gr.update(value=None), # file gr.update(value=None, visible=False), # image_preview gr.update(value=None, visible=False), # video_preview gr.update(value="", interactive=True), # Re-enable and clear textbox gr.update(interactive=True), # Re-enable send button ) # ============================================================================ # Example Data # ============================================================================ EXAMPLES = [ [ "assets/yt--MAYaJ5cyOE_70.mp4", "Question: Which one of these descriptions correctly matches the actions in the video?\nOptions:\n(A) officiating\n(B) skating\n(C) stopping\n(D) playing sports\nPut your final answer in \\boxed{}.", # GT is B ], [ "assets/validation_Finance_2.mp4", "Using the Arbitrage Pricing Theory model shown above, calculate the expected return E(rp) if the risk-free rate increases to 5%. All other risk premiums (RP) and beta (\\beta) values remain unchanged.\nOptions:\nA. 13.4%\nB. 14.8%\nC. 15.6%\nD. 16.1%\nE. 16.5%\nF. 16.9%\nG. 17.5%\nH. 17.8%\nI. 17.2%\nJ. 18.1%\nPut your final answer in \\boxed{}.", # GT is I ], [ "assets/M3CoT-25169-0.png", "Within the image, you'll notice several purchased items. And we assume that the water temperature is 4 ° C at this time.\nWithin the image, can you identify the count of items among the provided options that will go below the waterline?\nA. 0\nB. 1\nC. 2\nD. 3\nPut your final answer in \\boxed{}.", # GT is B ], [ None, "Determine the value of the parameter $m$ such that the equation $(m-2)x^2 + (m^2-4m+3)x - (6m^2-2) = 0$ has real solutions, and the sum of the cubes of these solutions is equal to zero.\nPut your final answer in \\boxed{}.", # GT is 3 ], ] # ============================================================================ # Gradio Interface # ============================================================================ demo = gr.Blocks(title="VideoAuto-R1 Demo") with demo: gr.Markdown("# [VideoAuto-R1 Demo](https://github.com/IVUL-KAUST/VideoAuto-R1/)") # Display system prompt with gr.Accordion("System Prompt", open=False): gr.Markdown(f"```\n{COT_SYSTEM_PROMPT_ANSWER_TWICE}\n```") # State variables messages_state = gr.State([]) chatbot_state = gr.State([]) last_media_state = gr.State(None) with gr.Row(): # Left column: Media input and settings with gr.Column(scale=3): media_input = gr.File( label="Upload Image or Video", file_types=["image", "video"], type="filepath", ) image_preview = gr.Image(label="Image Preview", visible=False) video_preview = gr.Video(label="Video Preview", visible=False) with gr.Accordion("Advanced Settings", open=True): early_exit_thresh = gr.Slider( minimum=0.0, maximum=1.0, value=0.98, step=0.01, label="Early Exit Threshold", ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Temperature", ) # Right column: Chat interface with gr.Column(scale=7): chatbot = gr.Chatbot( label="Chat", elem_id="chatbot", height=600, sanitize_html=False, ) textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", lines=2, ) with gr.Row(): send_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear") gr.Markdown("Please click the **Clear** button before starting a new conversation or trying a new example.") # Event handlers media_input.change( fn=update_preview, inputs=[media_input], outputs=[image_preview, video_preview], ) # Send button click: generate response and disable input controls send_btn.click( fn=chat_generate, inputs=[ media_input, textbox, messages_state, chatbot_state, last_media_state, early_exit_thresh, temperature, ], outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn], ).then( fn=lambda cs: cs, inputs=[chatbot_state], outputs=[chatbot], ) # Textbox submit: generate response and disable input controls textbox.submit( fn=chat_generate, inputs=[ media_input, textbox, messages_state, chatbot_state, last_media_state, early_exit_thresh, temperature, ], outputs=[messages_state, chatbot_state, last_media_state, textbox, send_btn], ).then( fn=lambda cs: cs, inputs=[chatbot_state], outputs=[chatbot], ) # Clear button: reset all states and re-enable input controls clear_btn.click( fn=clear_history, inputs=[], outputs=[ messages_state, chatbot_state, last_media_state, media_input, image_preview, video_preview, textbox, send_btn, ], ).then( fn=lambda cs: cs, inputs=[chatbot_state], outputs=[chatbot], ) examples_ds = gr.Dataset( components=[media_input, textbox], samples=EXAMPLES, label="Examples", type="index", # important: pass selected row index to fn ) def load_example(idx: int | None): # idx can be None when deselecting if idx is None: # just clear everything return clear_history() media, text = EXAMPLES[idx][0], EXAMPLES[idx][1] # 1) clear all states + re-enable inputs ms, cs, last, file_u, img_u, vid_u, tb_u, send_u = clear_history() # 2) set selected example values file_u = gr.update(value=media) tb_u = gr.update(value=text, interactive=True) send_u = gr.update(interactive=True) # 3) update preview explicitly (don't rely on File.change always firing) img_u, vid_u = update_preview(media) # 4) optionally set last_media_state to current media last = media return ms, cs, last, file_u, img_u, vid_u, tb_u, send_u examples_ds.select( fn=load_example, inputs=[examples_ds], outputs=[ messages_state, chatbot_state, last_media_state, media_input, image_preview, video_preview, textbox, send_btn, ], ).then( fn=lambda cs: cs, inputs=[chatbot_state], outputs=[chatbot], ) # Launch demo demo.launch( share=True, server_name="0.0.0.0", server_port=7860, allowed_paths=["assets"], debug=True, css=CUSTOM_CSS, )