"""Standalone benchmark-mode hybrid policy builder. Combines two eval JSONs (wild-mode QA and stock-uniform-8f) by selecting, per question, whichever prediction the policy says to use: - If task_type ∈ {Object Reasoning, Temporal Reasoning} → take stock-uniform pred (these are tasks where Video-MME 64f stock does NOT outperform 8f stock, so query-aware frame selection cannot help). - Else → take wild-mode (query-aware) pred. This is a pure post-hoc combination of two prediction sets — it runs no inference, takes no GPU. The output JSON has the same shape as the eval JSONs, with an added ``policy_source`` field per result row. Usage:: python eval_videomme.py --mode wild --n-questions 300 \\ --out-json wild_300q.json python eval_videomme.py --mode stock-uniform --n-questions 300 \\ --out-json stock_uniform_300q.json python build_hybrid.py \\ --wild-json wild_300q.json \\ --stock-uniform-json stock_uniform_300q.json \\ --out-json hybrid_300q.json """ from __future__ import annotations import argparse import json from collections import defaultdict from pathlib import Path # Tasks where Video-MME stock-64f does NOT outperform stock-8f on the # 300Q mini split (measured: Object Reasoning Δ -0.083, Temporal # Reasoning Δ +0.000). For these tasks frame coverage is not the # bottleneck, so the hybrid policy reverts to uniform sampling. NO_FRAME_GAIN_TASKS = frozenset({"Object Reasoning", "Temporal Reasoning"}) def load_eval(path: str | Path) -> tuple[dict, list[dict]]: """Read a Video-MME eval JSON. Returns (summary, results).""" d = json.loads(Path(path).read_text()) return d.get("summary", {}), d.get("results", []) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--wild-json", required=True, help="path to wild-mode eval JSON (QA frames). " "Produced by `eval_videomme.py --mode wild`.") ap.add_argument("--stock-uniform-json", required=True, help="path to stock-uniform-8f eval JSON. " "Produced by `eval_videomme.py --mode stock-uniform`.") ap.add_argument("--out-json", required=True, help="output hybrid JSON path") args = ap.parse_args() wild_summary, wild_results = load_eval(args.wild_json) stk_summary, stk_results = load_eval(args.stock_uniform_json) wild_by = {r["index"]: r for r in wild_results} stk_by = {r["index"]: r for r in stk_results} common = sorted(set(wild_by) & set(stk_by)) if not common: raise SystemExit( "[hybrid] no overlapping question indices between the two " "eval JSONs — make sure both runs used the same n_questions " "and chunks.") if len(common) != len(wild_by) or len(common) != len(stk_by): print(f"[hybrid] WARN: wild={len(wild_by)} stock-uniform={len(stk_by)} " f"overlap={len(common)}; computing on overlap only.") hybrid_results = [] src_count = {"query_aware": 0, "uniform_fallback": 0} for i in common: w, s = wild_by[i], stk_by[i] task = w.get("task_type", "") use_uniform = task in NO_FRAME_GAIN_TASKS chosen = s if use_uniform else w src_count["uniform_fallback" if use_uniform else "query_aware"] += 1 hybrid_results.append({ "index": i, "videoID": w.get("videoID"), "task_type": task, "gold": w.get("gold"), "pred": chosen.get("pred"), "correct": chosen.get("correct"), "policy_source": ("uniform_fallback" if use_uniform else "query_aware"), }) n = len(hybrid_results) correct = sum(1 for r in hybrid_results if r["correct"]) acc = correct / n if n else 0.0 qa_acc = sum(1 for i in common if wild_by[i]["correct"]) / len(common) sk_acc = sum(1 for i in common if stk_by[i]["correct"]) / len(common) summary = { "tag": "benchmark_mode_hybrid", "policy": ("uniform-fallback for tasks where stock-64f does not " "exceed stock-8f (Object Reasoning, Temporal Reasoning); " "query-aware otherwise"), "no_frame_gain_tasks": sorted(NO_FRAME_GAIN_TASKS), "n_questions": n, "accuracy": round(acc, 4), "wild_accuracy": round(qa_acc, 4), "stock_uniform_accuracy": round(sk_acc, 4), "delta_hybrid_vs_stock_uniform": round(acc - sk_acc, 4), "delta_hybrid_vs_wild": round(acc - qa_acc, 4), "policy_source_counts": src_count, } out_path = Path(args.out_json) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps( {"summary": summary, "results": hybrid_results}, indent=2, ensure_ascii=False)) print(f"[hybrid] wrote {out_path}") print(f"[hybrid] hybrid acc = {acc:.4f} " f"(wild {qa_acc:.4f}, stock-uniform {sk_acc:.4f})") print(f"[hybrid] Δ vs stock = {acc-sk_acc:+.4f} " f"Δ vs wild = {acc-qa_acc:+.4f}") print(f"[hybrid] policy: query_aware={src_count['query_aware']} " f"uniform_fallback={src_count['uniform_fallback']}") # Per-task breakdown for transparency. by_task = defaultdict(lambda: [0, 0]) by_task_w = defaultdict(lambda: [0, 0]) by_task_s = defaultdict(lambda: [0, 0]) for r in hybrid_results: t = r["task_type"] by_task[t][1] += 1 by_task[t][0] += int(r["correct"]) for r in wild_results: t = r.get("task_type", "") by_task_w[t][1] += 1 by_task_w[t][0] += int(r["correct"]) for r in stk_results: t = r.get("task_type", "") by_task_s[t][1] += 1 by_task_s[t][0] += int(r["correct"]) print(f"\n=== per-task (n / stock-uniform / wild / hybrid / Δ_hyb_vs_stock) ===") for t in sorted(by_task): n_t = by_task[t][1] s_acc = by_task_s[t][0]/by_task_s[t][1] if by_task_s[t][1] else 0 w_acc = by_task_w[t][0]/by_task_w[t][1] if by_task_w[t][1] else 0 h_acc = by_task[t][0]/n_t if n_t else 0 d = h_acc - s_acc flag = " ⭐" if d >= 0.05 else (" ⚠️" if d <= -0.05 else "") print(f" {t:<25s} n={n_t:>3d} s={s_acc:.3f} w={w_acc:.3f} " f"h={h_acc:.3f} Δ_hyb_vs_s={d:+.3f}{flag}") return 0 if __name__ == "__main__": import sys sys.exit(main())