camera-motion-ab-eval / run_videomae_ab_test.py
kaier111's picture
Deploy latest timeline + speed optimizations
3b6218c verified
#!/usr/bin/env python3
"""A/B test for timeline postprocess under strict replayability metric.
Run one-pass window inference, then evaluate multiple fixed postprocess variants.
Strict pass criteria (per shot):
1) no uncertain segment
2) predicted stage count == GT stage count
3) each stage motion-set exactly matches GT stage motion-set (order-sensitive across stages)
"""
from __future__ import annotations
import argparse
import concurrent.futures
import csv
import json
import os
import re
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import asdict, dataclass
from statistics import mean
from typing import Any, Dict, List, Optional, Sequence, Tuple
import cv2
import numpy as np
import torch
try:
from huggingface_hub import InferenceClient
except Exception: # pragma: no cover - optional at runtime for non-LLM mode
InferenceClient = None # type: ignore[assignment]
from run_videomae_timeline import (
LABEL_CN_MAP,
PRIMARY_LABELS,
SECONDARY_CANDIDATES,
MODEL_ID,
NUM_FRAMES,
WindowPrediction,
ShotBoundary,
VideoMAEWindowInferencer,
build_window_ranges,
compute_frame_span,
build_segments_output,
build_shot_narrative,
build_segment_meta,
enforce_min_duration,
enforce_max_segments,
load_shots_jsonl,
)
@dataclass(frozen=True)
class SampleCase:
name: str
video: str
shots_jsonl: Optional[str]
gt_json: Optional[str]
sample_ids: Tuple[int, ...]
SAMPLE_CASES: Tuple[SampleCase, ...] = (
SampleCase(
name="baseus",
video="测试物料/01_测试样本/把自然放进耳机 倍思 WX5 蓝_精选全球优秀创意视频案例编号122691_光厂案例VJshi案例.mp4",
shots_jsonl="测试物料/01_测试样本/shots_info.json",
gt_json="测试物料/02_真值/ground_truth_baseus.json",
sample_ids=(1, 2, 8, 10),
),
SampleCase(
name="runner",
video="测试物料/01_测试样本/跑者.mp4",
shots_jsonl="测试物料/01_测试样本/跑者镜头信息.json",
gt_json="测试物料/02_真值/ground_truth_runner.json",
sample_ids=(6, 9, 29),
),
SampleCase(
name="vertical",
video="测试物料/01_测试样本/竖屏.mp4",
shots_jsonl="测试物料/01_测试样本/竖屏.json",
gt_json="测试物料/02_真值/ground_truth_vertical.json",
sample_ids=(5, 8),
),
)
@dataclass(frozen=True)
class VariantConfig:
name: str
uncertain_p1_min: float
uncertain_margin_min: float
uncertain_undefined_min: float
secondary_min: float
secondary_gap_max: float
smooth_radius: int
hysteresis_support: int
compound_support_windows: int
use_shake_static_gate: bool = False
use_dolly_zoom_gate: bool = False
allow_shake_primary: bool = False
allow_uncertain_compound_rescue: bool = False
rescue_min_primary: float = 0.0
rescue_min_secondary: float = 0.0
rescue_undefined_max: float = 1.0
undefined_override_primary_min: float = 1.1
allow_undefined_single_fallback: bool = False
undefined_fallback_min: float = 0.0
undefined_fallback_margin_min: float = 0.0
undefined_fallback_compound_min: float = 1.1
undefined_fallback_compound_gap_max: float = 0.0
use_undefined_hitchcock_gate: bool = False
hitch_undefined_min: float = 1.1
hitch_zoom_min: float = 0.0
hitch_motion_min: float = 1.1
hitch_translation_min: float = 1.1
use_arc_orbit_gate: bool = False
use_truck_low_translation_gate: bool = False
use_roll_to_pan_gate: bool = False
use_shake_roll_pan_gate: bool = False
use_terminal_uncertain_absorb: bool = False
terminal_uncertain_max_windows: int = 1
terminal_uncertain_neighbor_conf_min: float = 0.7
use_tail_shake_rebound_collapse: bool = False
tail_shake_max_windows: int = 1
VARIANT_A = VariantConfig(
name="A_baseline_locked",
uncertain_p1_min=0.45,
uncertain_margin_min=0.12,
uncertain_undefined_min=0.50,
secondary_min=0.50,
secondary_gap_max=0.18,
smooth_radius=2,
hysteresis_support=2,
compound_support_windows=1,
)
VARIANT_B = VariantConfig(
name="B_replay_priority",
uncertain_p1_min=0.40,
uncertain_margin_min=0.08,
uncertain_undefined_min=0.60,
secondary_min=0.55,
secondary_gap_max=0.15,
smooth_radius=2,
hysteresis_support=3,
compound_support_windows=2,
)
VARIANT_C = VariantConfig(
name="C_replay_plus_cv_shake_static",
uncertain_p1_min=0.40,
uncertain_margin_min=0.08,
uncertain_undefined_min=0.60,
secondary_min=0.55,
secondary_gap_max=0.15,
smooth_radius=2,
hysteresis_support=3,
compound_support_windows=2,
use_shake_static_gate=True,
allow_shake_primary=True,
)
VARIANT_D = VariantConfig(
name="D_replay_plus_compound_rescue",
uncertain_p1_min=0.30,
uncertain_margin_min=0.05,
uncertain_undefined_min=0.60,
secondary_min=0.35,
secondary_gap_max=0.30,
smooth_radius=2,
hysteresis_support=3,
compound_support_windows=1,
use_shake_static_gate=True,
use_dolly_zoom_gate=False,
allow_shake_primary=True,
allow_uncertain_compound_rescue=True,
rescue_min_primary=0.82,
rescue_min_secondary=0.70,
rescue_undefined_max=0.40,
undefined_override_primary_min=0.90,
allow_undefined_single_fallback=True,
undefined_fallback_min=0.02,
undefined_fallback_margin_min=0.0,
undefined_fallback_compound_min=0.02,
undefined_fallback_compound_gap_max=0.08,
use_undefined_hitchcock_gate=True,
hitch_undefined_min=0.85,
hitch_zoom_min=0.02,
hitch_motion_min=0.55,
hitch_translation_min=0.05,
use_arc_orbit_gate=True,
use_truck_low_translation_gate=True,
use_roll_to_pan_gate=True,
use_shake_roll_pan_gate=True,
use_terminal_uncertain_absorb=True,
terminal_uncertain_max_windows=1,
terminal_uncertain_neighbor_conf_min=0.65,
use_tail_shake_rebound_collapse=True,
tail_shake_max_windows=2,
)
VARIANTS: Tuple[VariantConfig, ...] = (VARIANT_A, VARIANT_B, VARIANT_C, VARIANT_D)
TOKEN_PATTERNS: Sequence[Tuple[str, Sequence[str]]] = (
("arc_left", ("左弧绕", "逆时针环绕", "左环绕")),
("arc_right", ("右弧绕", "顺时针环绕", "右环绕")),
("pan_left", ("左摇", "向左摇", "左摇镜")),
("pan_right", ("右摇", "向右摇", "右摇镜")),
("truck_left", ("左移", "机位左移", "向左移", "左跟拍")),
("truck_right", ("右移", "机位右移", "向右移", "右跟拍")),
("tilt_up", ("上摇", "上仰")),
("tilt_down", ("下摇", "下俯")),
("pedestal_up", ("上移", "上升")),
("pedestal_down", ("下移", "下降")),
("dolly_in", ("非变焦推", "非变焦推进", "前推(非变焦)", "推进(非变焦)", "推(非变焦)", "推(非变焦)")),
("dolly_out", ("非变焦拉", "轻微拉(非变焦)", "拉(非变焦)", "拉(非变焦)")),
("zoom_in", ("变焦推", "变焦推进", "飞变焦", "推镜头", "推进", "前推", "推")),
("zoom_out", ("变焦拉", "焦距拉远", "拉镜头", "后拉", "拉远", "拉")),
("track", ("跟拍",)),
("shake", ("晃动", "手持", "抖动")),
("roll_left", ("左滚转",)),
("roll_right", ("右滚转", "旋转")),
("static", ("固定", "基本固定", "静止", "没变动")),
)
SIMULTANEOUS_MARKERS = ("同时", "伴", "并", "并且", "+", "和", "的同时")
SEQUENTIAL_MARKERS = ("然后", "接", "再", "随后", "之后", "最后", "快切", "切至", "切为")
EPS = 1e-6
DEFAULT_JUDGE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
DEFAULT_JUDGE_PROVIDER = "hf"
DEFAULT_GEMINI_MODEL = "gemini-3-flash"
LLM_JUDGE_PASS_THRESHOLD = 85
JUDGE_PROMPT_VERSION = "v1_cn_replay"
JUDGE_JSON_SCHEMA = {
"type": "json_schema",
"json_schema": {
"name": "camera_motion_judge",
"schema": {
"type": "object",
"properties": {
"action_type_score": {"type": "integer", "minimum": 0, "maximum": 25},
"direction_score": {"type": "integer", "minimum": 0, "maximum": 25},
"order_score": {"type": "integer", "minimum": 0, "maximum": 25},
"replayability_score": {"type": "integer", "minimum": 0, "maximum": 25},
"overall_score": {"type": "integer", "minimum": 0, "maximum": 100},
"verdict": {"type": "string", "enum": ["pass", "partial", "fail"]},
"reason": {"type": "string"},
},
"required": [
"action_type_score",
"direction_score",
"order_score",
"replayability_score",
"overall_score",
"verdict",
"reason",
],
"additionalProperties": False,
},
"strict": True,
},
}
GEMINI_JUDGE_RESPONSE_SCHEMA = {
"type": "object",
"properties": {
"action_type_score": {"type": "integer"},
"direction_score": {"type": "integer"},
"order_score": {"type": "integer"},
"replayability_score": {"type": "integer"},
"overall_score": {"type": "integer"},
"verdict": {"type": "string"},
"reason": {"type": "string"},
},
"required": [
"action_type_score",
"direction_score",
"order_score",
"replayability_score",
"overall_score",
"verdict",
"reason",
],
}
@dataclass(frozen=True)
class WindowCVPrior:
mean_flow_mag: float
translation_ratio: float
radial_alignment: float
jitter_ratio: float
static_conf: float
shake_conf: float
zoom_bias: float
dolly_bias: float
def clamp01(v: float) -> float:
if v < 0.0:
return 0.0
if v > 1.0:
return 1.0
return v
def direction_sign_from_scores(scores: Dict[str, float], radial_alignment: float) -> int:
if abs(radial_alignment) >= 0.12:
return 1 if radial_alignment > 0.0 else -1
in_score = max(scores.get("zoom_in", 0.0), scores.get("dolly_in", 0.0))
out_score = max(scores.get("zoom_out", 0.0), scores.get("dolly_out", 0.0))
return 1 if in_score >= out_score else -1
def _radial_grid(h: int, w: int) -> Tuple[np.ndarray, np.ndarray]:
yy, xx = np.mgrid[0:h, 0:w]
cx = (w - 1) * 0.5
cy = (h - 1) * 0.5
rx = xx.astype(np.float32) - cx
ry = yy.astype(np.float32) - cy
rr = np.sqrt(rx * rx + ry * ry)
rr = np.maximum(rr, EPS)
return rx / rr, ry / rr
def compute_window_cv_prior(frames: List[np.ndarray]) -> WindowCVPrior:
if len(frames) < 2:
return WindowCVPrior(0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0)
gframes = [cv2.cvtColor(x, cv2.COLOR_RGB2GRAY) for x in frames]
h, w = gframes[0].shape[:2]
diag = max(float(np.sqrt(h * h + w * w)), 1.0)
radial_x, radial_y = _radial_grid(h, w)
mean_mags: List[float] = []
trans_ratios: List[float] = []
radial_aligns: List[float] = []
trans_vecs: List[Tuple[float, float]] = []
for g0, g1 in zip(gframes[:-1], gframes[1:]):
flow = cv2.calcOpticalFlowFarneback(
g0, g1, None, pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0
)
fx = flow[..., 0]
fy = flow[..., 1]
mag = np.sqrt(fx * fx + fy * fy)
mean_mag = float(np.mean(mag))
mean_mags.append(mean_mag)
med_dx = float(np.median(fx))
med_dy = float(np.median(fy))
trans_mag = float(np.sqrt(med_dx * med_dx + med_dy * med_dy))
trans_vecs.append((med_dx, med_dy))
trans_ratios.append(clamp01(trans_mag / (mean_mag + EPS)))
flow_norm_x = fx / (mag + EPS)
flow_norm_y = fy / (mag + EPS)
radial_dot = (flow_norm_x * radial_x + flow_norm_y * radial_y).astype(np.float32)
weight = mag.astype(np.float32)
align = float(np.sum(radial_dot * weight) / (np.sum(weight) + EPS))
radial_aligns.append(float(np.clip(align, -1.0, 1.0)))
mean_mag = float(mean(mean_mags)) if mean_mags else 0.0
mean_mag_norm = clamp01(mean_mag / max(1.25, 0.005 * diag))
translation_ratio = float(mean(trans_ratios)) if trans_ratios else 0.0
radial_alignment = float(mean(radial_aligns)) if radial_aligns else 0.0
jitter_ratio = 0.0
if len(trans_vecs) >= 2:
deltas = [
float(np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2))
for (x0, y0), (x1, y1) in zip(trans_vecs[:-1], trans_vecs[1:])
]
jitter_ratio = float(mean(deltas)) / (mean_mag + EPS)
jitter_ratio_clamped = clamp01(jitter_ratio / 1.5)
static_conf = clamp01(((0.55 - mean_mag_norm) / 0.55) * (1.0 - 0.6 * jitter_ratio_clamped))
shake_conf_base = mean_mag_norm * (1.0 - translation_ratio) * 1.5
shake_conf_jitter = 0.75 * jitter_ratio_clamped * min(1.0, mean_mag_norm * 2.0 + 0.15)
shake_conf = clamp01(max(shake_conf_base, shake_conf_jitter))
zoom_bias = clamp01(abs(radial_alignment) * mean_mag_norm - 0.25 * translation_ratio)
dolly_bias = clamp01(translation_ratio * mean_mag_norm - 0.2 * abs(radial_alignment))
return WindowCVPrior(
mean_flow_mag=mean_mag_norm,
translation_ratio=translation_ratio,
radial_alignment=radial_alignment,
jitter_ratio=jitter_ratio_clamped,
static_conf=static_conf,
shake_conf=shake_conf,
zoom_bias=zoom_bias,
dolly_bias=dolly_bias,
)
@torch.inference_mode()
def infer_raw_window_from_frames(
inferencer: VideoMAEWindowInferencer, frames: List[np.ndarray], start_sec: float, end_sec: float
) -> WindowPrediction:
inputs = inferencer.processor([frames], return_tensors="pt")
inputs = {k: v.to(inferencer.device) for k, v in inputs.items()}
logits = inferencer.model(**inputs).logits.float().cpu().numpy()[0]
probs = 1.0 / (1.0 + np.exp(-logits))
raw_scores = {inferencer.id2label[i]: float(probs[i]) for i in sorted(inferencer.id2label.keys())}
raw_top = sorted(raw_scores.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)[:8]
return WindowPrediction(
start_sec=start_sec,
end_sec=end_sec,
raw_scores=raw_scores,
raw_top_labels=[{"label": lb, "score": round(sc, 4)} for lb, sc in raw_top],
mode="single",
primary=None,
secondary=None,
confidence=0.0,
)
def build_window_frame_index_lists(
ranges: Sequence[Tuple[float, float]], fps: float, total_frames: int
) -> List[List[int]]:
index_lists: List[List[int]] = []
for s, e in ranges:
f0, f1 = compute_frame_span(s, e, fps, total_frames)
idxs = np.linspace(f0, f1, NUM_FRAMES).astype(int).tolist()
index_lists.append(idxs)
return index_lists
def sample_frames_for_window_indices(
cap: cv2.VideoCapture,
inferencer: VideoMAEWindowInferencer,
index_lists: Sequence[List[int]],
fps: float,
total_frames: int,
) -> List[List[np.ndarray]]:
if not index_lists:
return []
# Fast path: decode in temporal order once and reuse overlap via a tiny cache.
first_idx = min(index_lists[0])
cap.set(cv2.CAP_PROP_POS_FRAMES, int(first_idx))
cur = first_idx
frame_cache: Dict[int, np.ndarray] = {}
out: List[List[np.ndarray]] = []
def _read_at(target_idx: int) -> np.ndarray:
nonlocal cur
if target_idx in frame_cache:
return frame_cache[target_idx]
if target_idx < cur:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(target_idx))
cur = target_idx
while cur < target_idx:
ok = cap.grab()
if not ok:
raise RuntimeError(f"grab failed before frame={target_idx}")
cur += 1
ok, frame = cap.read()
if not ok:
raise RuntimeError(f"read failed at frame={target_idx}")
cur += 1
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_cache[target_idx] = rgb
return rgb
try:
for wi, idxs in enumerate(index_lists):
if not idxs:
raise RuntimeError("empty index list")
frames: List[np.ndarray] = []
for idx in idxs:
frames.append(_read_at(int(idx)))
out.append(frames)
# Keep only frames that can still be reused by upcoming windows.
if wi + 1 < len(index_lists):
next_min = min(index_lists[wi + 1])
stale = [k for k in frame_cache.keys() if k < next_min]
for k in stale:
del frame_cache[k]
except Exception:
# Safe fallback to the original random-seek sampler.
out = []
for idxs in index_lists:
if not idxs:
out.append([])
continue
s_idx = int(min(idxs))
e_idx = int(max(idxs))
start_sec = s_idx / fps
end_sec = min(total_frames - 1, e_idx + 1) / fps
out.append(inferencer._sample_frames(cap, start_sec, end_sec, fps, total_frames))
return out
def axis_of(label: str) -> Optional[Tuple[str, str]]:
if label.endswith("_left"):
return ("h", "left")
if label.endswith("_right"):
return ("h", "right")
if label.endswith("_up"):
return ("v", "up")
if label.endswith("_down"):
return ("v", "down")
if label.endswith("_in"):
return ("d", "in")
if label.endswith("_out"):
return ("d", "out")
return None
def is_axis_conflict(a: str, b: str) -> bool:
# Only block opposite directions inside same family; cross-family combos are allowed.
if a.split("_", 1)[0] != b.split("_", 1)[0]:
return False
aa = axis_of(a)
bb = axis_of(b)
if aa is None or bb is None:
return False
return aa[0] == bb[0] and aa[1] != bb[1]
def is_redundant_zoom_dolly_pair(a: str, b: str) -> bool:
pair = {a, b}
return pair in ({"dolly_in", "zoom_in"}, {"dolly_out", "zoom_out"})
def classify_scores(raw_scores: Dict[str, float], cfg: VariantConfig) -> Tuple[str, Optional[str], Optional[str], float]:
primary_labels = list(PRIMARY_LABELS)
if cfg.allow_shake_primary and "shake" not in primary_labels:
primary_labels.append("shake")
scored_primary = sorted(
((lb, raw_scores.get(lb, 0.0)) for lb in primary_labels),
key=lambda x: (x[1], x[0]),
reverse=True,
)
p1_label, p1 = scored_primary[0]
p2 = scored_primary[1][1] if len(scored_primary) > 1 else 0.0
p2_label = scored_primary[1][0] if len(scored_primary) > 1 else None
margin = p1 - p2
undef = raw_scores.get("undefined", 0.0)
uncertain_by_undef = (undef >= cfg.uncertain_undefined_min) and (p1 < cfg.undefined_override_primary_min)
if p1 < cfg.uncertain_p1_min or margin < cfg.uncertain_margin_min or uncertain_by_undef:
if cfg.allow_uncertain_compound_rescue:
can_rescue = (
undef <= cfg.rescue_undefined_max
and p1 >= cfg.rescue_min_primary
and (p2_label is not None)
and p2 >= cfg.rescue_min_secondary
)
if can_rescue and p1_label != "static" and p2_label not in ("static", p1_label):
if not is_axis_conflict(p1_label, p2_label) and not is_redundant_zoom_dolly_pair(p1_label, p2_label):
return ("compound", p1_label, p2_label, min(p1, p2))
if undef <= cfg.rescue_undefined_max and p1 >= cfg.rescue_min_primary:
return ("single", p1_label, None, p1)
if (
cfg.allow_undefined_single_fallback
and undef >= cfg.uncertain_undefined_min
and p1 >= cfg.undefined_fallback_min
and margin >= cfg.undefined_fallback_margin_min
and p1_label != "static"
):
scored_secondary_fallback = sorted(
(
(lb, raw_scores.get(lb, 0.0))
for lb in SECONDARY_CANDIDATES
if lb not in ("undefined", p1_label, "static")
),
key=lambda x: (x[1], x[0]),
reverse=True,
)
for lb, sc in scored_secondary_fallback:
if sc < cfg.undefined_fallback_compound_min:
break
if p1 - sc > cfg.undefined_fallback_compound_gap_max:
continue
if is_axis_conflict(p1_label, lb):
continue
if is_redundant_zoom_dolly_pair(p1_label, lb):
continue
return ("compound", p1_label, lb, min(p1, sc))
return ("single", p1_label, None, p1)
return ("uncertain", None, None, p1)
if p1_label == "static":
return ("single", p1_label, None, p1)
scored_secondary = sorted(
(
(lb, raw_scores.get(lb, 0.0))
for lb in SECONDARY_CANDIDATES
if lb not in ("undefined", p1_label)
),
key=lambda x: (x[1], x[0]),
reverse=True,
)
secondary = None
secondary_score = 0.0
for lb, sc in scored_secondary:
if sc < cfg.secondary_min:
break
if p1 - sc > cfg.secondary_gap_max:
continue
if is_axis_conflict(p1_label, lb):
continue
if is_redundant_zoom_dolly_pair(p1_label, lb):
continue
secondary = lb
secondary_score = sc
break
if secondary:
return ("compound", p1_label, secondary, min(p1, secondary_score))
return ("single", p1_label, None, p1)
def apply_cv_gates(raw_scores: Dict[str, float], prior: WindowCVPrior, cfg: VariantConfig) -> Dict[str, float]:
scores = dict(raw_scores)
if cfg.use_shake_static_gate:
motion_peak = max(
scores.get(lb, 0.0)
for lb in (
"arc_left",
"arc_right",
"pan_left",
"pan_right",
"truck_left",
"truck_right",
"tilt_up",
"tilt_down",
"pedestal_up",
"pedestal_down",
"dolly_in",
"dolly_out",
"zoom_in",
"zoom_out",
"roll_left",
"roll_right",
)
)
if (
prior.static_conf >= 0.78
and prior.shake_conf < 0.45
and scores.get("static", 0.0) >= 0.45
and motion_peak < 0.65
):
target = clamp01(0.68 + 0.18 * prior.static_conf)
scores["static"] = max(scores.get("static", 0.0), target)
for lb in ("pan_left", "pan_right", "truck_left", "truck_right", "arc_left", "arc_right"):
scores[lb] = clamp01(scores.get(lb, 0.0) * 0.92)
if prior.shake_conf >= 0.55 and (scores.get("shake", 0.0) >= 0.35 or prior.jitter_ratio >= 0.50):
target = clamp01(0.68 + 0.20 * prior.shake_conf)
scores["shake"] = max(scores.get("shake", 0.0), target)
scores["static"] = clamp01(scores.get("static", 0.0) * 0.78)
if cfg.use_roll_to_pan_gate and prior.mean_flow_mag <= 0.08:
roll_r = scores.get("roll_right", 0.0)
roll_l = scores.get("roll_left", 0.0)
if roll_r >= 0.20:
scores["pan_right"] = max(scores.get("pan_right", 0.0), clamp01(roll_r * 0.82))
if roll_l >= 0.20:
scores["pan_left"] = max(scores.get("pan_left", 0.0), clamp01(roll_l * 0.82))
if cfg.use_shake_roll_pan_gate and prior.static_conf >= 0.97 and scores.get("shake", 0.0) >= 0.90:
roll_r = scores.get("roll_right", 0.0)
roll_l = scores.get("roll_left", 0.0)
if max(roll_r, roll_l) >= 0.20:
# Handheld jitter with dominant roll often corresponds to slight pan in GT wording.
scores["shake"] = clamp01(scores.get("shake", 0.0) * 0.45)
if roll_r >= 0.20:
scores["pan_right"] = max(scores.get("pan_right", 0.0), clamp01(roll_r * 1.85))
if roll_l >= 0.20:
scores["pan_left"] = max(scores.get("pan_left", 0.0), clamp01(roll_l * 1.85))
if cfg.use_arc_orbit_gate:
tr = prior.translation_ratio
if tr <= 0.03 and abs(prior.radial_alignment) >= 0.40:
arc_r = scores.get("arc_right", 0.0)
arc_l = scores.get("arc_left", 0.0)
if arc_r >= 0.95:
scores["arc_right"] = clamp01(arc_r * 0.75)
scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(arc_r * 0.88))
scores["pan_left"] = max(scores.get("pan_left", 0.0), clamp01(arc_r * 0.86))
if arc_l >= 0.95:
scores["arc_left"] = clamp01(arc_l * 0.75)
scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(arc_l * 0.88))
scores["pan_right"] = max(scores.get("pan_right", 0.0), clamp01(arc_l * 0.86))
if cfg.use_truck_low_translation_gate:
tr = prior.translation_ratio
if tr <= 0.03 and prior.mean_flow_mag >= 0.10:
tr_r = scores.get("truck_right", 0.0)
tr_l = scores.get("truck_left", 0.0)
if tr_r >= 0.80:
scores["truck_right"] = clamp01(tr_r * 0.55)
scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(tr_r * 0.86))
if tr_l >= 0.80:
scores["truck_left"] = clamp01(tr_l * 0.55)
scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(tr_l * 0.86))
if cfg.use_dolly_zoom_gate:
sign = direction_sign_from_scores(scores, prior.radial_alignment)
if prior.zoom_bias >= 0.22 and prior.zoom_bias >= prior.dolly_bias + 0.08:
boost = 0.08 + 0.20 * prior.zoom_bias
suppress = 0.10 * prior.zoom_bias
if sign > 0:
scores["zoom_in"] = clamp01(max(scores.get("zoom_in", 0.0), scores.get("dolly_in", 0.0) + boost))
scores["dolly_in"] = clamp01(scores.get("dolly_in", 0.0) - suppress)
else:
scores["zoom_out"] = clamp01(max(scores.get("zoom_out", 0.0), scores.get("dolly_out", 0.0) + boost))
scores["dolly_out"] = clamp01(scores.get("dolly_out", 0.0) - suppress)
if prior.dolly_bias >= 0.22 and prior.dolly_bias >= prior.zoom_bias + 0.08:
boost = 0.08 + 0.18 * prior.dolly_bias
suppress = 0.08 * prior.dolly_bias
if sign > 0:
scores["dolly_in"] = clamp01(max(scores.get("dolly_in", 0.0), scores.get("zoom_in", 0.0) + boost))
scores["zoom_in"] = clamp01(scores.get("zoom_in", 0.0) - suppress)
else:
scores["dolly_out"] = clamp01(max(scores.get("dolly_out", 0.0), scores.get("zoom_out", 0.0) + boost))
scores["zoom_out"] = clamp01(scores.get("zoom_out", 0.0) - suppress)
if cfg.use_undefined_hitchcock_gate:
undef = scores.get("undefined", 0.0)
if (
undef >= cfg.hitch_undefined_min
and prior.mean_flow_mag >= cfg.hitch_motion_min
and prior.translation_ratio >= cfg.hitch_translation_min
):
zoom_out = scores.get("zoom_out", 0.0)
zoom_in = scores.get("zoom_in", 0.0)
if zoom_out >= cfg.hitch_zoom_min:
zoom_out = max(zoom_out, cfg.hitch_zoom_min)
scores["zoom_out"] = clamp01(zoom_out)
scores["dolly_in"] = clamp01(max(scores.get("dolly_in", 0.0), min(0.25, zoom_out + 0.03)))
if zoom_in >= cfg.hitch_zoom_min:
zoom_in = max(zoom_in, cfg.hitch_zoom_min)
scores["zoom_in"] = clamp01(zoom_in)
scores["dolly_out"] = clamp01(max(scores.get("dolly_out", 0.0), min(0.25, zoom_in + 0.03)))
return scores
def state_key(w: WindowPrediction) -> Tuple[str, Optional[str], Optional[str]]:
return (w.mode, w.primary, w.secondary)
def set_state(w: WindowPrediction, mode: str, primary: Optional[str], secondary: Optional[str]) -> None:
w.mode = mode
w.primary = primary
w.secondary = secondary
def majority_fill_uncertain(windows: List[WindowPrediction], radius: int) -> None:
n = len(windows)
for i, w in enumerate(windows):
if w.mode != "uncertain":
continue
lo = max(0, i - radius)
hi = min(n - 1, i + radius)
counter: Dict[Tuple[str, Optional[str], Optional[str]], int] = {}
for j in range(lo, hi + 1):
if j == i:
continue
k = state_key(windows[j])
if k[0] == "uncertain":
continue
counter[k] = counter.get(k, 0) + 1
if not counter:
continue
ranked = sorted(counter.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)
best_k, best_cnt = ranked[0]
second_cnt = ranked[1][1] if len(ranked) > 1 else -1
if best_cnt >= 2 and best_cnt > second_cnt:
set_state(w, best_k[0], best_k[1], best_k[2])
def apply_hysteresis(windows: List[WindowPrediction], support_need: int) -> None:
if not windows:
return
states = [state_key(w) for w in windows]
out: List[Tuple[str, Optional[str], Optional[str]]] = [states[0]]
current = states[0]
for i in range(1, len(states)):
cand = states[i]
if cand == current:
out.append(current)
continue
support = 1
j = i + 1
while j < len(states) and states[j] == cand and support < support_need:
support += 1
j += 1
if support >= support_need:
current = cand
out.append(current)
for w, st in zip(windows, out):
set_state(w, st[0], st[1], st[2])
def enforce_compound_support(windows: List[WindowPrediction], support_windows: int) -> None:
if support_windows <= 1:
return
meta = build_segment_meta(windows)
for seg in meta:
if seg.mode != "compound" or not seg.primary:
continue
num_windows = seg.end_idx - seg.start_idx + 1
if num_windows >= support_windows:
continue
for i in range(seg.start_idx, seg.end_idx + 1):
set_state(windows[i], "single", seg.primary, None)
def _seg_windows(seg: Any) -> int:
return int(seg.end_idx - seg.start_idx + 1)
def _seg_labels(seg: Any) -> set:
labels = set()
if seg.primary:
labels.add(seg.primary)
if seg.secondary:
labels.add(seg.secondary)
return labels
def _seg_is_shake_like(seg: Any) -> bool:
labels = _seg_labels(seg)
return "shake" in labels
def _seg_is_motion_with_shake(seg: Any) -> bool:
labels = _seg_labels(seg)
if "shake" not in labels:
return False
return any(lb not in ("shake", "static") for lb in labels)
def absorb_terminal_uncertain(
windows: List[WindowPrediction], max_windows: int, neighbor_conf_min: float
) -> None:
if max_windows <= 0:
return
safety = 0
while safety < 8:
safety += 1
meta = build_segment_meta(windows)
if len(meta) < 2:
return
changed = False
first = meta[0]
right = meta[1]
if (
first.mode == "uncertain"
and _seg_windows(first) <= max_windows
and right.mode != "uncertain"
and right.confidence >= neighbor_conf_min
):
target = (right.mode, right.primary, right.secondary)
for i in range(first.start_idx, first.end_idx + 1):
set_state(windows[i], target[0], target[1], target[2])
changed = True
if not changed:
last = meta[-1]
left = meta[-2]
if (
last.mode == "uncertain"
and _seg_windows(last) <= max_windows
and left.mode != "uncertain"
and left.confidence >= neighbor_conf_min
):
target = (left.mode, left.primary, left.secondary)
for i in range(last.start_idx, last.end_idx + 1):
set_state(windows[i], target[0], target[1], target[2])
changed = True
if not changed:
return
def collapse_tail_shake_rebound(windows: List[WindowPrediction], max_windows: int) -> None:
if max_windows <= 0:
return
safety = 0
while safety < 8:
safety += 1
meta = build_segment_meta(windows)
if len(meta) < 3 or len(meta) > 4:
return
if not _seg_is_shake_like(meta[0]):
return
last = meta[-1]
prev = meta[-2]
if not (_seg_is_shake_like(last) and _seg_windows(last) <= max_windows):
return
has_early_shake = any(_seg_is_shake_like(seg) for seg in meta[:-1])
if not has_early_shake:
return
merge_tail = False
if _seg_is_motion_with_shake(prev):
merge_tail = True
elif not _seg_is_shake_like(prev):
merge_tail = True
if not merge_tail:
return
target = (prev.mode, prev.primary, prev.secondary)
for i in range(last.start_idx, last.end_idx + 1):
set_state(windows[i], target[0], target[1], target[2])
def resolve_hysteresis_support(window_count: int, cfg_support: int) -> int:
if window_count <= 3:
return 1
if window_count <= 8:
return min(cfg_support, 2)
return max(1, cfg_support)
def windows_for_variant(
raw_windows: List[WindowPrediction], cv_priors: List[WindowCVPrior], cfg: VariantConfig
) -> List[WindowPrediction]:
windows = []
for idx, w in enumerate(raw_windows):
prior = cv_priors[idx]
adjusted_scores = apply_cv_gates(w.raw_scores, prior, cfg)
mode, p, s, conf = classify_scores(adjusted_scores, cfg)
windows.append(
WindowPrediction(
start_sec=w.start_sec,
end_sec=w.end_sec,
raw_scores=adjusted_scores,
raw_top_labels=w.raw_top_labels,
mode=mode,
primary=p,
secondary=s,
confidence=conf,
)
)
majority_fill_uncertain(windows, radius=cfg.smooth_radius)
support_need = resolve_hysteresis_support(len(windows), cfg.hysteresis_support)
apply_hysteresis(windows, support_need=support_need)
enforce_compound_support(windows, support_windows=cfg.compound_support_windows)
if cfg.use_terminal_uncertain_absorb:
absorb_terminal_uncertain(
windows,
max_windows=max(1, int(cfg.terminal_uncertain_max_windows)),
neighbor_conf_min=float(cfg.terminal_uncertain_neighbor_conf_min),
)
if cfg.use_tail_shake_rebound_collapse:
collapse_tail_shake_rebound(windows, max_windows=max(1, int(cfg.tail_shake_max_windows)))
return windows
def _all_keyword_hits(text: str, keyword: str) -> List[int]:
hits: List[int] = []
start = 0
while True:
idx = text.find(keyword, start)
if idx < 0:
break
hits.append(idx)
start = idx + len(keyword)
return hits
def parse_gt_stages(desc: str) -> List[List[str]]:
text = (desc or "").strip()
if not text:
return []
if "不验证" in text:
return []
mentions: List[Tuple[int, int, str, str]] = []
has_nonzoom = "非变焦" in text
for label, keywords in TOKEN_PATTERNS:
for kw in keywords:
for pos in _all_keyword_hits(text, kw):
end = pos + len(kw)
local = text[max(0, pos - 12) : min(len(text), end + 12)]
local_nonzoom = "非变焦" in local
# Prefer dolly_* on explicit non-zoom phrases.
if label == "zoom_in" and (local_nonzoom or has_nonzoom) and kw in (
"变焦推",
"变焦推进",
"推镜头",
"推进",
"前推",
"推",
):
continue
if label == "zoom_out" and (local_nonzoom or has_nonzoom) and kw in (
"变焦拉",
"拉镜头",
"后拉",
"拉远",
"拉",
):
continue
# Prefer zoom_* on explicit focal-length wording.
if label == "dolly_in" and ((("变焦" in local) and (not local_nonzoom)) or ("焦距" in local)):
continue
if label == "dolly_out" and ((("变焦" in local) and (not local_nonzoom)) or ("焦距" in local)):
continue
mentions.append((pos, end, label, kw))
if "希区柯克" in text:
hitch_pos = text.find("希区柯克")
mentions.append((hitch_pos, hitch_pos + 4, "dolly_in", "希区柯克"))
mentions.append((hitch_pos, hitch_pos + 4, "zoom_out", "希区柯克"))
if not mentions:
return []
mentions.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
stages: List[List[str]] = [[mentions[0][2]]]
last_end = mentions[0][1]
for pos, end, label, _kw in mentions[1:]:
bridge = text[last_end:pos]
if any(m in bridge for m in SIMULTANEOUS_MARKERS):
if label not in stages[-1]:
stages[-1].append(label)
else:
is_seq = any(m in bridge for m in SEQUENTIAL_MARKERS)
is_short_bridge = len(bridge.strip()) <= 6 and ("," not in bridge) and ("。" not in bridge)
if not is_seq and is_short_bridge:
if label not in stages[-1]:
stages[-1].append(label)
elif label not in stages[-1]:
stages.append([label])
last_end = max(last_end, end)
deduped: List[List[str]] = []
for st in stages:
uniq = []
for lb in st:
if lb not in uniq:
uniq.append(lb)
if "希区柯克" in text:
# "推/拉 + 希区柯克变焦" often means dolly + opposite zoom, not an extra generic zoom token.
if "dolly_in" in uniq and "zoom_out" in uniq and "zoom_in" in uniq:
uniq = [x for x in uniq if x != "zoom_in"]
if "dolly_out" in uniq and "zoom_in" in uniq and "zoom_out" in uniq:
uniq = [x for x in uniq if x != "zoom_out"]
if not deduped or deduped[-1] != uniq:
deduped.append(uniq)
return deduped
def predicted_stages_from_segments(segments: List[dict]) -> List[List[str]]:
stages: List[List[str]] = []
for seg in segments:
mode = seg["mode"]
if mode == "uncertain":
stages.append(["uncertain"])
continue
lbs = []
if seg.get("primary"):
lbs.append(seg["primary"])
if seg.get("secondary"):
lbs.append(seg["secondary"])
if lbs:
stages.append(lbs)
return stages
def stage_to_cn(stage: List[str]) -> str:
return " + ".join(LABEL_CN_MAP.get(x, x) for x in stage)
def strict_compare(expected: List[List[str]], predicted: List[List[str]]) -> Tuple[bool, List[str]]:
reasons: List[str] = []
if any("uncertain" in st for st in predicted):
reasons.append("contains_uncertain")
if len(expected) == 0:
reasons.append("gt_unparsed_or_skipped")
return (False, reasons)
if len(expected) != len(predicted):
reasons.append(f"stage_count_mismatch exp={len(expected)} pred={len(predicted)}")
if len(expected) == len(predicted):
for i, (e, p) in enumerate(zip(expected, predicted), 1):
if set(e) != set(p):
reasons.append(
f"stage_{i}_mismatch exp=({'+'.join(e)}) pred=({'+'.join(p)})"
)
return (len(reasons) == 0, reasons)
def parse_sample_ids_csv(text: str) -> Tuple[int, ...]:
text = (text or "").strip()
if not text:
return ()
out: List[int] = []
for part in re.split(r"[,\s]+", text):
if not part:
continue
out.append(int(part))
return tuple(out)
def resolve_case_path(root: str, path: Optional[str]) -> Optional[str]:
if not path:
return None
if os.path.isabs(path):
return path
return os.path.join(root, path)
def load_gt_map(path: Optional[str]) -> Dict[int, str]:
if not path:
return {}
with open(path, "r", encoding="utf-8") as f:
obj = json.load(f)
if isinstance(obj, list):
return {int(x["id"]): str(x.get("desc", "")) for x in obj if isinstance(x, dict) and "id" in x}
if isinstance(obj, dict):
out: Dict[int, str] = {}
for k, v in obj.items():
try:
sid = int(k)
except Exception:
continue
if isinstance(v, str):
out[sid] = v
elif isinstance(v, dict):
out[sid] = str(v.get("desc", ""))
else:
out[sid] = ""
return out
raise RuntimeError(f"Unsupported GT json format: {path}")
def build_cases(args: argparse.Namespace) -> Tuple[SampleCase, ...]:
if args.video:
if not args.gt_json:
raise RuntimeError("Custom mode requires --gt-json")
return (
SampleCase(
name=args.case_name.strip() if args.case_name else "custom",
video=args.video,
shots_jsonl=args.shots_jsonl,
gt_json=args.gt_json,
sample_ids=parse_sample_ids_csv(args.sample_ids),
),
)
selected = {x.strip() for x in (args.cases or "").split(",") if x.strip()}
if not selected:
return SAMPLE_CASES
out = tuple(c for c in SAMPLE_CASES if c.name in selected)
if not out:
raise RuntimeError(f"--cases did not match built-in cases: {args.cases}")
return out
def _segments_to_judge_text(segments: List[dict]) -> str:
lines = []
for s in segments:
piece = s.get("narrative_piece", s.get("piece", ""))
lines.append(
(
f"{s['seg_idx']}. {s['start_sec']:.2f}-{s['end_sec']:.2f}s"
f" | mode={s['mode']}"
f" | primary={s['primary'] or '-'}"
f" | secondary={s['secondary'] or '-'}"
f" | piece={piece}"
)
)
return "\n".join(lines) if lines else "(no segments)"
def _extract_json_object(text: str) -> Dict[str, Any]:
stripped = (text or "").strip()
if stripped.startswith("```"):
stripped = re.sub(r"^```(?:json)?\s*", "", stripped)
stripped = re.sub(r"\s*```$", "", stripped)
try:
obj = json.loads(stripped)
if isinstance(obj, dict):
return obj
except Exception:
pass
m = re.search(r"\{.*\}", stripped, re.S)
if not m:
raise RuntimeError("No JSON object found in judge output")
obj = json.loads(m.group(0))
if not isinstance(obj, dict):
raise RuntimeError("Judge output JSON is not an object")
return obj
def _as_int(v: Any, lo: int, hi: int) -> int:
try:
x = int(round(float(v)))
except Exception:
x = lo
return min(hi, max(lo, x))
def _verdict_from_overall(overall_score: int) -> str:
if overall_score >= LLM_JUDGE_PASS_THRESHOLD:
return "pass"
if overall_score >= 60:
return "partial"
return "fail"
def _judge_result_template(status: str, model: str, reason: str) -> dict:
return {
"status": status,
"model": model,
"prompt_version": JUDGE_PROMPT_VERSION,
"action_type_score": 0,
"direction_score": 0,
"order_score": 0,
"replayability_score": 0,
"overall_score": 0,
"verdict": "fail" if status == "ok" else status,
"reason": reason[:240],
}
def _build_judge_prompts(
gt_desc: str, gt_stages_cn: List[str], pred_narrative_cn: str, segments: List[dict]
) -> Tuple[str, str]:
sys_prompt = (
"你是运镜复拍评委。只看动作类型、方向、先后顺序、可复拍性。"
"不要考虑文采。必须输出JSON。"
)
user_prompt = (
"请按固定规则评分:\n"
"1) action_type_score: 0-25,动作类型是否匹配。\n"
"2) direction_score: 0-25,方向/正反是否匹配。\n"
"3) order_score: 0-25,先后顺序是否匹配;并行复合不算顺序错。\n"
"4) replayability_score: 0-25,仅凭预测描述,摄影师能否复拍。\n"
"5) overall_score 必须是上述四项之和(0-100)。\n"
"6) verdict: pass/partial/fail。阈值:pass>=85, partial>=60。\n"
"7) reason: 用中文给一句短理由(<=60字)。\n\n"
f"GT描述:\n{gt_desc}\n\n"
f"GT解析阶段(供参考):\n{'; '.join(gt_stages_cn) if gt_stages_cn else '(empty)'}\n\n"
f"预测叙述:\n{pred_narrative_cn}\n\n"
f"预测分段:\n{_segments_to_judge_text(segments)}"
)
return (sys_prompt, user_prompt)
def _llm_judge_one_hf(
client: Any,
model: str,
sys_prompt: str,
user_prompt: str,
) -> Dict[str, Any]:
last_err: Optional[Exception] = None
parsed: Optional[Dict[str, Any]] = None
for use_schema in (True, False):
kwargs: Dict[str, Any] = {
"model": model,
"messages": [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": 0.0,
"max_tokens": 320,
"seed": 42,
}
if use_schema:
kwargs["response_format"] = JUDGE_JSON_SCHEMA
try:
resp = client.chat_completion(**kwargs)
choices = getattr(resp, "choices", None)
if not choices:
raise RuntimeError("judge response has no choices")
content = getattr(choices[0].message, "content", "")
if isinstance(content, list):
text = "".join(str(x.get("text", "")) if isinstance(x, dict) else str(x) for x in content)
else:
text = str(content)
parsed = _extract_json_object(text)
break
except Exception as exc: # pragma: no cover - network/runtime path
last_err = exc
continue
if parsed is None:
err = f"judge_failed: {last_err}" if last_err else "judge_failed: unknown"
raise RuntimeError(err)
return parsed
def _llm_judge_one_gemini(
api_key: str,
model: str,
sys_prompt: str,
user_prompt: str,
) -> Dict[str, Any]:
endpoint = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{urllib.parse.quote(model)}:generateContent?key={urllib.parse.quote(api_key)}"
)
payload = {
"systemInstruction": {"parts": [{"text": sys_prompt}]},
"contents": [{"role": "user", "parts": [{"text": user_prompt}]}],
"generationConfig": {
"temperature": 0,
"responseMimeType": "application/json",
"responseSchema": GEMINI_JUDGE_RESPONSE_SCHEMA,
},
}
body = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
endpoint,
data=body,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=90) as resp:
raw = resp.read().decode("utf-8")
except urllib.error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="ignore") if hasattr(exc, "read") else str(exc)
raise RuntimeError(f"gemini_http_error: {exc.code} {detail[:300]}") from exc
except Exception as exc:
raise RuntimeError(f"gemini_request_error: {exc}") from exc
obj = json.loads(raw)
cands = obj.get("candidates", [])
if not cands:
raise RuntimeError(f"gemini_no_candidates: {raw[:300]}")
parts = cands[0].get("content", {}).get("parts", [])
text = "".join(str(x.get("text", "")) for x in parts if isinstance(x, dict))
if not text.strip():
raise RuntimeError(f"gemini_empty_text: {raw[:300]}")
return _extract_json_object(text)
def llm_judge_one(
*,
provider: str,
client: Any,
model: str,
gt_desc: str,
gt_stages_cn: List[str],
pred_narrative_cn: str,
segments: List[dict],
) -> dict:
if not gt_desc.strip():
return _judge_result_template("skipped", model, "missing_gt_desc")
sys_prompt, user_prompt = _build_judge_prompts(gt_desc, gt_stages_cn, pred_narrative_cn, segments)
try:
if provider == "gemini":
parsed = _llm_judge_one_gemini(
api_key=str(client["api_key"]),
model=model,
sys_prompt=sys_prompt,
user_prompt=user_prompt,
)
else:
parsed = _llm_judge_one_hf(
client=client,
model=model,
sys_prompt=sys_prompt,
user_prompt=user_prompt,
)
except Exception as exc:
return _judge_result_template("error", model, f"{provider}_judge_failed: {exc}")
action = _as_int(parsed.get("action_type_score"), 0, 25)
direction = _as_int(parsed.get("direction_score"), 0, 25)
order = _as_int(parsed.get("order_score"), 0, 25)
replayability = _as_int(parsed.get("replayability_score"), 0, 25)
summed = action + direction + order + replayability
overall = _as_int(parsed.get("overall_score"), 0, 100)
if abs(overall - summed) > 5:
overall = summed
verdict = str(parsed.get("verdict", "")).strip().lower()
derived = _verdict_from_overall(overall)
if verdict not in {"pass", "partial", "fail"}:
verdict = derived
if verdict != derived:
verdict = derived
reason = str(parsed.get("reason", "")).strip()
if not reason:
reason = "评分由四项匹配程度综合给出。"
out = _judge_result_template("ok", model, reason)
out.update(
{
"action_type_score": action,
"direction_score": direction,
"order_score": order,
"replayability_score": replayability,
"overall_score": overall,
"verdict": verdict,
"reason": reason[:240],
}
)
return out
def summarize_variant(rows: List[dict], variant_name: str, llm_enabled: bool) -> dict:
scored = [r for r in rows if r[variant_name]["scorable"]]
hits = sum(1 for r in scored if r[variant_name]["strict_pass"])
summary = {
"scored_shots": len(scored),
"strict_hits": hits,
"strict_hit_rate": round((hits / len(scored)) if scored else 0.0, 4),
"contains_uncertain": sum(1 for r in scored if r[variant_name]["has_uncertain"]),
}
if llm_enabled:
judged = [r[variant_name].get("llm_judge", {}) for r in scored]
valid = [x for x in judged if isinstance(x, dict) and x.get("status") == "ok"]
summary["llm_judged"] = len(valid)
summary["llm_overall_mean"] = round(
float(mean(float(x.get("overall_score", 0.0)) for x in valid)) if valid else 0.0, 2
)
summary["llm_pass_rate"] = round(
float(sum(1 for x in valid if x.get("verdict") == "pass") / len(valid)) if valid else 0.0, 4
)
summary["llm_error_count"] = sum(1 for x in judged if isinstance(x, dict) and x.get("status") == "error")
return summary
def run_benchmark(
*,
hf_token: str,
output_json: str,
output_csv: str,
cases: Sequence[SampleCase],
llm_judge: bool,
judge_provider: str,
judge_model: str,
judge_token: str,
gemini_api_key: str,
judge_workers: int,
max_shots: int,
) -> dict:
root = os.path.abspath(os.path.dirname(__file__))
judge_client = None
if llm_judge:
if judge_provider == "gemini":
key = gemini_api_key or os.environ.get("GEMINI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "")
if not key:
raise RuntimeError("Gemini judge requires --gemini-api-key or GEMINI_API_KEY/GOOGLE_API_KEY")
judge_client = {"api_key": key}
else:
if InferenceClient is None:
raise RuntimeError("huggingface_hub is required for --llm-judge with provider=hf")
token = judge_token or hf_token
if not token:
raise RuntimeError("Judge token required: --judge-token or JUDGE_TOKEN/HF_TOKEN")
judge_client = InferenceClient(token=token)
inferencer = VideoMAEWindowInferencer(MODEL_ID, hf_token)
rows: List[dict] = []
for case in cases:
video_path = resolve_case_path(root, case.video)
shots_path = resolve_case_path(root, case.shots_jsonl)
gt_path = resolve_case_path(root, case.gt_json)
if not video_path:
raise RuntimeError(f"Missing video path for case={case.name}")
print(f"[CASE] {case.name}", flush=True)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
if fps <= 1e-8 or total_frames <= 0:
cap.release()
raise RuntimeError(f"Invalid metadata for {video_path}")
duration_sec = total_frames / fps
if shots_path:
shots = load_shots_jsonl(shots_path, fps=fps, video_duration_sec=duration_sec)
else:
shots = [ShotBoundary(shot_id=1, start_sec=0.0, end_sec=duration_sec)]
shot_map: Dict[int, ShotBoundary] = {s.shot_id: s for s in shots}
gt_map = load_gt_map(gt_path)
selected_ids: Tuple[int, ...]
if case.sample_ids:
selected_ids = tuple(x for x in case.sample_ids if x in shot_map)
else:
selected_ids = tuple(sorted(shot_map.keys()))
if max_shots > 0:
selected_ids = selected_ids[:max_shots]
for sid in selected_ids:
shot = shot_map[sid]
desc = gt_map.get(sid, "")
expected_stages = parse_gt_stages(desc)
print(f" [SHOT] {case.name}#{sid}", flush=True)
ranges = build_window_ranges(shot.start_sec, shot.end_sec)
index_lists = build_window_frame_index_lists(ranges, fps=fps, total_frames=total_frames)
frames_by_window = sample_frames_for_window_indices(
cap,
inferencer,
index_lists=index_lists,
fps=fps,
total_frames=total_frames,
)
raw_windows: List[WindowPrediction] = []
cv_priors: List[WindowCVPrior] = []
for (s, e), frames in zip(ranges, frames_by_window):
raw_windows.append(infer_raw_window_from_frames(inferencer, frames, s, e))
cv_priors.append(compute_window_cv_prior(frames))
entry = {
"dataset": case.name,
"shot_id": sid,
"shot_start_sec": round(shot.start_sec, 3),
"shot_end_sec": round(shot.end_sec, 3),
"gt_desc": desc,
"gt_stages": expected_stages,
"gt_stages_cn": [stage_to_cn(st) for st in expected_stages],
}
variant_payload_map: Dict[str, dict] = {}
for cfg in VARIANTS:
windows = windows_for_variant(raw_windows, cv_priors, cfg)
enforce_min_duration(windows, shot.end_sec - shot.start_sec)
enforce_max_segments(windows)
segments = build_segments_output(windows)
pred_stages = predicted_stages_from_segments(segments)
ok, reasons = strict_compare(expected_stages, pred_stages)
has_uncertain = any(seg["mode"] == "uncertain" for seg in segments)
narrative_cn = build_shot_narrative(segments)
variant_payload = {
"narrative_cn": narrative_cn,
"segments": [
{
"seg_idx": s["seg_idx"],
"start_sec": s["start_sec"],
"end_sec": s["end_sec"],
"mode": s["mode"],
"primary": s["primary"],
"secondary": s["secondary"],
"confidence": s["confidence"],
"piece": s["narrative_piece"],
}
for s in segments
],
"pred_stages": pred_stages,
"pred_stages_cn": [stage_to_cn(st) for st in pred_stages],
"strict_pass": ok,
"strict_reasons": reasons,
"has_uncertain": has_uncertain,
"scorable": len(expected_stages) > 0,
}
variant_payload_map[cfg.name] = variant_payload
if llm_judge:
if judge_client is None:
raise RuntimeError("judge client init failed")
def _judge_variant(name: str, payload: dict) -> Tuple[str, dict]:
judge = llm_judge_one(
provider=judge_provider,
client=judge_client,
model=judge_model,
gt_desc=desc,
gt_stages_cn=entry["gt_stages_cn"],
pred_narrative_cn=payload["narrative_cn"],
segments=payload["segments"],
)
return (name, judge)
jobs = list(variant_payload_map.items())
workers = max(1, min(int(judge_workers), len(jobs)))
if workers == 1:
for name, payload in jobs:
payload["llm_judge"] = _judge_variant(name, payload)[1]
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as ex:
futs = [ex.submit(_judge_variant, name, payload) for name, payload in jobs]
for fut in concurrent.futures.as_completed(futs):
name, judge = fut.result()
variant_payload_map[name]["llm_judge"] = judge
for cfg in VARIANTS:
entry[cfg.name] = variant_payload_map[cfg.name]
entry["window_cv_mean"] = {
"motion": round(float(mean(x.mean_flow_mag for x in cv_priors)) if cv_priors else 0.0, 4),
"jitter": round(float(mean(x.jitter_ratio for x in cv_priors)) if cv_priors else 0.0, 4),
"shake": round(float(mean(x.shake_conf for x in cv_priors)) if cv_priors else 0.0, 4),
"static": round(float(mean(x.static_conf for x in cv_priors)) if cv_priors else 0.0, 4),
"zoom_bias": round(float(mean(x.zoom_bias for x in cv_priors)) if cv_priors else 0.0, 4),
"dolly_bias": round(float(mean(x.dolly_bias for x in cv_priors)) if cv_priors else 0.0, 4),
}
rows.append(entry)
cap.release()
summary = {cfg.name: summarize_variant(rows, cfg.name, llm_enabled=llm_judge) for cfg in VARIANTS}
summary["delta_B_minus_A"] = round(
summary[VARIANT_B.name]["strict_hit_rate"] - summary[VARIANT_A.name]["strict_hit_rate"], 4
)
summary["delta_C_minus_B"] = round(
summary[VARIANT_C.name]["strict_hit_rate"] - summary[VARIANT_B.name]["strict_hit_rate"], 4
)
summary["delta_D_minus_C"] = round(
summary[VARIANT_D.name]["strict_hit_rate"] - summary[VARIANT_C.name]["strict_hit_rate"], 4
)
summary["delta_D_minus_B"] = round(
summary[VARIANT_D.name]["strict_hit_rate"] - summary[VARIANT_B.name]["strict_hit_rate"], 4
)
if llm_judge:
summary["llm_delta_B_minus_A"] = round(
summary[VARIANT_B.name]["llm_overall_mean"] - summary[VARIANT_A.name]["llm_overall_mean"], 2
)
summary["llm_delta_C_minus_B"] = round(
summary[VARIANT_C.name]["llm_overall_mean"] - summary[VARIANT_B.name]["llm_overall_mean"], 2
)
summary["llm_delta_D_minus_C"] = round(
summary[VARIANT_D.name]["llm_overall_mean"] - summary[VARIANT_C.name]["llm_overall_mean"], 2
)
payload = {
"model_id": MODEL_ID,
"variants": [asdict(v) for v in VARIANTS],
"llm_judge": {
"enabled": llm_judge,
"provider": judge_provider if llm_judge else "",
"model": judge_model if llm_judge else "",
"prompt_version": JUDGE_PROMPT_VERSION,
"pass_threshold": LLM_JUDGE_PASS_THRESHOLD,
},
"cases": [asdict(c) for c in cases],
"summary": summary,
"rows": rows,
}
with open(output_json, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def llm_fields(v: dict) -> Tuple[Any, Any, Any]:
j = v.get("llm_judge", {})
if j.get("status") == "ok":
return (j.get("overall_score", ""), j.get("verdict", ""), j.get("reason", ""))
return ("", j.get("status", ""), j.get("reason", ""))
with open(output_csv, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerow(
[
"dataset",
"shot_id",
"gt_desc",
"gt_stages_cn",
"A_pass",
"A_pred_stages_cn",
"A_reasons",
"A_llm_score",
"A_llm_verdict",
"A_llm_reason",
"B_pass",
"B_pred_stages_cn",
"B_reasons",
"B_llm_score",
"B_llm_verdict",
"B_llm_reason",
"C_pass",
"C_pred_stages_cn",
"C_reasons",
"C_llm_score",
"C_llm_verdict",
"C_llm_reason",
"D_pass",
"D_pred_stages_cn",
"D_reasons",
"D_llm_score",
"D_llm_verdict",
"D_llm_reason",
]
)
for r in rows:
a = r[VARIANT_A.name]
b = r[VARIANT_B.name]
c = r[VARIANT_C.name]
d = r[VARIANT_D.name]
a_llm = llm_fields(a)
b_llm = llm_fields(b)
c_llm = llm_fields(c)
d_llm = llm_fields(d)
w.writerow(
[
r["dataset"],
r["shot_id"],
r["gt_desc"],
" | ".join(r["gt_stages_cn"]),
int(a["strict_pass"]),
" | ".join(a["pred_stages_cn"]),
" | ".join(a["strict_reasons"]),
a_llm[0],
a_llm[1],
a_llm[2],
int(b["strict_pass"]),
" | ".join(b["pred_stages_cn"]),
" | ".join(b["strict_reasons"]),
b_llm[0],
b_llm[1],
b_llm[2],
int(c["strict_pass"]),
" | ".join(c["pred_stages_cn"]),
" | ".join(c["strict_reasons"]),
c_llm[0],
c_llm[1],
c_llm[2],
int(d["strict_pass"]),
" | ".join(d["pred_stages_cn"]),
" | ".join(d["strict_reasons"]),
d_llm[0],
d_llm[1],
d_llm[2],
]
)
print(f"[OK] wrote {output_json}")
print(f"[OK] wrote {output_csv}")
print("[SUMMARY]")
print(json.dumps(summary, ensure_ascii=False, indent=2))
return payload
def main() -> int:
parser = argparse.ArgumentParser(description="A/B strict replayability benchmark for VideoMAE timeline")
parser.add_argument("--hf-token", default=os.environ.get("HF_TOKEN", ""), help="HF token or set HF_TOKEN")
parser.add_argument("--output-json", default="ab_strict_report.json")
parser.add_argument("--output-csv", default="ab_strict_report.csv")
parser.add_argument("--cases", default="", help="Built-in cases filter, comma separated (baseus,runner,vertical)")
parser.add_argument("--max-shots", type=int, default=0, help="Cap processed shots per case; 0 means all selected")
parser.add_argument("--video", default="", help="Custom video path; when set, run custom mode")
parser.add_argument("--shots-jsonl", default="", help="Custom shot boundary JSONL path (optional)")
parser.add_argument("--gt-json", default="", help="Custom GT json path (required in custom mode)")
parser.add_argument("--sample-ids", default="", help="Custom sample shot ids, comma separated; default all")
parser.add_argument("--case-name", default="custom", help="Custom case display name")
parser.add_argument("--llm-judge", action="store_true", help="Enable LLM-as-judge scoring")
parser.add_argument(
"--judge-provider",
choices=("hf", "gemini"),
default=os.environ.get("JUDGE_PROVIDER", DEFAULT_JUDGE_PROVIDER),
help="Judge backend provider",
)
parser.add_argument("--judge-model", default=os.environ.get("JUDGE_MODEL", DEFAULT_JUDGE_MODEL))
parser.add_argument("--judge-token", default=os.environ.get("JUDGE_TOKEN", ""))
parser.add_argument("--gemini-api-key", default=os.environ.get("GEMINI_API_KEY", os.environ.get("GOOGLE_API_KEY", "")))
parser.add_argument("--judge-workers", type=int, default=4, help="Parallel workers for judge calls per shot")
args = parser.parse_args()
if not args.hf_token:
raise RuntimeError("HF token required: --hf-token or HF_TOKEN")
if args.llm_judge and args.judge_provider == "gemini" and args.judge_model == DEFAULT_JUDGE_MODEL:
args.judge_model = os.environ.get("GEMINI_MODEL", DEFAULT_GEMINI_MODEL)
cases = build_cases(args)
run_benchmark(
hf_token=args.hf_token,
output_json=args.output_json,
output_csv=args.output_csv,
cases=cases,
llm_judge=bool(args.llm_judge),
judge_provider=args.judge_provider,
judge_model=args.judge_model,
judge_token=args.judge_token,
gemini_api_key=args.gemini_api_key,
judge_workers=max(1, int(args.judge_workers)),
max_shots=max(0, int(args.max_shots)),
)
return 0
if __name__ == "__main__":
raise SystemExit(main())