Spaces:
Running on Zero
Running on Zero
| """ | |
| MedGRPO Demo — Medical Video Understanding with uAI-NEXUS-MedVLM-1.0b-4B-RL | |
| A Gradio demo showcasing the 4B Qwen3-VL based MedGRPO RL model on MedVidBench | |
| across 8 medical video understanding tasks. Includes pre-computed examples | |
| and live inference. | |
| """ | |
| import json | |
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| ROOT = Path(__file__).parent | |
| MODEL_ID = "UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL" | |
| # ── Load examples ───────────────────────────────────────────────────────────── | |
| with open(ROOT / "examples.json") as f: | |
| EXAMPLES = json.load(f) | |
| with open(ROOT / "live_examples.json") as f: | |
| LIVE_EXAMPLES = json.load(f) | |
| TASK_INFO = { | |
| "Temporal Action Localization": { | |
| "icon": "\u23f1\ufe0f", | |
| "short": "TAL", | |
| "desc": "Identify when specific surgical actions occur in the video (start\u2013end times).", | |
| }, | |
| "Spatiotemporal Grounding": { | |
| "icon": "\U0001f4cd", | |
| "short": "STG", | |
| "desc": "Locate instruments or anatomy in both space (bounding boxes) and time.", | |
| }, | |
| "Dense Captioning": { | |
| "icon": "\U0001f4dd", | |
| "short": "DC", | |
| "desc": "Generate detailed, time-stamped descriptions of each action segment.", | |
| }, | |
| "Next Action Prediction": { | |
| "icon": "\U0001f52e", | |
| "short": "NAP", | |
| "desc": "Predict the next procedural step given the current video context.", | |
| }, | |
| "Video Summary": { | |
| "icon": "\U0001f4cb", | |
| "short": "VS", | |
| "desc": "Produce a concise summary of the entire surgical procedure shown.", | |
| }, | |
| "Region Caption": { | |
| "icon": "\U0001f50d", | |
| "short": "RC", | |
| "desc": "Describe the activity of a specific instrument or region across the clip.", | |
| }, | |
| "CVS Assessment": { | |
| "icon": "\u2705", | |
| "short": "CVS", | |
| "desc": "Score the three Critical View of Safety criteria for cholecystectomy.", | |
| }, | |
| "Skill Assessment": { | |
| "icon": "\U0001f3af", | |
| "short": "SA", | |
| "desc": "Rate surgical skill on multiple dimensions (1\u20135 scale).", | |
| }, | |
| } | |
| TASKS = list(TASK_INFO.keys()) | |
| # ── Model loading (lazy, cached) ────────────────────────────────────────────── | |
| _model = None | |
| _processor = None | |
| def get_model_and_processor(): | |
| global _model, _processor | |
| if _model is None: | |
| from transformers import AutoProcessor, Qwen3VLForConditionalGeneration | |
| hf_token = os.environ.get("HF_TOKEN") | |
| print(f"[MedGRPO] Loading model from {MODEL_ID}...") | |
| print(f"[MedGRPO] HF_TOKEN present: {hf_token is not None and len(hf_token) > 0}") | |
| _processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, token=hf_token | |
| ) | |
| _model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| token=hf_token, | |
| ) | |
| _model.eval() | |
| print("[MedGRPO] Model loaded.") | |
| return _model, _processor | |
| # ── Examples tab helpers ────────────────────────────────────────────────────── | |
| def get_examples_for_task(task_name: str) -> list[dict]: | |
| return [ex for ex in EXAMPLES if ex["task"] == task_name] | |
| def load_example(task_name: str, example_idx: int): | |
| task_examples = get_examples_for_task(task_name) | |
| if not task_examples or example_idx >= len(task_examples): | |
| return [], "", "", "", "" | |
| ex = task_examples[example_idx] | |
| images = [str(ROOT / fp) for fp in ex["frames"] if (ROOT / fp).exists()] | |
| info = ( | |
| f"**Task:** {ex['task']} \n" | |
| f"**Data Source:** {ex['data_source']} \n" | |
| f"**Original Frames:** {ex['n_original_frames']} " | |
| f"(showing {len(images)} sampled)" | |
| ) | |
| return images, ex["question"], ex["ground_truth"], ex["prediction"], info | |
| def on_task_change(task_name): | |
| task_examples = get_examples_for_task(task_name) | |
| choices = [f"Example {i+1}" for i in range(len(task_examples))] | |
| images, question, gt, pred, info = load_example(task_name, 0) | |
| task_meta = TASK_INFO[task_name] | |
| desc = f"### {task_meta['icon']} {task_name} ({task_meta['short']})\n{task_meta['desc']}" | |
| return ( | |
| gr.update(choices=choices, value=choices[0] if choices else None), | |
| images, | |
| question, | |
| gt, | |
| pred, | |
| info, | |
| desc, | |
| ) | |
| def on_example_change(task_name, example_choice): | |
| if not example_choice: | |
| return [], "", "", "", "" | |
| idx = int(example_choice.split()[-1]) - 1 | |
| return load_example(task_name, idx) | |
| # ── Live inference ──────────────────────────────────────────────────────────── | |
| MAX_FRAMES = 60 # 1fps × 60s = up to 60 frames | |
| SAMPLE_FPS = 1.0 # Real wall-clock fps the extracted frames represent. Passed | |
| # to the processor via VideoMetadata so the model's emitted | |
| # timestamps match real time. Without it, Qwen3-VL defaults to | |
| # fps=24 and compresses the timeline (e.g. 0.0–1.1s for a 27s clip). | |
| def make_load_live_example(example_idx): | |
| """Create a loader function for a specific live example index.""" | |
| def _load(): | |
| ex = LIVE_EXAMPLES[example_idx] | |
| images = [] | |
| for fp in ex["frames"]: | |
| full = ROOT / fp | |
| if full.exists(): | |
| images.append(Image.open(str(full)).convert("RGB")) | |
| return images, ex["question"] | |
| return _load | |
| def extract_frames_1fps(video_path: str, max_frames: int = MAX_FRAMES) -> list: | |
| """Extract frames at 1fps from a video, up to max_frames.""" | |
| import cv2 | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps <= 0: | |
| fps = 30.0 | |
| frame_interval = max(1, round(fps)) # skip frames to get ~1fps | |
| frames = [] | |
| frame_idx = 0 | |
| while cap.isOpened() and len(frames) < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_idx % frame_interval == 0: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| frame_idx += 1 | |
| cap.release() | |
| return frames | |
| def run_inference(video_file, uploaded_images, question, max_tokens): | |
| """Run model inference on uploaded video or images.""" | |
| import traceback | |
| if not question or not question.strip(): | |
| return "Please enter a question." | |
| try: | |
| # Collect frames from video or uploaded images | |
| frames = [] | |
| if video_file is not None: | |
| frames = extract_frames_1fps(video_file, MAX_FRAMES) | |
| elif uploaded_images is not None and len(uploaded_images) > 0: | |
| for item in uploaded_images: | |
| # Gallery returns different formats depending on Gradio version | |
| path = None | |
| if isinstance(item, str): | |
| path = item | |
| elif isinstance(item, tuple): | |
| path = item[0] | |
| elif isinstance(item, dict): | |
| path = item.get("name") or item.get("path") or item.get("url") | |
| elif isinstance(item, Image.Image): | |
| frames.append(item) | |
| continue | |
| if path: | |
| frames.append(Image.open(path).convert("RGB")) | |
| if not frames: | |
| return "Please upload a video or images." | |
| print(f"[MedGRPO] Collected {len(frames)} frames") | |
| model, processor = get_model_and_processor() | |
| # ZeroGPU provides GPU only inside @spaces.GPU — move model here | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| print(f"[MedGRPO] Model on {device}") | |
| # Build chat prompt | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "video", "video": frames}, | |
| {"type": "text", "text": question.strip()}, | |
| ], | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| print(f"[MedGRPO] Chat template applied, prompt length: {len(text)} chars") | |
| # When you pass pre-sampled PIL frames to Qwen3-VL, the processor cannot | |
| # infer the source fps and silently defaults to fps=24 — which compresses | |
| # the model's perceived timeline (e.g. 0.0–1.1s for a 27s clip). Pass | |
| # explicit VideoMetadata + do_sample_frames=False so the time grid matches | |
| # the real sampling rate. | |
| from transformers.video_utils import VideoMetadata | |
| n_frames = len(frames) | |
| first_w, first_h = frames[0].size | |
| video_meta = VideoMetadata( | |
| total_num_frames=n_frames, | |
| fps=SAMPLE_FPS, | |
| width=first_w, | |
| height=first_h, | |
| duration=float(n_frames) / SAMPLE_FPS, | |
| video_backend="opencv", | |
| frames_indices=list(range(n_frames)), | |
| ) | |
| # Use processor() to handle text tokenization + video processing together. | |
| # This correctly expands <|video_pad|> placeholders in input_ids to match | |
| # the number of visual patches — separate tokenizer + video_processor calls | |
| # would produce mismatched input_ids (no placeholder expansion). | |
| inputs = processor( | |
| text=[text], | |
| videos=[frames], | |
| video_metadata=[video_meta], | |
| do_sample_frames=False, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| print(f"[MedGRPO] Processed inputs keys: {list(inputs.keys())}") | |
| print(f"[MedGRPO] input_ids shape: {inputs['input_ids'].shape}") | |
| # Build generate() kwargs | |
| # Strip video_metadata — generate() doesn't accept it as an input | |
| gen_kwargs = {} | |
| for key, value in inputs.items(): | |
| if key == "video_metadata": | |
| continue | |
| if key == "second_per_grid_ts": | |
| gen_kwargs[key] = value if isinstance(value, list) else value.tolist() | |
| elif isinstance(value, torch.Tensor): | |
| if torch.is_floating_point(value): | |
| value = value.to(model.dtype) | |
| gen_kwargs[key] = value.to(device) | |
| else: | |
| gen_kwargs[key] = value | |
| gen_kwargs["max_new_tokens"] = int(max_tokens) | |
| gen_kwargs["do_sample"] = False | |
| print(f"[MedGRPO] Starting generation...") | |
| with torch.inference_mode(): | |
| generated_ids = model.generate(**gen_kwargs) | |
| print(f"[MedGRPO] Generated {generated_ids.shape[1]} tokens") | |
| output_ids = generated_ids[:, inputs["input_ids"].shape[1]:] | |
| response = processor.batch_decode( | |
| output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| print(f"[MedGRPO] Response: {response[:200]}...") | |
| return response | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"Error: {e}" | |
| # ── Build UI ────────────────────────────────────────────────────────────────── | |
| TITLE = "MedGRPO Demo — Medical Video Understanding" | |
| DESCRIPTION = """\ | |
| This demo runs **[uAI-NEXUS-MedVLM-1.0b-4B-RL](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL)** \ | |
| (base: Qwen3-VL-4B), part of the uAI-NEXUS-MedVLM 1.0 family trained with SFT + MedGRPO on \ | |
| [MedVidBench](https://huggingface.co/datasets/UII-AI/MedVidBench), \ | |
| for medical video question answering across **8 tasks**: temporal reasoning, \ | |
| spatial grounding, captioning, and clinical assessment. Sibling release: \ | |
| [uAI-NEXUS-MedVLM-1.0a-7B-RL](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0a-7B-RL) (Qwen2.5-VL-7B base). | |
| 📄 [Paper](https://arxiv.org/abs/2512.06581) 🌐 [Project Page](https://uii-ai.github.io/MedGRPO/) 💾 [Dataset](https://huggingface.co/datasets/UII-AI/MedVidBench) 🤖 [Model](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL) 💻 [GitHub](https://github.com/UII-AI/MedGRPO-Code) 📊 [Leaderboard](https://huggingface.co/spaces/UII-AI/MedVidBench-Leaderboard) | |
| """ | |
| CSS = """ | |
| .output-box { min-height: 120px; } | |
| #gallery { height: 380px !important; } | |
| .example-card-img { cursor: pointer !important; } | |
| .example-card-img:hover { opacity: 0.8; } | |
| """ | |
| with gr.Blocks( | |
| title="MedGRPO Demo", | |
| css=CSS, | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), | |
| ) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| # ── Tab 1: Pre-computed Examples ── | |
| with gr.TabItem("Examples"): | |
| gr.Markdown( | |
| "> Browse pre-computed predictions from the test set " | |
| "(no GPU needed)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Select Task") | |
| task_radio = gr.Radio( | |
| choices=TASKS, value=TASKS[0], label="Task", interactive=True | |
| ) | |
| task_desc = gr.Markdown( | |
| f"### {TASK_INFO[TASKS[0]]['icon']} {TASKS[0]} " | |
| f"({TASK_INFO[TASKS[0]]['short']})\n" | |
| f"{TASK_INFO[TASKS[0]]['desc']}" | |
| ) | |
| example_dropdown = gr.Dropdown( | |
| choices=["Example 1", "Example 2"], | |
| value="Example 1", | |
| label="Choose Example", | |
| interactive=True, | |
| ) | |
| info_md = gr.Markdown("") | |
| with gr.Column(scale=3): | |
| gallery = gr.Gallery( | |
| label="Video Frames", | |
| columns=4, | |
| rows=2, | |
| height=380, | |
| object_fit="contain", | |
| elem_id="gallery", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| question_box = gr.Textbox( | |
| label="Question", | |
| lines=4, | |
| interactive=False, | |
| elem_classes="output-box", | |
| ) | |
| with gr.Column(): | |
| gt_box = gr.Textbox( | |
| label="Ground Truth", | |
| lines=4, | |
| interactive=False, | |
| elem_classes="output-box", | |
| ) | |
| with gr.Column(): | |
| pred_box = gr.Textbox( | |
| label="Model Prediction", | |
| lines=4, | |
| interactive=False, | |
| elem_classes="output-box", | |
| ) | |
| task_radio.change( | |
| fn=on_task_change, | |
| inputs=[task_radio], | |
| outputs=[ | |
| example_dropdown, | |
| gallery, | |
| question_box, | |
| gt_box, | |
| pred_box, | |
| info_md, | |
| task_desc, | |
| ], | |
| ) | |
| example_dropdown.change( | |
| fn=on_example_change, | |
| inputs=[task_radio, example_dropdown], | |
| outputs=[gallery, question_box, gt_box, pred_box, info_md], | |
| ) | |
| demo.load( | |
| fn=on_task_change, | |
| inputs=[task_radio], | |
| outputs=[ | |
| example_dropdown, | |
| gallery, | |
| question_box, | |
| gt_box, | |
| pred_box, | |
| info_md, | |
| task_desc, | |
| ], | |
| ) | |
| # ── Tab 2: Live Inference ── | |
| with gr.TabItem("Live Inference"): | |
| gr.Markdown( | |
| "> Upload a medical video or frames and ask a question, " | |
| "or try a pre-loaded example. " | |
| "The model runs on ZeroGPU (may take 30\u201360s on first load)." | |
| ) | |
| # Example cards - clickable thumbnails | |
| gr.Markdown("**Try a Pre-loaded Example** (click a card below):") | |
| with gr.Row(equal_height=True): | |
| example_btns = [] | |
| for i, ex in enumerate(LIVE_EXAMPLES): | |
| thumb = ROOT / ex["frames"][0] | |
| task_label = ex["task"].replace("_", " ").title() | |
| with gr.Column(min_width=180): | |
| img = gr.Image( | |
| value=str(thumb) if thumb.exists() else None, | |
| label=f"{task_label} ({ex['data_source']}, {ex['n_frames']}f)", | |
| height=160, | |
| interactive=False, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| elem_classes="example-card-img", | |
| ) | |
| btn = gr.Button( | |
| f"Load {task_label}", | |
| size="sm", | |
| variant="secondary", | |
| ) | |
| example_btns.append((i, img, btn)) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| video_input = gr.Video(label="Upload Video (mp4)") | |
| frame_preview = gr.Gallery( | |
| label="Loaded Frames", | |
| columns=8, | |
| rows=2, | |
| height=200, | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=1): | |
| infer_question = gr.Textbox( | |
| label="Question", | |
| lines=5, | |
| placeholder="e.g., What surgical actions are being performed?", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=32, | |
| maximum=512, | |
| value=256, | |
| step=32, | |
| label="Max Response Length (words)", | |
| ) | |
| infer_btn = gr.Button( | |
| "Run Inference", variant="primary", size="lg" | |
| ) | |
| infer_output = gr.Textbox( | |
| label="Model Response", | |
| lines=8, | |
| interactive=False, | |
| ) | |
| for idx, img, btn in example_btns: | |
| loader = make_load_live_example(idx) | |
| btn.click(fn=loader, inputs=[], outputs=[frame_preview, infer_question]) | |
| img.select(fn=loader, inputs=[], outputs=[frame_preview, infer_question]) | |
| infer_btn.click( | |
| fn=run_inference, | |
| inputs=[video_input, frame_preview, infer_question, max_tokens], | |
| outputs=[infer_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |