DW-KhotTaeVL-2B-QueryFrames / build_hybrid.py
commandeaw's picture
Initial release: DW-KhotTaeVL-2B-QueryFrames v1.0
84c8a9d verified
"""Standalone benchmark-mode hybrid policy builder.
Combines two eval JSONs (wild-mode QA and stock-uniform-8f) by selecting,
per question, whichever prediction the policy says to use:
- If task_type ∈ {Object Reasoning, Temporal Reasoning} → take stock-uniform pred
(these are tasks where Video-MME 64f stock does NOT outperform 8f stock,
so query-aware frame selection cannot help).
- Else → take wild-mode (query-aware) pred.
This is a pure post-hoc combination of two prediction sets — it runs no
inference, takes no GPU. The output JSON has the same shape as the
eval JSONs, with an added ``policy_source`` field per result row.
Usage::
python eval_videomme.py --mode wild --n-questions 300 \\
--out-json wild_300q.json
python eval_videomme.py --mode stock-uniform --n-questions 300 \\
--out-json stock_uniform_300q.json
python build_hybrid.py \\
--wild-json wild_300q.json \\
--stock-uniform-json stock_uniform_300q.json \\
--out-json hybrid_300q.json
"""
from __future__ import annotations
import argparse
import json
from collections import defaultdict
from pathlib import Path
# Tasks where Video-MME stock-64f does NOT outperform stock-8f on the
# 300Q mini split (measured: Object Reasoning Δ -0.083, Temporal
# Reasoning Δ +0.000). For these tasks frame coverage is not the
# bottleneck, so the hybrid policy reverts to uniform sampling.
NO_FRAME_GAIN_TASKS = frozenset({"Object Reasoning", "Temporal Reasoning"})
def load_eval(path: str | Path) -> tuple[dict, list[dict]]:
"""Read a Video-MME eval JSON. Returns (summary, results)."""
d = json.loads(Path(path).read_text())
return d.get("summary", {}), d.get("results", [])
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--wild-json", required=True,
help="path to wild-mode eval JSON (QA frames). "
"Produced by `eval_videomme.py --mode wild`.")
ap.add_argument("--stock-uniform-json", required=True,
help="path to stock-uniform-8f eval JSON. "
"Produced by `eval_videomme.py --mode stock-uniform`.")
ap.add_argument("--out-json", required=True,
help="output hybrid JSON path")
args = ap.parse_args()
wild_summary, wild_results = load_eval(args.wild_json)
stk_summary, stk_results = load_eval(args.stock_uniform_json)
wild_by = {r["index"]: r for r in wild_results}
stk_by = {r["index"]: r for r in stk_results}
common = sorted(set(wild_by) & set(stk_by))
if not common:
raise SystemExit(
"[hybrid] no overlapping question indices between the two "
"eval JSONs — make sure both runs used the same n_questions "
"and chunks.")
if len(common) != len(wild_by) or len(common) != len(stk_by):
print(f"[hybrid] WARN: wild={len(wild_by)} stock-uniform={len(stk_by)} "
f"overlap={len(common)}; computing on overlap only.")
hybrid_results = []
src_count = {"query_aware": 0, "uniform_fallback": 0}
for i in common:
w, s = wild_by[i], stk_by[i]
task = w.get("task_type", "")
use_uniform = task in NO_FRAME_GAIN_TASKS
chosen = s if use_uniform else w
src_count["uniform_fallback" if use_uniform else "query_aware"] += 1
hybrid_results.append({
"index": i,
"videoID": w.get("videoID"),
"task_type": task,
"gold": w.get("gold"),
"pred": chosen.get("pred"),
"correct": chosen.get("correct"),
"policy_source": ("uniform_fallback" if use_uniform else "query_aware"),
})
n = len(hybrid_results)
correct = sum(1 for r in hybrid_results if r["correct"])
acc = correct / n if n else 0.0
qa_acc = sum(1 for i in common if wild_by[i]["correct"]) / len(common)
sk_acc = sum(1 for i in common if stk_by[i]["correct"]) / len(common)
summary = {
"tag": "benchmark_mode_hybrid",
"policy": ("uniform-fallback for tasks where stock-64f does not "
"exceed stock-8f (Object Reasoning, Temporal Reasoning); "
"query-aware otherwise"),
"no_frame_gain_tasks": sorted(NO_FRAME_GAIN_TASKS),
"n_questions": n,
"accuracy": round(acc, 4),
"wild_accuracy": round(qa_acc, 4),
"stock_uniform_accuracy": round(sk_acc, 4),
"delta_hybrid_vs_stock_uniform": round(acc - sk_acc, 4),
"delta_hybrid_vs_wild": round(acc - qa_acc, 4),
"policy_source_counts": src_count,
}
out_path = Path(args.out_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(
{"summary": summary, "results": hybrid_results},
indent=2, ensure_ascii=False))
print(f"[hybrid] wrote {out_path}")
print(f"[hybrid] hybrid acc = {acc:.4f} "
f"(wild {qa_acc:.4f}, stock-uniform {sk_acc:.4f})")
print(f"[hybrid] Δ vs stock = {acc-sk_acc:+.4f} "
f"Δ vs wild = {acc-qa_acc:+.4f}")
print(f"[hybrid] policy: query_aware={src_count['query_aware']} "
f"uniform_fallback={src_count['uniform_fallback']}")
# Per-task breakdown for transparency.
by_task = defaultdict(lambda: [0, 0])
by_task_w = defaultdict(lambda: [0, 0])
by_task_s = defaultdict(lambda: [0, 0])
for r in hybrid_results:
t = r["task_type"]
by_task[t][1] += 1
by_task[t][0] += int(r["correct"])
for r in wild_results:
t = r.get("task_type", "")
by_task_w[t][1] += 1
by_task_w[t][0] += int(r["correct"])
for r in stk_results:
t = r.get("task_type", "")
by_task_s[t][1] += 1
by_task_s[t][0] += int(r["correct"])
print(f"\n=== per-task (n / stock-uniform / wild / hybrid / Δ_hyb_vs_stock) ===")
for t in sorted(by_task):
n_t = by_task[t][1]
s_acc = by_task_s[t][0]/by_task_s[t][1] if by_task_s[t][1] else 0
w_acc = by_task_w[t][0]/by_task_w[t][1] if by_task_w[t][1] else 0
h_acc = by_task[t][0]/n_t if n_t else 0
d = h_acc - s_acc
flag = " ⭐" if d >= 0.05 else (" ⚠️" if d <= -0.05 else "")
print(f" {t:<25s} n={n_t:>3d} s={s_acc:.3f} w={w_acc:.3f} "
f"h={h_acc:.3f} Δ_hyb_vs_s={d:+.3f}{flag}")
return 0
if __name__ == "__main__":
import sys
sys.exit(main())