""" MedGRPO Demo — Medical Video Understanding with uAI-NEXUS-MedVLM-1.0b-4B-RL A Gradio demo showcasing the 4B Qwen3-VL based MedGRPO RL model on MedVidBench across 8 medical video understanding tasks. Includes pre-computed examples and live inference. """ import json import os from pathlib import Path import gradio as gr import spaces import torch from PIL import Image ROOT = Path(__file__).parent MODEL_ID = "UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL" # ── Load examples ───────────────────────────────────────────────────────────── with open(ROOT / "examples.json") as f: EXAMPLES = json.load(f) with open(ROOT / "live_examples.json") as f: LIVE_EXAMPLES = json.load(f) TASK_INFO = { "Temporal Action Localization": { "icon": "\u23f1\ufe0f", "short": "TAL", "desc": "Identify when specific surgical actions occur in the video (start\u2013end times).", }, "Spatiotemporal Grounding": { "icon": "\U0001f4cd", "short": "STG", "desc": "Locate instruments or anatomy in both space (bounding boxes) and time.", }, "Dense Captioning": { "icon": "\U0001f4dd", "short": "DC", "desc": "Generate detailed, time-stamped descriptions of each action segment.", }, "Next Action Prediction": { "icon": "\U0001f52e", "short": "NAP", "desc": "Predict the next procedural step given the current video context.", }, "Video Summary": { "icon": "\U0001f4cb", "short": "VS", "desc": "Produce a concise summary of the entire surgical procedure shown.", }, "Region Caption": { "icon": "\U0001f50d", "short": "RC", "desc": "Describe the activity of a specific instrument or region across the clip.", }, "CVS Assessment": { "icon": "\u2705", "short": "CVS", "desc": "Score the three Critical View of Safety criteria for cholecystectomy.", }, "Skill Assessment": { "icon": "\U0001f3af", "short": "SA", "desc": "Rate surgical skill on multiple dimensions (1\u20135 scale).", }, } TASKS = list(TASK_INFO.keys()) # ── Model loading (lazy, cached) ────────────────────────────────────────────── _model = None _processor = None def get_model_and_processor(): global _model, _processor if _model is None: from transformers import AutoProcessor, Qwen3VLForConditionalGeneration hf_token = os.environ.get("HF_TOKEN") print(f"[MedGRPO] Loading model from {MODEL_ID}...") print(f"[MedGRPO] HF_TOKEN present: {hf_token is not None and len(hf_token) > 0}") _processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, token=hf_token ) _model = Qwen3VLForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, attn_implementation="sdpa", trust_remote_code=True, token=hf_token, ) _model.eval() print("[MedGRPO] Model loaded.") return _model, _processor # ── Examples tab helpers ────────────────────────────────────────────────────── def get_examples_for_task(task_name: str) -> list[dict]: return [ex for ex in EXAMPLES if ex["task"] == task_name] def load_example(task_name: str, example_idx: int): task_examples = get_examples_for_task(task_name) if not task_examples or example_idx >= len(task_examples): return [], "", "", "", "" ex = task_examples[example_idx] images = [str(ROOT / fp) for fp in ex["frames"] if (ROOT / fp).exists()] info = ( f"**Task:** {ex['task']} \n" f"**Data Source:** {ex['data_source']} \n" f"**Original Frames:** {ex['n_original_frames']} " f"(showing {len(images)} sampled)" ) return images, ex["question"], ex["ground_truth"], ex["prediction"], info def on_task_change(task_name): task_examples = get_examples_for_task(task_name) choices = [f"Example {i+1}" for i in range(len(task_examples))] images, question, gt, pred, info = load_example(task_name, 0) task_meta = TASK_INFO[task_name] desc = f"### {task_meta['icon']} {task_name} ({task_meta['short']})\n{task_meta['desc']}" return ( gr.update(choices=choices, value=choices[0] if choices else None), images, question, gt, pred, info, desc, ) def on_example_change(task_name, example_choice): if not example_choice: return [], "", "", "", "" idx = int(example_choice.split()[-1]) - 1 return load_example(task_name, idx) # ── Live inference ──────────────────────────────────────────────────────────── MAX_FRAMES = 60 # 1fps × 60s = up to 60 frames SAMPLE_FPS = 1.0 # Real wall-clock fps the extracted frames represent. Passed # to the processor via VideoMetadata so the model's emitted # timestamps match real time. Without it, Qwen3-VL defaults to # fps=24 and compresses the timeline (e.g. 0.0–1.1s for a 27s clip). def make_load_live_example(example_idx): """Create a loader function for a specific live example index.""" def _load(): ex = LIVE_EXAMPLES[example_idx] images = [] for fp in ex["frames"]: full = ROOT / fp if full.exists(): images.append(Image.open(str(full)).convert("RGB")) return images, ex["question"] return _load def extract_frames_1fps(video_path: str, max_frames: int = MAX_FRAMES) -> list: """Extract frames at 1fps from a video, up to max_frames.""" import cv2 cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) if fps <= 0: fps = 30.0 frame_interval = max(1, round(fps)) # skip frames to get ~1fps frames = [] frame_idx = 0 while cap.isOpened() and len(frames) < max_frames: ret, frame = cap.read() if not ret: break if frame_idx % frame_interval == 0: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) frame_idx += 1 cap.release() return frames @spaces.GPU(duration=120) def run_inference(video_file, uploaded_images, question, max_tokens): """Run model inference on uploaded video or images.""" import traceback if not question or not question.strip(): return "Please enter a question." try: # Collect frames from video or uploaded images frames = [] if video_file is not None: frames = extract_frames_1fps(video_file, MAX_FRAMES) elif uploaded_images is not None and len(uploaded_images) > 0: for item in uploaded_images: # Gallery returns different formats depending on Gradio version path = None if isinstance(item, str): path = item elif isinstance(item, tuple): path = item[0] elif isinstance(item, dict): path = item.get("name") or item.get("path") or item.get("url") elif isinstance(item, Image.Image): frames.append(item) continue if path: frames.append(Image.open(path).convert("RGB")) if not frames: return "Please upload a video or images." print(f"[MedGRPO] Collected {len(frames)} frames") model, processor = get_model_and_processor() # ZeroGPU provides GPU only inside @spaces.GPU — move model here device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print(f"[MedGRPO] Model on {device}") # Build chat prompt messages = [ { "role": "user", "content": [ {"type": "video", "video": frames}, {"type": "text", "text": question.strip()}, ], } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) print(f"[MedGRPO] Chat template applied, prompt length: {len(text)} chars") # When you pass pre-sampled PIL frames to Qwen3-VL, the processor cannot # infer the source fps and silently defaults to fps=24 — which compresses # the model's perceived timeline (e.g. 0.0–1.1s for a 27s clip). Pass # explicit VideoMetadata + do_sample_frames=False so the time grid matches # the real sampling rate. from transformers.video_utils import VideoMetadata n_frames = len(frames) first_w, first_h = frames[0].size video_meta = VideoMetadata( total_num_frames=n_frames, fps=SAMPLE_FPS, width=first_w, height=first_h, duration=float(n_frames) / SAMPLE_FPS, video_backend="opencv", frames_indices=list(range(n_frames)), ) # Use processor() to handle text tokenization + video processing together. # This correctly expands <|video_pad|> placeholders in input_ids to match # the number of visual patches — separate tokenizer + video_processor calls # would produce mismatched input_ids (no placeholder expansion). inputs = processor( text=[text], videos=[frames], video_metadata=[video_meta], do_sample_frames=False, padding=True, return_tensors="pt", ) print(f"[MedGRPO] Processed inputs keys: {list(inputs.keys())}") print(f"[MedGRPO] input_ids shape: {inputs['input_ids'].shape}") # Build generate() kwargs # Strip video_metadata — generate() doesn't accept it as an input gen_kwargs = {} for key, value in inputs.items(): if key == "video_metadata": continue if key == "second_per_grid_ts": gen_kwargs[key] = value if isinstance(value, list) else value.tolist() elif isinstance(value, torch.Tensor): if torch.is_floating_point(value): value = value.to(model.dtype) gen_kwargs[key] = value.to(device) else: gen_kwargs[key] = value gen_kwargs["max_new_tokens"] = int(max_tokens) gen_kwargs["do_sample"] = False print(f"[MedGRPO] Starting generation...") with torch.inference_mode(): generated_ids = model.generate(**gen_kwargs) print(f"[MedGRPO] Generated {generated_ids.shape[1]} tokens") output_ids = generated_ids[:, inputs["input_ids"].shape[1]:] response = processor.batch_decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] print(f"[MedGRPO] Response: {response[:200]}...") return response except Exception as e: traceback.print_exc() return f"Error: {e}" # ── Build UI ────────────────────────────────────────────────────────────────── TITLE = "MedGRPO Demo — Medical Video Understanding" DESCRIPTION = """\ This demo runs **[uAI-NEXUS-MedVLM-1.0b-4B-RL](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL)** \ (base: Qwen3-VL-4B), part of the uAI-NEXUS-MedVLM 1.0 family trained with SFT + MedGRPO on \ [MedVidBench](https://huggingface.co/datasets/UII-AI/MedVidBench), \ for medical video question answering across **8 tasks**: temporal reasoning, \ spatial grounding, captioning, and clinical assessment. Sibling release: \ [uAI-NEXUS-MedVLM-1.0a-7B-RL](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0a-7B-RL) (Qwen2.5-VL-7B base). 📄 [Paper](https://arxiv.org/abs/2512.06581)   🌐 [Project Page](https://uii-ai.github.io/MedGRPO/)   💾 [Dataset](https://huggingface.co/datasets/UII-AI/MedVidBench)   🤖 [Model](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL)   💻 [GitHub](https://github.com/UII-AI/MedGRPO-Code)   📊 [Leaderboard](https://huggingface.co/spaces/UII-AI/MedVidBench-Leaderboard) """ CSS = """ .output-box { min-height: 120px; } #gallery { height: 380px !important; } .example-card-img { cursor: pointer !important; } .example-card-img:hover { opacity: 0.8; } """ with gr.Blocks( title="MedGRPO Demo", css=CSS, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), ) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) with gr.Tabs(): # ── Tab 1: Pre-computed Examples ── with gr.TabItem("Examples"): gr.Markdown( "> Browse pre-computed predictions from the test set " "(no GPU needed)." ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Select Task") task_radio = gr.Radio( choices=TASKS, value=TASKS[0], label="Task", interactive=True ) task_desc = gr.Markdown( f"### {TASK_INFO[TASKS[0]]['icon']} {TASKS[0]} " f"({TASK_INFO[TASKS[0]]['short']})\n" f"{TASK_INFO[TASKS[0]]['desc']}" ) example_dropdown = gr.Dropdown( choices=["Example 1", "Example 2"], value="Example 1", label="Choose Example", interactive=True, ) info_md = gr.Markdown("") with gr.Column(scale=3): gallery = gr.Gallery( label="Video Frames", columns=4, rows=2, height=380, object_fit="contain", elem_id="gallery", ) with gr.Row(): with gr.Column(): question_box = gr.Textbox( label="Question", lines=4, interactive=False, elem_classes="output-box", ) with gr.Column(): gt_box = gr.Textbox( label="Ground Truth", lines=4, interactive=False, elem_classes="output-box", ) with gr.Column(): pred_box = gr.Textbox( label="Model Prediction", lines=4, interactive=False, elem_classes="output-box", ) task_radio.change( fn=on_task_change, inputs=[task_radio], outputs=[ example_dropdown, gallery, question_box, gt_box, pred_box, info_md, task_desc, ], ) example_dropdown.change( fn=on_example_change, inputs=[task_radio, example_dropdown], outputs=[gallery, question_box, gt_box, pred_box, info_md], ) demo.load( fn=on_task_change, inputs=[task_radio], outputs=[ example_dropdown, gallery, question_box, gt_box, pred_box, info_md, task_desc, ], ) # ── Tab 2: Live Inference ── with gr.TabItem("Live Inference"): gr.Markdown( "> Upload a medical video or frames and ask a question, " "or try a pre-loaded example. " "The model runs on ZeroGPU (may take 30\u201360s on first load)." ) # Example cards - clickable thumbnails gr.Markdown("**Try a Pre-loaded Example** (click a card below):") with gr.Row(equal_height=True): example_btns = [] for i, ex in enumerate(LIVE_EXAMPLES): thumb = ROOT / ex["frames"][0] task_label = ex["task"].replace("_", " ").title() with gr.Column(min_width=180): img = gr.Image( value=str(thumb) if thumb.exists() else None, label=f"{task_label} ({ex['data_source']}, {ex['n_frames']}f)", height=160, interactive=False, show_download_button=False, show_fullscreen_button=False, elem_classes="example-card-img", ) btn = gr.Button( f"Load {task_label}", size="sm", variant="secondary", ) example_btns.append((i, img, btn)) gr.Markdown("---") with gr.Row(): with gr.Column(scale=2): video_input = gr.Video(label="Upload Video (mp4)") frame_preview = gr.Gallery( label="Loaded Frames", columns=8, rows=2, height=200, interactive=False, ) with gr.Column(scale=1): infer_question = gr.Textbox( label="Question", lines=5, placeholder="e.g., What surgical actions are being performed?", ) with gr.Accordion("Advanced Settings", open=False): max_tokens = gr.Slider( minimum=32, maximum=512, value=256, step=32, label="Max Response Length (words)", ) infer_btn = gr.Button( "Run Inference", variant="primary", size="lg" ) infer_output = gr.Textbox( label="Model Response", lines=8, interactive=False, ) for idx, img, btn in example_btns: loader = make_load_live_example(idx) btn.click(fn=loader, inputs=[], outputs=[frame_preview, infer_question]) img.select(fn=loader, inputs=[], outputs=[frame_preview, infer_question]) infer_btn.click( fn=run_inference, inputs=[video_input, frame_preview, infer_question, max_tokens], outputs=[infer_output], ) if __name__ == "__main__": demo.launch()