Spaces:
Running on Zero
Running on Zero
File size: 4,559 Bytes
2e5766d 43c5604 2e5766d 43c5604 2e5766d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | #!/usr/bin/env python3
"""
Select 2 good examples per task (16 total) from the test set,
copy frames, and build examples.json for the Gradio demo.
Run once locally before uploading to HF Space.
"""
import json
import os
import random
import shutil
from pathlib import Path
random.seed(42)
TEST_DATA = "/root/code/qa_instances/qwen3vl_data_test_03_17.json"
INFER_DATA = "/root/code/LlamaFactory/medical_finetune/eval/results/qwen3_5vl_sft_inference_final_0321_0911.json"
OUTPUT_DIR = Path(__file__).parent / "examples"
MAX_FRAMES = 16 # subsample frames for demo (keep it lightweight)
TASK_GROUPS = {
"Temporal Action Localization": ["tal"],
"Spatiotemporal Grounding": ["stg"],
"Dense Captioning": ["dense_captioning_gpt", "dense_captioning_gemini"],
"Next Action Prediction": ["next_action"],
"Video Summary": ["video_summary_gpt", "video_summary_gemini"],
"Region Caption": ["region_caption_gpt", "region_caption_gemini"],
"CVS Assessment": ["cvs_assessment"],
"Skill Assessment": ["skill_assessment"],
}
EXAMPLES_PER_TASK = 2
def subsample_frames(frame_paths: list[str], max_frames: int) -> list[str]:
"""Uniformly subsample frames."""
if len(frame_paths) <= max_frames:
return frame_paths
step = len(frame_paths) / max_frames
return [frame_paths[int(i * step)] for i in range(max_frames)]
def extract_qa(conversations):
question, answer = "", ""
for msg in conversations:
if msg.get("from") in ("human", "user"):
question = msg.get("value", "").replace("<video>\n", "")
elif msg.get("from") in ("gpt", "assistant"):
answer = msg.get("value", "")
return question, answer
def main():
print("Loading test data...")
with open(TEST_DATA) as f:
test_data = json.load(f)
print("Loading inference results...")
with open(INFER_DATA) as f:
infer_data = json.load(f)
# Clean output
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir(parents=True)
examples = []
example_id = 0
for task_name, qa_types in TASK_GROUPS.items():
print(f"\n--- {task_name} ---")
# Gather candidates
candidates = [
(i, d) for i, d in enumerate(test_data) if d["qa_type"] in qa_types
]
# Prefer shorter videos (50-80 frames) for better demo experience
candidates.sort(key=lambda x: abs(len(x[1].get("video", [])) - 60))
pool = candidates[: max(30, len(candidates) // 4)]
random.shuffle(pool)
picks = pool[:EXAMPLES_PER_TASK]
for idx, sample in picks:
pred = infer_data.get(str(idx), {})
question, gt_answer = extract_qa(sample["conversations"])
pred_answer = pred.get("answer", "N/A")
# Subsample and copy frames
frames = sample.get("video", [])
sub_frames = subsample_frames(frames, MAX_FRAMES)
example_dir = OUTPUT_DIR / f"example_{example_id:03d}"
example_dir.mkdir()
copied_frames = []
for fi, fpath in enumerate(sub_frames):
ext = Path(fpath).suffix
dst = example_dir / f"frame_{fi:03d}{ext}"
if os.path.exists(fpath):
shutil.copy2(fpath, dst)
copied_frames.append(str(dst.relative_to(OUTPUT_DIR.parent)))
else:
print(f" WARNING: Frame not found: {fpath}")
examples.append(
{
"id": example_id,
"test_idx": idx,
"task": task_name,
"qa_type": sample["qa_type"],
"data_source": sample.get("data_source", "unknown"),
"question": question,
"ground_truth": gt_answer,
"prediction": pred_answer,
"frames": copied_frames,
"n_original_frames": len(frames),
}
)
print(
f" Example {example_id}: test_idx={idx}, {sample['qa_type']} "
f"({len(copied_frames)} frames from {len(frames)}, "
f"{sample.get('data_source', 'unknown')})"
)
example_id += 1
# Save manifest
manifest_path = OUTPUT_DIR.parent / "examples.json"
with open(manifest_path, "w") as f:
json.dump(examples, f, indent=2)
print(f"\nSaved {len(examples)} examples to {manifest_path}")
if __name__ == "__main__":
main()
|