MedGRPO-Demo / app.py
gaozhongpai's picture
Pass VideoMetadata so timestamps match real time
8b894ad
"""
MedGRPO Demo — Medical Video Understanding with uAI-NEXUS-MedVLM-1.0b-4B-RL
A Gradio demo showcasing the 4B Qwen3-VL based MedGRPO RL model on MedVidBench
across 8 medical video understanding tasks. Includes pre-computed examples
and live inference.
"""
import json
import os
from pathlib import Path
import gradio as gr
import spaces
import torch
from PIL import Image
ROOT = Path(__file__).parent
MODEL_ID = "UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL"
# ── Load examples ─────────────────────────────────────────────────────────────
with open(ROOT / "examples.json") as f:
EXAMPLES = json.load(f)
with open(ROOT / "live_examples.json") as f:
LIVE_EXAMPLES = json.load(f)
TASK_INFO = {
"Temporal Action Localization": {
"icon": "\u23f1\ufe0f",
"short": "TAL",
"desc": "Identify when specific surgical actions occur in the video (start\u2013end times).",
},
"Spatiotemporal Grounding": {
"icon": "\U0001f4cd",
"short": "STG",
"desc": "Locate instruments or anatomy in both space (bounding boxes) and time.",
},
"Dense Captioning": {
"icon": "\U0001f4dd",
"short": "DC",
"desc": "Generate detailed, time-stamped descriptions of each action segment.",
},
"Next Action Prediction": {
"icon": "\U0001f52e",
"short": "NAP",
"desc": "Predict the next procedural step given the current video context.",
},
"Video Summary": {
"icon": "\U0001f4cb",
"short": "VS",
"desc": "Produce a concise summary of the entire surgical procedure shown.",
},
"Region Caption": {
"icon": "\U0001f50d",
"short": "RC",
"desc": "Describe the activity of a specific instrument or region across the clip.",
},
"CVS Assessment": {
"icon": "\u2705",
"short": "CVS",
"desc": "Score the three Critical View of Safety criteria for cholecystectomy.",
},
"Skill Assessment": {
"icon": "\U0001f3af",
"short": "SA",
"desc": "Rate surgical skill on multiple dimensions (1\u20135 scale).",
},
}
TASKS = list(TASK_INFO.keys())
# ── Model loading (lazy, cached) ──────────────────────────────────────────────
_model = None
_processor = None
def get_model_and_processor():
global _model, _processor
if _model is None:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
hf_token = os.environ.get("HF_TOKEN")
print(f"[MedGRPO] Loading model from {MODEL_ID}...")
print(f"[MedGRPO] HF_TOKEN present: {hf_token is not None and len(hf_token) > 0}")
_processor = AutoProcessor.from_pretrained(
MODEL_ID, trust_remote_code=True, token=hf_token
)
_model = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
trust_remote_code=True,
token=hf_token,
)
_model.eval()
print("[MedGRPO] Model loaded.")
return _model, _processor
# ── Examples tab helpers ──────────────────────────────────────────────────────
def get_examples_for_task(task_name: str) -> list[dict]:
return [ex for ex in EXAMPLES if ex["task"] == task_name]
def load_example(task_name: str, example_idx: int):
task_examples = get_examples_for_task(task_name)
if not task_examples or example_idx >= len(task_examples):
return [], "", "", "", ""
ex = task_examples[example_idx]
images = [str(ROOT / fp) for fp in ex["frames"] if (ROOT / fp).exists()]
info = (
f"**Task:** {ex['task']} \n"
f"**Data Source:** {ex['data_source']} \n"
f"**Original Frames:** {ex['n_original_frames']} "
f"(showing {len(images)} sampled)"
)
return images, ex["question"], ex["ground_truth"], ex["prediction"], info
def on_task_change(task_name):
task_examples = get_examples_for_task(task_name)
choices = [f"Example {i+1}" for i in range(len(task_examples))]
images, question, gt, pred, info = load_example(task_name, 0)
task_meta = TASK_INFO[task_name]
desc = f"### {task_meta['icon']} {task_name} ({task_meta['short']})\n{task_meta['desc']}"
return (
gr.update(choices=choices, value=choices[0] if choices else None),
images,
question,
gt,
pred,
info,
desc,
)
def on_example_change(task_name, example_choice):
if not example_choice:
return [], "", "", "", ""
idx = int(example_choice.split()[-1]) - 1
return load_example(task_name, idx)
# ── Live inference ────────────────────────────────────────────────────────────
MAX_FRAMES = 60 # 1fps × 60s = up to 60 frames
SAMPLE_FPS = 1.0 # Real wall-clock fps the extracted frames represent. Passed
# to the processor via VideoMetadata so the model's emitted
# timestamps match real time. Without it, Qwen3-VL defaults to
# fps=24 and compresses the timeline (e.g. 0.0–1.1s for a 27s clip).
def make_load_live_example(example_idx):
"""Create a loader function for a specific live example index."""
def _load():
ex = LIVE_EXAMPLES[example_idx]
images = []
for fp in ex["frames"]:
full = ROOT / fp
if full.exists():
images.append(Image.open(str(full)).convert("RGB"))
return images, ex["question"]
return _load
def extract_frames_1fps(video_path: str, max_frames: int = MAX_FRAMES) -> list:
"""Extract frames at 1fps from a video, up to max_frames."""
import cv2
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0:
fps = 30.0
frame_interval = max(1, round(fps)) # skip frames to get ~1fps
frames = []
frame_idx = 0
while cap.isOpened() and len(frames) < max_frames:
ret, frame = cap.read()
if not ret:
break
if frame_idx % frame_interval == 0:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
frame_idx += 1
cap.release()
return frames
@spaces.GPU(duration=120)
def run_inference(video_file, uploaded_images, question, max_tokens):
"""Run model inference on uploaded video or images."""
import traceback
if not question or not question.strip():
return "Please enter a question."
try:
# Collect frames from video or uploaded images
frames = []
if video_file is not None:
frames = extract_frames_1fps(video_file, MAX_FRAMES)
elif uploaded_images is not None and len(uploaded_images) > 0:
for item in uploaded_images:
# Gallery returns different formats depending on Gradio version
path = None
if isinstance(item, str):
path = item
elif isinstance(item, tuple):
path = item[0]
elif isinstance(item, dict):
path = item.get("name") or item.get("path") or item.get("url")
elif isinstance(item, Image.Image):
frames.append(item)
continue
if path:
frames.append(Image.open(path).convert("RGB"))
if not frames:
return "Please upload a video or images."
print(f"[MedGRPO] Collected {len(frames)} frames")
model, processor = get_model_and_processor()
# ZeroGPU provides GPU only inside @spaces.GPU — move model here
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"[MedGRPO] Model on {device}")
# Build chat prompt
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": frames},
{"type": "text", "text": question.strip()},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print(f"[MedGRPO] Chat template applied, prompt length: {len(text)} chars")
# When you pass pre-sampled PIL frames to Qwen3-VL, the processor cannot
# infer the source fps and silently defaults to fps=24 — which compresses
# the model's perceived timeline (e.g. 0.0–1.1s for a 27s clip). Pass
# explicit VideoMetadata + do_sample_frames=False so the time grid matches
# the real sampling rate.
from transformers.video_utils import VideoMetadata
n_frames = len(frames)
first_w, first_h = frames[0].size
video_meta = VideoMetadata(
total_num_frames=n_frames,
fps=SAMPLE_FPS,
width=first_w,
height=first_h,
duration=float(n_frames) / SAMPLE_FPS,
video_backend="opencv",
frames_indices=list(range(n_frames)),
)
# Use processor() to handle text tokenization + video processing together.
# This correctly expands <|video_pad|> placeholders in input_ids to match
# the number of visual patches — separate tokenizer + video_processor calls
# would produce mismatched input_ids (no placeholder expansion).
inputs = processor(
text=[text],
videos=[frames],
video_metadata=[video_meta],
do_sample_frames=False,
padding=True,
return_tensors="pt",
)
print(f"[MedGRPO] Processed inputs keys: {list(inputs.keys())}")
print(f"[MedGRPO] input_ids shape: {inputs['input_ids'].shape}")
# Build generate() kwargs
# Strip video_metadata — generate() doesn't accept it as an input
gen_kwargs = {}
for key, value in inputs.items():
if key == "video_metadata":
continue
if key == "second_per_grid_ts":
gen_kwargs[key] = value if isinstance(value, list) else value.tolist()
elif isinstance(value, torch.Tensor):
if torch.is_floating_point(value):
value = value.to(model.dtype)
gen_kwargs[key] = value.to(device)
else:
gen_kwargs[key] = value
gen_kwargs["max_new_tokens"] = int(max_tokens)
gen_kwargs["do_sample"] = False
print(f"[MedGRPO] Starting generation...")
with torch.inference_mode():
generated_ids = model.generate(**gen_kwargs)
print(f"[MedGRPO] Generated {generated_ids.shape[1]} tokens")
output_ids = generated_ids[:, inputs["input_ids"].shape[1]:]
response = processor.batch_decode(
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f"[MedGRPO] Response: {response[:200]}...")
return response
except Exception as e:
traceback.print_exc()
return f"Error: {e}"
# ── Build UI ──────────────────────────────────────────────────────────────────
TITLE = "MedGRPO Demo — Medical Video Understanding"
DESCRIPTION = """\
This demo runs **[uAI-NEXUS-MedVLM-1.0b-4B-RL](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL)** \
(base: Qwen3-VL-4B), part of the uAI-NEXUS-MedVLM 1.0 family trained with SFT + MedGRPO on \
[MedVidBench](https://huggingface.co/datasets/UII-AI/MedVidBench), \
for medical video question answering across **8 tasks**: temporal reasoning, \
spatial grounding, captioning, and clinical assessment. Sibling release: \
[uAI-NEXUS-MedVLM-1.0a-7B-RL](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0a-7B-RL) (Qwen2.5-VL-7B base).
📄 [Paper](https://arxiv.org/abs/2512.06581) &nbsp; 🌐 [Project Page](https://uii-ai.github.io/MedGRPO/) &nbsp; 💾 [Dataset](https://huggingface.co/datasets/UII-AI/MedVidBench) &nbsp; 🤖 [Model](https://huggingface.co/UII-AI/uAI-NEXUS-MedVLM-1.0b-4B-RL) &nbsp; 💻 [GitHub](https://github.com/UII-AI/MedGRPO-Code) &nbsp; 📊 [Leaderboard](https://huggingface.co/spaces/UII-AI/MedVidBench-Leaderboard)
"""
CSS = """
.output-box { min-height: 120px; }
#gallery { height: 380px !important; }
.example-card-img { cursor: pointer !important; }
.example-card-img:hover { opacity: 0.8; }
"""
with gr.Blocks(
title="MedGRPO Demo",
css=CSS,
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Tabs():
# ── Tab 1: Pre-computed Examples ──
with gr.TabItem("Examples"):
gr.Markdown(
"> Browse pre-computed predictions from the test set "
"(no GPU needed)."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Select Task")
task_radio = gr.Radio(
choices=TASKS, value=TASKS[0], label="Task", interactive=True
)
task_desc = gr.Markdown(
f"### {TASK_INFO[TASKS[0]]['icon']} {TASKS[0]} "
f"({TASK_INFO[TASKS[0]]['short']})\n"
f"{TASK_INFO[TASKS[0]]['desc']}"
)
example_dropdown = gr.Dropdown(
choices=["Example 1", "Example 2"],
value="Example 1",
label="Choose Example",
interactive=True,
)
info_md = gr.Markdown("")
with gr.Column(scale=3):
gallery = gr.Gallery(
label="Video Frames",
columns=4,
rows=2,
height=380,
object_fit="contain",
elem_id="gallery",
)
with gr.Row():
with gr.Column():
question_box = gr.Textbox(
label="Question",
lines=4,
interactive=False,
elem_classes="output-box",
)
with gr.Column():
gt_box = gr.Textbox(
label="Ground Truth",
lines=4,
interactive=False,
elem_classes="output-box",
)
with gr.Column():
pred_box = gr.Textbox(
label="Model Prediction",
lines=4,
interactive=False,
elem_classes="output-box",
)
task_radio.change(
fn=on_task_change,
inputs=[task_radio],
outputs=[
example_dropdown,
gallery,
question_box,
gt_box,
pred_box,
info_md,
task_desc,
],
)
example_dropdown.change(
fn=on_example_change,
inputs=[task_radio, example_dropdown],
outputs=[gallery, question_box, gt_box, pred_box, info_md],
)
demo.load(
fn=on_task_change,
inputs=[task_radio],
outputs=[
example_dropdown,
gallery,
question_box,
gt_box,
pred_box,
info_md,
task_desc,
],
)
# ── Tab 2: Live Inference ──
with gr.TabItem("Live Inference"):
gr.Markdown(
"> Upload a medical video or frames and ask a question, "
"or try a pre-loaded example. "
"The model runs on ZeroGPU (may take 30\u201360s on first load)."
)
# Example cards - clickable thumbnails
gr.Markdown("**Try a Pre-loaded Example** (click a card below):")
with gr.Row(equal_height=True):
example_btns = []
for i, ex in enumerate(LIVE_EXAMPLES):
thumb = ROOT / ex["frames"][0]
task_label = ex["task"].replace("_", " ").title()
with gr.Column(min_width=180):
img = gr.Image(
value=str(thumb) if thumb.exists() else None,
label=f"{task_label} ({ex['data_source']}, {ex['n_frames']}f)",
height=160,
interactive=False,
show_download_button=False,
show_fullscreen_button=False,
elem_classes="example-card-img",
)
btn = gr.Button(
f"Load {task_label}",
size="sm",
variant="secondary",
)
example_btns.append((i, img, btn))
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=2):
video_input = gr.Video(label="Upload Video (mp4)")
frame_preview = gr.Gallery(
label="Loaded Frames",
columns=8,
rows=2,
height=200,
interactive=False,
)
with gr.Column(scale=1):
infer_question = gr.Textbox(
label="Question",
lines=5,
placeholder="e.g., What surgical actions are being performed?",
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens = gr.Slider(
minimum=32,
maximum=512,
value=256,
step=32,
label="Max Response Length (words)",
)
infer_btn = gr.Button(
"Run Inference", variant="primary", size="lg"
)
infer_output = gr.Textbox(
label="Model Response",
lines=8,
interactive=False,
)
for idx, img, btn in example_btns:
loader = make_load_live_example(idx)
btn.click(fn=loader, inputs=[], outputs=[frame_preview, infer_question])
img.select(fn=loader, inputs=[], outputs=[frame_preview, infer_question])
infer_btn.click(
fn=run_inference,
inputs=[video_input, frame_preview, infer_question, max_tokens],
outputs=[infer_output],
)
if __name__ == "__main__":
demo.launch()