MedGRPO-Demo / prepare_examples.py
gaozhongpai's picture
Add test_idx to all examples for traceability
43c5604
#!/usr/bin/env python3
"""
Select 2 good examples per task (16 total) from the test set,
copy frames, and build examples.json for the Gradio demo.
Run once locally before uploading to HF Space.
"""
import json
import os
import random
import shutil
from pathlib import Path
random.seed(42)
TEST_DATA = "/root/code/qa_instances/qwen3vl_data_test_03_17.json"
INFER_DATA = "/root/code/LlamaFactory/medical_finetune/eval/results/qwen3_5vl_sft_inference_final_0321_0911.json"
OUTPUT_DIR = Path(__file__).parent / "examples"
MAX_FRAMES = 16 # subsample frames for demo (keep it lightweight)
TASK_GROUPS = {
"Temporal Action Localization": ["tal"],
"Spatiotemporal Grounding": ["stg"],
"Dense Captioning": ["dense_captioning_gpt", "dense_captioning_gemini"],
"Next Action Prediction": ["next_action"],
"Video Summary": ["video_summary_gpt", "video_summary_gemini"],
"Region Caption": ["region_caption_gpt", "region_caption_gemini"],
"CVS Assessment": ["cvs_assessment"],
"Skill Assessment": ["skill_assessment"],
}
EXAMPLES_PER_TASK = 2
def subsample_frames(frame_paths: list[str], max_frames: int) -> list[str]:
"""Uniformly subsample frames."""
if len(frame_paths) <= max_frames:
return frame_paths
step = len(frame_paths) / max_frames
return [frame_paths[int(i * step)] for i in range(max_frames)]
def extract_qa(conversations):
question, answer = "", ""
for msg in conversations:
if msg.get("from") in ("human", "user"):
question = msg.get("value", "").replace("<video>\n", "")
elif msg.get("from") in ("gpt", "assistant"):
answer = msg.get("value", "")
return question, answer
def main():
print("Loading test data...")
with open(TEST_DATA) as f:
test_data = json.load(f)
print("Loading inference results...")
with open(INFER_DATA) as f:
infer_data = json.load(f)
# Clean output
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True)
examples = []
example_id = 0
for task_name, qa_types in TASK_GROUPS.items():
print(f"\n--- {task_name} ---")
# Gather candidates
candidates = [
(i, d) for i, d in enumerate(test_data) if d["qa_type"] in qa_types
]
# Prefer shorter videos (50-80 frames) for better demo experience
candidates.sort(key=lambda x: abs(len(x[1].get("video", [])) - 60))
pool = candidates[: max(30, len(candidates) // 4)]
random.shuffle(pool)
picks = pool[:EXAMPLES_PER_TASK]
for idx, sample in picks:
pred = infer_data.get(str(idx), {})
question, gt_answer = extract_qa(sample["conversations"])
pred_answer = pred.get("answer", "N/A")
# Subsample and copy frames
frames = sample.get("video", [])
sub_frames = subsample_frames(frames, MAX_FRAMES)
example_dir = OUTPUT_DIR / f"example_{example_id:03d}"
example_dir.mkdir()
copied_frames = []
for fi, fpath in enumerate(sub_frames):
ext = Path(fpath).suffix
dst = example_dir / f"frame_{fi:03d}{ext}"
if os.path.exists(fpath):
shutil.copy2(fpath, dst)
copied_frames.append(str(dst.relative_to(OUTPUT_DIR.parent)))
else:
print(f" WARNING: Frame not found: {fpath}")
examples.append(
{
"id": example_id,
"test_idx": idx,
"task": task_name,
"qa_type": sample["qa_type"],
"data_source": sample.get("data_source", "unknown"),
"question": question,
"ground_truth": gt_answer,
"prediction": pred_answer,
"frames": copied_frames,
"n_original_frames": len(frames),
}
)
print(
f" Example {example_id}: test_idx={idx}, {sample['qa_type']} "
f"({len(copied_frames)} frames from {len(frames)}, "
f"{sample.get('data_source', 'unknown')})"
)
example_id += 1
# Save manifest
manifest_path = OUTPUT_DIR.parent / "examples.json"
with open(manifest_path, "w") as f:
json.dump(examples, f, indent=2)
print(f"\nSaved {len(examples)} examples to {manifest_path}")
if __name__ == "__main__":
main()