Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """A/B test for timeline postprocess under strict replayability metric. | |
| Run one-pass window inference, then evaluate multiple fixed postprocess variants. | |
| Strict pass criteria (per shot): | |
| 1) no uncertain segment | |
| 2) predicted stage count == GT stage count | |
| 3) each stage motion-set exactly matches GT stage motion-set (order-sensitive across stages) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import concurrent.futures | |
| import csv | |
| import json | |
| import os | |
| import re | |
| import urllib.error | |
| import urllib.parse | |
| import urllib.request | |
| from dataclasses import asdict, dataclass | |
| from statistics import mean | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| try: | |
| from huggingface_hub import InferenceClient | |
| except Exception: # pragma: no cover - optional at runtime for non-LLM mode | |
| InferenceClient = None # type: ignore[assignment] | |
| from run_videomae_timeline import ( | |
| LABEL_CN_MAP, | |
| PRIMARY_LABELS, | |
| SECONDARY_CANDIDATES, | |
| MODEL_ID, | |
| NUM_FRAMES, | |
| WindowPrediction, | |
| ShotBoundary, | |
| VideoMAEWindowInferencer, | |
| build_window_ranges, | |
| compute_frame_span, | |
| build_segments_output, | |
| build_shot_narrative, | |
| build_segment_meta, | |
| enforce_min_duration, | |
| enforce_max_segments, | |
| load_shots_jsonl, | |
| ) | |
| class SampleCase: | |
| name: str | |
| video: str | |
| shots_jsonl: Optional[str] | |
| gt_json: Optional[str] | |
| sample_ids: Tuple[int, ...] | |
| SAMPLE_CASES: Tuple[SampleCase, ...] = ( | |
| SampleCase( | |
| name="baseus", | |
| video="测试物料/01_测试样本/把自然放进耳机 倍思 WX5 蓝_精选全球优秀创意视频案例编号122691_光厂案例VJshi案例.mp4", | |
| shots_jsonl="测试物料/01_测试样本/shots_info.json", | |
| gt_json="测试物料/02_真值/ground_truth_baseus.json", | |
| sample_ids=(1, 2, 8, 10), | |
| ), | |
| SampleCase( | |
| name="runner", | |
| video="测试物料/01_测试样本/跑者.mp4", | |
| shots_jsonl="测试物料/01_测试样本/跑者镜头信息.json", | |
| gt_json="测试物料/02_真值/ground_truth_runner.json", | |
| sample_ids=(6, 9, 29), | |
| ), | |
| SampleCase( | |
| name="vertical", | |
| video="测试物料/01_测试样本/竖屏.mp4", | |
| shots_jsonl="测试物料/01_测试样本/竖屏.json", | |
| gt_json="测试物料/02_真值/ground_truth_vertical.json", | |
| sample_ids=(5, 8), | |
| ), | |
| ) | |
| class VariantConfig: | |
| name: str | |
| uncertain_p1_min: float | |
| uncertain_margin_min: float | |
| uncertain_undefined_min: float | |
| secondary_min: float | |
| secondary_gap_max: float | |
| smooth_radius: int | |
| hysteresis_support: int | |
| compound_support_windows: int | |
| use_shake_static_gate: bool = False | |
| use_dolly_zoom_gate: bool = False | |
| allow_shake_primary: bool = False | |
| allow_uncertain_compound_rescue: bool = False | |
| rescue_min_primary: float = 0.0 | |
| rescue_min_secondary: float = 0.0 | |
| rescue_undefined_max: float = 1.0 | |
| undefined_override_primary_min: float = 1.1 | |
| allow_undefined_single_fallback: bool = False | |
| undefined_fallback_min: float = 0.0 | |
| undefined_fallback_margin_min: float = 0.0 | |
| undefined_fallback_compound_min: float = 1.1 | |
| undefined_fallback_compound_gap_max: float = 0.0 | |
| use_undefined_hitchcock_gate: bool = False | |
| hitch_undefined_min: float = 1.1 | |
| hitch_zoom_min: float = 0.0 | |
| hitch_motion_min: float = 1.1 | |
| hitch_translation_min: float = 1.1 | |
| use_arc_orbit_gate: bool = False | |
| use_truck_low_translation_gate: bool = False | |
| use_roll_to_pan_gate: bool = False | |
| use_shake_roll_pan_gate: bool = False | |
| use_terminal_uncertain_absorb: bool = False | |
| terminal_uncertain_max_windows: int = 1 | |
| terminal_uncertain_neighbor_conf_min: float = 0.7 | |
| use_tail_shake_rebound_collapse: bool = False | |
| tail_shake_max_windows: int = 1 | |
| VARIANT_A = VariantConfig( | |
| name="A_baseline_locked", | |
| uncertain_p1_min=0.45, | |
| uncertain_margin_min=0.12, | |
| uncertain_undefined_min=0.50, | |
| secondary_min=0.50, | |
| secondary_gap_max=0.18, | |
| smooth_radius=2, | |
| hysteresis_support=2, | |
| compound_support_windows=1, | |
| ) | |
| VARIANT_B = VariantConfig( | |
| name="B_replay_priority", | |
| uncertain_p1_min=0.40, | |
| uncertain_margin_min=0.08, | |
| uncertain_undefined_min=0.60, | |
| secondary_min=0.55, | |
| secondary_gap_max=0.15, | |
| smooth_radius=2, | |
| hysteresis_support=3, | |
| compound_support_windows=2, | |
| ) | |
| VARIANT_C = VariantConfig( | |
| name="C_replay_plus_cv_shake_static", | |
| uncertain_p1_min=0.40, | |
| uncertain_margin_min=0.08, | |
| uncertain_undefined_min=0.60, | |
| secondary_min=0.55, | |
| secondary_gap_max=0.15, | |
| smooth_radius=2, | |
| hysteresis_support=3, | |
| compound_support_windows=2, | |
| use_shake_static_gate=True, | |
| allow_shake_primary=True, | |
| ) | |
| VARIANT_D = VariantConfig( | |
| name="D_replay_plus_compound_rescue", | |
| uncertain_p1_min=0.30, | |
| uncertain_margin_min=0.05, | |
| uncertain_undefined_min=0.60, | |
| secondary_min=0.35, | |
| secondary_gap_max=0.30, | |
| smooth_radius=2, | |
| hysteresis_support=3, | |
| compound_support_windows=1, | |
| use_shake_static_gate=True, | |
| use_dolly_zoom_gate=False, | |
| allow_shake_primary=True, | |
| allow_uncertain_compound_rescue=True, | |
| rescue_min_primary=0.82, | |
| rescue_min_secondary=0.70, | |
| rescue_undefined_max=0.40, | |
| undefined_override_primary_min=0.90, | |
| allow_undefined_single_fallback=True, | |
| undefined_fallback_min=0.02, | |
| undefined_fallback_margin_min=0.0, | |
| undefined_fallback_compound_min=0.02, | |
| undefined_fallback_compound_gap_max=0.08, | |
| use_undefined_hitchcock_gate=True, | |
| hitch_undefined_min=0.85, | |
| hitch_zoom_min=0.02, | |
| hitch_motion_min=0.55, | |
| hitch_translation_min=0.05, | |
| use_arc_orbit_gate=True, | |
| use_truck_low_translation_gate=True, | |
| use_roll_to_pan_gate=True, | |
| use_shake_roll_pan_gate=True, | |
| use_terminal_uncertain_absorb=True, | |
| terminal_uncertain_max_windows=1, | |
| terminal_uncertain_neighbor_conf_min=0.65, | |
| use_tail_shake_rebound_collapse=True, | |
| tail_shake_max_windows=2, | |
| ) | |
| VARIANTS: Tuple[VariantConfig, ...] = (VARIANT_A, VARIANT_B, VARIANT_C, VARIANT_D) | |
| TOKEN_PATTERNS: Sequence[Tuple[str, Sequence[str]]] = ( | |
| ("arc_left", ("左弧绕", "逆时针环绕", "左环绕")), | |
| ("arc_right", ("右弧绕", "顺时针环绕", "右环绕")), | |
| ("pan_left", ("左摇", "向左摇", "左摇镜")), | |
| ("pan_right", ("右摇", "向右摇", "右摇镜")), | |
| ("truck_left", ("左移", "机位左移", "向左移", "左跟拍")), | |
| ("truck_right", ("右移", "机位右移", "向右移", "右跟拍")), | |
| ("tilt_up", ("上摇", "上仰")), | |
| ("tilt_down", ("下摇", "下俯")), | |
| ("pedestal_up", ("上移", "上升")), | |
| ("pedestal_down", ("下移", "下降")), | |
| ("dolly_in", ("非变焦推", "非变焦推进", "前推(非变焦)", "推进(非变焦)", "推(非变焦)", "推(非变焦)")), | |
| ("dolly_out", ("非变焦拉", "轻微拉(非变焦)", "拉(非变焦)", "拉(非变焦)")), | |
| ("zoom_in", ("变焦推", "变焦推进", "飞变焦", "推镜头", "推进", "前推", "推")), | |
| ("zoom_out", ("变焦拉", "焦距拉远", "拉镜头", "后拉", "拉远", "拉")), | |
| ("track", ("跟拍",)), | |
| ("shake", ("晃动", "手持", "抖动")), | |
| ("roll_left", ("左滚转",)), | |
| ("roll_right", ("右滚转", "旋转")), | |
| ("static", ("固定", "基本固定", "静止", "没变动")), | |
| ) | |
| SIMULTANEOUS_MARKERS = ("同时", "伴", "并", "并且", "+", "和", "的同时") | |
| SEQUENTIAL_MARKERS = ("然后", "接", "再", "随后", "之后", "最后", "快切", "切至", "切为") | |
| EPS = 1e-6 | |
| DEFAULT_JUDGE_MODEL = "Qwen/Qwen2.5-7B-Instruct" | |
| DEFAULT_JUDGE_PROVIDER = "hf" | |
| DEFAULT_GEMINI_MODEL = "gemini-3-flash" | |
| LLM_JUDGE_PASS_THRESHOLD = 85 | |
| JUDGE_PROMPT_VERSION = "v1_cn_replay" | |
| JUDGE_JSON_SCHEMA = { | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "camera_motion_judge", | |
| "schema": { | |
| "type": "object", | |
| "properties": { | |
| "action_type_score": {"type": "integer", "minimum": 0, "maximum": 25}, | |
| "direction_score": {"type": "integer", "minimum": 0, "maximum": 25}, | |
| "order_score": {"type": "integer", "minimum": 0, "maximum": 25}, | |
| "replayability_score": {"type": "integer", "minimum": 0, "maximum": 25}, | |
| "overall_score": {"type": "integer", "minimum": 0, "maximum": 100}, | |
| "verdict": {"type": "string", "enum": ["pass", "partial", "fail"]}, | |
| "reason": {"type": "string"}, | |
| }, | |
| "required": [ | |
| "action_type_score", | |
| "direction_score", | |
| "order_score", | |
| "replayability_score", | |
| "overall_score", | |
| "verdict", | |
| "reason", | |
| ], | |
| "additionalProperties": False, | |
| }, | |
| "strict": True, | |
| }, | |
| } | |
| GEMINI_JUDGE_RESPONSE_SCHEMA = { | |
| "type": "object", | |
| "properties": { | |
| "action_type_score": {"type": "integer"}, | |
| "direction_score": {"type": "integer"}, | |
| "order_score": {"type": "integer"}, | |
| "replayability_score": {"type": "integer"}, | |
| "overall_score": {"type": "integer"}, | |
| "verdict": {"type": "string"}, | |
| "reason": {"type": "string"}, | |
| }, | |
| "required": [ | |
| "action_type_score", | |
| "direction_score", | |
| "order_score", | |
| "replayability_score", | |
| "overall_score", | |
| "verdict", | |
| "reason", | |
| ], | |
| } | |
| class WindowCVPrior: | |
| mean_flow_mag: float | |
| translation_ratio: float | |
| radial_alignment: float | |
| jitter_ratio: float | |
| static_conf: float | |
| shake_conf: float | |
| zoom_bias: float | |
| dolly_bias: float | |
| def clamp01(v: float) -> float: | |
| if v < 0.0: | |
| return 0.0 | |
| if v > 1.0: | |
| return 1.0 | |
| return v | |
| def direction_sign_from_scores(scores: Dict[str, float], radial_alignment: float) -> int: | |
| if abs(radial_alignment) >= 0.12: | |
| return 1 if radial_alignment > 0.0 else -1 | |
| in_score = max(scores.get("zoom_in", 0.0), scores.get("dolly_in", 0.0)) | |
| out_score = max(scores.get("zoom_out", 0.0), scores.get("dolly_out", 0.0)) | |
| return 1 if in_score >= out_score else -1 | |
| def _radial_grid(h: int, w: int) -> Tuple[np.ndarray, np.ndarray]: | |
| yy, xx = np.mgrid[0:h, 0:w] | |
| cx = (w - 1) * 0.5 | |
| cy = (h - 1) * 0.5 | |
| rx = xx.astype(np.float32) - cx | |
| ry = yy.astype(np.float32) - cy | |
| rr = np.sqrt(rx * rx + ry * ry) | |
| rr = np.maximum(rr, EPS) | |
| return rx / rr, ry / rr | |
| def compute_window_cv_prior(frames: List[np.ndarray]) -> WindowCVPrior: | |
| if len(frames) < 2: | |
| return WindowCVPrior(0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0) | |
| gframes = [cv2.cvtColor(x, cv2.COLOR_RGB2GRAY) for x in frames] | |
| h, w = gframes[0].shape[:2] | |
| diag = max(float(np.sqrt(h * h + w * w)), 1.0) | |
| radial_x, radial_y = _radial_grid(h, w) | |
| mean_mags: List[float] = [] | |
| trans_ratios: List[float] = [] | |
| radial_aligns: List[float] = [] | |
| trans_vecs: List[Tuple[float, float]] = [] | |
| for g0, g1 in zip(gframes[:-1], gframes[1:]): | |
| flow = cv2.calcOpticalFlowFarneback( | |
| g0, g1, None, pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0 | |
| ) | |
| fx = flow[..., 0] | |
| fy = flow[..., 1] | |
| mag = np.sqrt(fx * fx + fy * fy) | |
| mean_mag = float(np.mean(mag)) | |
| mean_mags.append(mean_mag) | |
| med_dx = float(np.median(fx)) | |
| med_dy = float(np.median(fy)) | |
| trans_mag = float(np.sqrt(med_dx * med_dx + med_dy * med_dy)) | |
| trans_vecs.append((med_dx, med_dy)) | |
| trans_ratios.append(clamp01(trans_mag / (mean_mag + EPS))) | |
| flow_norm_x = fx / (mag + EPS) | |
| flow_norm_y = fy / (mag + EPS) | |
| radial_dot = (flow_norm_x * radial_x + flow_norm_y * radial_y).astype(np.float32) | |
| weight = mag.astype(np.float32) | |
| align = float(np.sum(radial_dot * weight) / (np.sum(weight) + EPS)) | |
| radial_aligns.append(float(np.clip(align, -1.0, 1.0))) | |
| mean_mag = float(mean(mean_mags)) if mean_mags else 0.0 | |
| mean_mag_norm = clamp01(mean_mag / max(1.25, 0.005 * diag)) | |
| translation_ratio = float(mean(trans_ratios)) if trans_ratios else 0.0 | |
| radial_alignment = float(mean(radial_aligns)) if radial_aligns else 0.0 | |
| jitter_ratio = 0.0 | |
| if len(trans_vecs) >= 2: | |
| deltas = [ | |
| float(np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2)) | |
| for (x0, y0), (x1, y1) in zip(trans_vecs[:-1], trans_vecs[1:]) | |
| ] | |
| jitter_ratio = float(mean(deltas)) / (mean_mag + EPS) | |
| jitter_ratio_clamped = clamp01(jitter_ratio / 1.5) | |
| static_conf = clamp01(((0.55 - mean_mag_norm) / 0.55) * (1.0 - 0.6 * jitter_ratio_clamped)) | |
| shake_conf_base = mean_mag_norm * (1.0 - translation_ratio) * 1.5 | |
| shake_conf_jitter = 0.75 * jitter_ratio_clamped * min(1.0, mean_mag_norm * 2.0 + 0.15) | |
| shake_conf = clamp01(max(shake_conf_base, shake_conf_jitter)) | |
| zoom_bias = clamp01(abs(radial_alignment) * mean_mag_norm - 0.25 * translation_ratio) | |
| dolly_bias = clamp01(translation_ratio * mean_mag_norm - 0.2 * abs(radial_alignment)) | |
| return WindowCVPrior( | |
| mean_flow_mag=mean_mag_norm, | |
| translation_ratio=translation_ratio, | |
| radial_alignment=radial_alignment, | |
| jitter_ratio=jitter_ratio_clamped, | |
| static_conf=static_conf, | |
| shake_conf=shake_conf, | |
| zoom_bias=zoom_bias, | |
| dolly_bias=dolly_bias, | |
| ) | |
| def infer_raw_window_from_frames( | |
| inferencer: VideoMAEWindowInferencer, frames: List[np.ndarray], start_sec: float, end_sec: float | |
| ) -> WindowPrediction: | |
| inputs = inferencer.processor([frames], return_tensors="pt") | |
| inputs = {k: v.to(inferencer.device) for k, v in inputs.items()} | |
| logits = inferencer.model(**inputs).logits.float().cpu().numpy()[0] | |
| probs = 1.0 / (1.0 + np.exp(-logits)) | |
| raw_scores = {inferencer.id2label[i]: float(probs[i]) for i in sorted(inferencer.id2label.keys())} | |
| raw_top = sorted(raw_scores.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)[:8] | |
| return WindowPrediction( | |
| start_sec=start_sec, | |
| end_sec=end_sec, | |
| raw_scores=raw_scores, | |
| raw_top_labels=[{"label": lb, "score": round(sc, 4)} for lb, sc in raw_top], | |
| mode="single", | |
| primary=None, | |
| secondary=None, | |
| confidence=0.0, | |
| ) | |
| def build_window_frame_index_lists( | |
| ranges: Sequence[Tuple[float, float]], fps: float, total_frames: int | |
| ) -> List[List[int]]: | |
| index_lists: List[List[int]] = [] | |
| for s, e in ranges: | |
| f0, f1 = compute_frame_span(s, e, fps, total_frames) | |
| idxs = np.linspace(f0, f1, NUM_FRAMES).astype(int).tolist() | |
| index_lists.append(idxs) | |
| return index_lists | |
| def sample_frames_for_window_indices( | |
| cap: cv2.VideoCapture, | |
| inferencer: VideoMAEWindowInferencer, | |
| index_lists: Sequence[List[int]], | |
| fps: float, | |
| total_frames: int, | |
| ) -> List[List[np.ndarray]]: | |
| if not index_lists: | |
| return [] | |
| # Fast path: decode in temporal order once and reuse overlap via a tiny cache. | |
| first_idx = min(index_lists[0]) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(first_idx)) | |
| cur = first_idx | |
| frame_cache: Dict[int, np.ndarray] = {} | |
| out: List[List[np.ndarray]] = [] | |
| def _read_at(target_idx: int) -> np.ndarray: | |
| nonlocal cur | |
| if target_idx in frame_cache: | |
| return frame_cache[target_idx] | |
| if target_idx < cur: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(target_idx)) | |
| cur = target_idx | |
| while cur < target_idx: | |
| ok = cap.grab() | |
| if not ok: | |
| raise RuntimeError(f"grab failed before frame={target_idx}") | |
| cur += 1 | |
| ok, frame = cap.read() | |
| if not ok: | |
| raise RuntimeError(f"read failed at frame={target_idx}") | |
| cur += 1 | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame_cache[target_idx] = rgb | |
| return rgb | |
| try: | |
| for wi, idxs in enumerate(index_lists): | |
| if not idxs: | |
| raise RuntimeError("empty index list") | |
| frames: List[np.ndarray] = [] | |
| for idx in idxs: | |
| frames.append(_read_at(int(idx))) | |
| out.append(frames) | |
| # Keep only frames that can still be reused by upcoming windows. | |
| if wi + 1 < len(index_lists): | |
| next_min = min(index_lists[wi + 1]) | |
| stale = [k for k in frame_cache.keys() if k < next_min] | |
| for k in stale: | |
| del frame_cache[k] | |
| except Exception: | |
| # Safe fallback to the original random-seek sampler. | |
| out = [] | |
| for idxs in index_lists: | |
| if not idxs: | |
| out.append([]) | |
| continue | |
| s_idx = int(min(idxs)) | |
| e_idx = int(max(idxs)) | |
| start_sec = s_idx / fps | |
| end_sec = min(total_frames - 1, e_idx + 1) / fps | |
| out.append(inferencer._sample_frames(cap, start_sec, end_sec, fps, total_frames)) | |
| return out | |
| def axis_of(label: str) -> Optional[Tuple[str, str]]: | |
| if label.endswith("_left"): | |
| return ("h", "left") | |
| if label.endswith("_right"): | |
| return ("h", "right") | |
| if label.endswith("_up"): | |
| return ("v", "up") | |
| if label.endswith("_down"): | |
| return ("v", "down") | |
| if label.endswith("_in"): | |
| return ("d", "in") | |
| if label.endswith("_out"): | |
| return ("d", "out") | |
| return None | |
| def is_axis_conflict(a: str, b: str) -> bool: | |
| # Only block opposite directions inside same family; cross-family combos are allowed. | |
| if a.split("_", 1)[0] != b.split("_", 1)[0]: | |
| return False | |
| aa = axis_of(a) | |
| bb = axis_of(b) | |
| if aa is None or bb is None: | |
| return False | |
| return aa[0] == bb[0] and aa[1] != bb[1] | |
| def is_redundant_zoom_dolly_pair(a: str, b: str) -> bool: | |
| pair = {a, b} | |
| return pair in ({"dolly_in", "zoom_in"}, {"dolly_out", "zoom_out"}) | |
| def classify_scores(raw_scores: Dict[str, float], cfg: VariantConfig) -> Tuple[str, Optional[str], Optional[str], float]: | |
| primary_labels = list(PRIMARY_LABELS) | |
| if cfg.allow_shake_primary and "shake" not in primary_labels: | |
| primary_labels.append("shake") | |
| scored_primary = sorted( | |
| ((lb, raw_scores.get(lb, 0.0)) for lb in primary_labels), | |
| key=lambda x: (x[1], x[0]), | |
| reverse=True, | |
| ) | |
| p1_label, p1 = scored_primary[0] | |
| p2 = scored_primary[1][1] if len(scored_primary) > 1 else 0.0 | |
| p2_label = scored_primary[1][0] if len(scored_primary) > 1 else None | |
| margin = p1 - p2 | |
| undef = raw_scores.get("undefined", 0.0) | |
| uncertain_by_undef = (undef >= cfg.uncertain_undefined_min) and (p1 < cfg.undefined_override_primary_min) | |
| if p1 < cfg.uncertain_p1_min or margin < cfg.uncertain_margin_min or uncertain_by_undef: | |
| if cfg.allow_uncertain_compound_rescue: | |
| can_rescue = ( | |
| undef <= cfg.rescue_undefined_max | |
| and p1 >= cfg.rescue_min_primary | |
| and (p2_label is not None) | |
| and p2 >= cfg.rescue_min_secondary | |
| ) | |
| if can_rescue and p1_label != "static" and p2_label not in ("static", p1_label): | |
| if not is_axis_conflict(p1_label, p2_label) and not is_redundant_zoom_dolly_pair(p1_label, p2_label): | |
| return ("compound", p1_label, p2_label, min(p1, p2)) | |
| if undef <= cfg.rescue_undefined_max and p1 >= cfg.rescue_min_primary: | |
| return ("single", p1_label, None, p1) | |
| if ( | |
| cfg.allow_undefined_single_fallback | |
| and undef >= cfg.uncertain_undefined_min | |
| and p1 >= cfg.undefined_fallback_min | |
| and margin >= cfg.undefined_fallback_margin_min | |
| and p1_label != "static" | |
| ): | |
| scored_secondary_fallback = sorted( | |
| ( | |
| (lb, raw_scores.get(lb, 0.0)) | |
| for lb in SECONDARY_CANDIDATES | |
| if lb not in ("undefined", p1_label, "static") | |
| ), | |
| key=lambda x: (x[1], x[0]), | |
| reverse=True, | |
| ) | |
| for lb, sc in scored_secondary_fallback: | |
| if sc < cfg.undefined_fallback_compound_min: | |
| break | |
| if p1 - sc > cfg.undefined_fallback_compound_gap_max: | |
| continue | |
| if is_axis_conflict(p1_label, lb): | |
| continue | |
| if is_redundant_zoom_dolly_pair(p1_label, lb): | |
| continue | |
| return ("compound", p1_label, lb, min(p1, sc)) | |
| return ("single", p1_label, None, p1) | |
| return ("uncertain", None, None, p1) | |
| if p1_label == "static": | |
| return ("single", p1_label, None, p1) | |
| scored_secondary = sorted( | |
| ( | |
| (lb, raw_scores.get(lb, 0.0)) | |
| for lb in SECONDARY_CANDIDATES | |
| if lb not in ("undefined", p1_label) | |
| ), | |
| key=lambda x: (x[1], x[0]), | |
| reverse=True, | |
| ) | |
| secondary = None | |
| secondary_score = 0.0 | |
| for lb, sc in scored_secondary: | |
| if sc < cfg.secondary_min: | |
| break | |
| if p1 - sc > cfg.secondary_gap_max: | |
| continue | |
| if is_axis_conflict(p1_label, lb): | |
| continue | |
| if is_redundant_zoom_dolly_pair(p1_label, lb): | |
| continue | |
| secondary = lb | |
| secondary_score = sc | |
| break | |
| if secondary: | |
| return ("compound", p1_label, secondary, min(p1, secondary_score)) | |
| return ("single", p1_label, None, p1) | |
| def apply_cv_gates(raw_scores: Dict[str, float], prior: WindowCVPrior, cfg: VariantConfig) -> Dict[str, float]: | |
| scores = dict(raw_scores) | |
| if cfg.use_shake_static_gate: | |
| motion_peak = max( | |
| scores.get(lb, 0.0) | |
| for lb in ( | |
| "arc_left", | |
| "arc_right", | |
| "pan_left", | |
| "pan_right", | |
| "truck_left", | |
| "truck_right", | |
| "tilt_up", | |
| "tilt_down", | |
| "pedestal_up", | |
| "pedestal_down", | |
| "dolly_in", | |
| "dolly_out", | |
| "zoom_in", | |
| "zoom_out", | |
| "roll_left", | |
| "roll_right", | |
| ) | |
| ) | |
| if ( | |
| prior.static_conf >= 0.78 | |
| and prior.shake_conf < 0.45 | |
| and scores.get("static", 0.0) >= 0.45 | |
| and motion_peak < 0.65 | |
| ): | |
| target = clamp01(0.68 + 0.18 * prior.static_conf) | |
| scores["static"] = max(scores.get("static", 0.0), target) | |
| for lb in ("pan_left", "pan_right", "truck_left", "truck_right", "arc_left", "arc_right"): | |
| scores[lb] = clamp01(scores.get(lb, 0.0) * 0.92) | |
| if prior.shake_conf >= 0.55 and (scores.get("shake", 0.0) >= 0.35 or prior.jitter_ratio >= 0.50): | |
| target = clamp01(0.68 + 0.20 * prior.shake_conf) | |
| scores["shake"] = max(scores.get("shake", 0.0), target) | |
| scores["static"] = clamp01(scores.get("static", 0.0) * 0.78) | |
| if cfg.use_roll_to_pan_gate and prior.mean_flow_mag <= 0.08: | |
| roll_r = scores.get("roll_right", 0.0) | |
| roll_l = scores.get("roll_left", 0.0) | |
| if roll_r >= 0.20: | |
| scores["pan_right"] = max(scores.get("pan_right", 0.0), clamp01(roll_r * 0.82)) | |
| if roll_l >= 0.20: | |
| scores["pan_left"] = max(scores.get("pan_left", 0.0), clamp01(roll_l * 0.82)) | |
| if cfg.use_shake_roll_pan_gate and prior.static_conf >= 0.97 and scores.get("shake", 0.0) >= 0.90: | |
| roll_r = scores.get("roll_right", 0.0) | |
| roll_l = scores.get("roll_left", 0.0) | |
| if max(roll_r, roll_l) >= 0.20: | |
| # Handheld jitter with dominant roll often corresponds to slight pan in GT wording. | |
| scores["shake"] = clamp01(scores.get("shake", 0.0) * 0.45) | |
| if roll_r >= 0.20: | |
| scores["pan_right"] = max(scores.get("pan_right", 0.0), clamp01(roll_r * 1.85)) | |
| if roll_l >= 0.20: | |
| scores["pan_left"] = max(scores.get("pan_left", 0.0), clamp01(roll_l * 1.85)) | |
| if cfg.use_arc_orbit_gate: | |
| tr = prior.translation_ratio | |
| if tr <= 0.03 and abs(prior.radial_alignment) >= 0.40: | |
| arc_r = scores.get("arc_right", 0.0) | |
| arc_l = scores.get("arc_left", 0.0) | |
| if arc_r >= 0.95: | |
| scores["arc_right"] = clamp01(arc_r * 0.75) | |
| scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(arc_r * 0.88)) | |
| scores["pan_left"] = max(scores.get("pan_left", 0.0), clamp01(arc_r * 0.86)) | |
| if arc_l >= 0.95: | |
| scores["arc_left"] = clamp01(arc_l * 0.75) | |
| scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(arc_l * 0.88)) | |
| scores["pan_right"] = max(scores.get("pan_right", 0.0), clamp01(arc_l * 0.86)) | |
| if cfg.use_truck_low_translation_gate: | |
| tr = prior.translation_ratio | |
| if tr <= 0.03 and prior.mean_flow_mag >= 0.10: | |
| tr_r = scores.get("truck_right", 0.0) | |
| tr_l = scores.get("truck_left", 0.0) | |
| if tr_r >= 0.80: | |
| scores["truck_right"] = clamp01(tr_r * 0.55) | |
| scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(tr_r * 0.86)) | |
| if tr_l >= 0.80: | |
| scores["truck_left"] = clamp01(tr_l * 0.55) | |
| scores["dolly_in"] = max(scores.get("dolly_in", 0.0), clamp01(tr_l * 0.86)) | |
| if cfg.use_dolly_zoom_gate: | |
| sign = direction_sign_from_scores(scores, prior.radial_alignment) | |
| if prior.zoom_bias >= 0.22 and prior.zoom_bias >= prior.dolly_bias + 0.08: | |
| boost = 0.08 + 0.20 * prior.zoom_bias | |
| suppress = 0.10 * prior.zoom_bias | |
| if sign > 0: | |
| scores["zoom_in"] = clamp01(max(scores.get("zoom_in", 0.0), scores.get("dolly_in", 0.0) + boost)) | |
| scores["dolly_in"] = clamp01(scores.get("dolly_in", 0.0) - suppress) | |
| else: | |
| scores["zoom_out"] = clamp01(max(scores.get("zoom_out", 0.0), scores.get("dolly_out", 0.0) + boost)) | |
| scores["dolly_out"] = clamp01(scores.get("dolly_out", 0.0) - suppress) | |
| if prior.dolly_bias >= 0.22 and prior.dolly_bias >= prior.zoom_bias + 0.08: | |
| boost = 0.08 + 0.18 * prior.dolly_bias | |
| suppress = 0.08 * prior.dolly_bias | |
| if sign > 0: | |
| scores["dolly_in"] = clamp01(max(scores.get("dolly_in", 0.0), scores.get("zoom_in", 0.0) + boost)) | |
| scores["zoom_in"] = clamp01(scores.get("zoom_in", 0.0) - suppress) | |
| else: | |
| scores["dolly_out"] = clamp01(max(scores.get("dolly_out", 0.0), scores.get("zoom_out", 0.0) + boost)) | |
| scores["zoom_out"] = clamp01(scores.get("zoom_out", 0.0) - suppress) | |
| if cfg.use_undefined_hitchcock_gate: | |
| undef = scores.get("undefined", 0.0) | |
| if ( | |
| undef >= cfg.hitch_undefined_min | |
| and prior.mean_flow_mag >= cfg.hitch_motion_min | |
| and prior.translation_ratio >= cfg.hitch_translation_min | |
| ): | |
| zoom_out = scores.get("zoom_out", 0.0) | |
| zoom_in = scores.get("zoom_in", 0.0) | |
| if zoom_out >= cfg.hitch_zoom_min: | |
| zoom_out = max(zoom_out, cfg.hitch_zoom_min) | |
| scores["zoom_out"] = clamp01(zoom_out) | |
| scores["dolly_in"] = clamp01(max(scores.get("dolly_in", 0.0), min(0.25, zoom_out + 0.03))) | |
| if zoom_in >= cfg.hitch_zoom_min: | |
| zoom_in = max(zoom_in, cfg.hitch_zoom_min) | |
| scores["zoom_in"] = clamp01(zoom_in) | |
| scores["dolly_out"] = clamp01(max(scores.get("dolly_out", 0.0), min(0.25, zoom_in + 0.03))) | |
| return scores | |
| def state_key(w: WindowPrediction) -> Tuple[str, Optional[str], Optional[str]]: | |
| return (w.mode, w.primary, w.secondary) | |
| def set_state(w: WindowPrediction, mode: str, primary: Optional[str], secondary: Optional[str]) -> None: | |
| w.mode = mode | |
| w.primary = primary | |
| w.secondary = secondary | |
| def majority_fill_uncertain(windows: List[WindowPrediction], radius: int) -> None: | |
| n = len(windows) | |
| for i, w in enumerate(windows): | |
| if w.mode != "uncertain": | |
| continue | |
| lo = max(0, i - radius) | |
| hi = min(n - 1, i + radius) | |
| counter: Dict[Tuple[str, Optional[str], Optional[str]], int] = {} | |
| for j in range(lo, hi + 1): | |
| if j == i: | |
| continue | |
| k = state_key(windows[j]) | |
| if k[0] == "uncertain": | |
| continue | |
| counter[k] = counter.get(k, 0) + 1 | |
| if not counter: | |
| continue | |
| ranked = sorted(counter.items(), key=lambda kv: (kv[1], kv[0]), reverse=True) | |
| best_k, best_cnt = ranked[0] | |
| second_cnt = ranked[1][1] if len(ranked) > 1 else -1 | |
| if best_cnt >= 2 and best_cnt > second_cnt: | |
| set_state(w, best_k[0], best_k[1], best_k[2]) | |
| def apply_hysteresis(windows: List[WindowPrediction], support_need: int) -> None: | |
| if not windows: | |
| return | |
| states = [state_key(w) for w in windows] | |
| out: List[Tuple[str, Optional[str], Optional[str]]] = [states[0]] | |
| current = states[0] | |
| for i in range(1, len(states)): | |
| cand = states[i] | |
| if cand == current: | |
| out.append(current) | |
| continue | |
| support = 1 | |
| j = i + 1 | |
| while j < len(states) and states[j] == cand and support < support_need: | |
| support += 1 | |
| j += 1 | |
| if support >= support_need: | |
| current = cand | |
| out.append(current) | |
| for w, st in zip(windows, out): | |
| set_state(w, st[0], st[1], st[2]) | |
| def enforce_compound_support(windows: List[WindowPrediction], support_windows: int) -> None: | |
| if support_windows <= 1: | |
| return | |
| meta = build_segment_meta(windows) | |
| for seg in meta: | |
| if seg.mode != "compound" or not seg.primary: | |
| continue | |
| num_windows = seg.end_idx - seg.start_idx + 1 | |
| if num_windows >= support_windows: | |
| continue | |
| for i in range(seg.start_idx, seg.end_idx + 1): | |
| set_state(windows[i], "single", seg.primary, None) | |
| def _seg_windows(seg: Any) -> int: | |
| return int(seg.end_idx - seg.start_idx + 1) | |
| def _seg_labels(seg: Any) -> set: | |
| labels = set() | |
| if seg.primary: | |
| labels.add(seg.primary) | |
| if seg.secondary: | |
| labels.add(seg.secondary) | |
| return labels | |
| def _seg_is_shake_like(seg: Any) -> bool: | |
| labels = _seg_labels(seg) | |
| return "shake" in labels | |
| def _seg_is_motion_with_shake(seg: Any) -> bool: | |
| labels = _seg_labels(seg) | |
| if "shake" not in labels: | |
| return False | |
| return any(lb not in ("shake", "static") for lb in labels) | |
| def absorb_terminal_uncertain( | |
| windows: List[WindowPrediction], max_windows: int, neighbor_conf_min: float | |
| ) -> None: | |
| if max_windows <= 0: | |
| return | |
| safety = 0 | |
| while safety < 8: | |
| safety += 1 | |
| meta = build_segment_meta(windows) | |
| if len(meta) < 2: | |
| return | |
| changed = False | |
| first = meta[0] | |
| right = meta[1] | |
| if ( | |
| first.mode == "uncertain" | |
| and _seg_windows(first) <= max_windows | |
| and right.mode != "uncertain" | |
| and right.confidence >= neighbor_conf_min | |
| ): | |
| target = (right.mode, right.primary, right.secondary) | |
| for i in range(first.start_idx, first.end_idx + 1): | |
| set_state(windows[i], target[0], target[1], target[2]) | |
| changed = True | |
| if not changed: | |
| last = meta[-1] | |
| left = meta[-2] | |
| if ( | |
| last.mode == "uncertain" | |
| and _seg_windows(last) <= max_windows | |
| and left.mode != "uncertain" | |
| and left.confidence >= neighbor_conf_min | |
| ): | |
| target = (left.mode, left.primary, left.secondary) | |
| for i in range(last.start_idx, last.end_idx + 1): | |
| set_state(windows[i], target[0], target[1], target[2]) | |
| changed = True | |
| if not changed: | |
| return | |
| def collapse_tail_shake_rebound(windows: List[WindowPrediction], max_windows: int) -> None: | |
| if max_windows <= 0: | |
| return | |
| safety = 0 | |
| while safety < 8: | |
| safety += 1 | |
| meta = build_segment_meta(windows) | |
| if len(meta) < 3 or len(meta) > 4: | |
| return | |
| if not _seg_is_shake_like(meta[0]): | |
| return | |
| last = meta[-1] | |
| prev = meta[-2] | |
| if not (_seg_is_shake_like(last) and _seg_windows(last) <= max_windows): | |
| return | |
| has_early_shake = any(_seg_is_shake_like(seg) for seg in meta[:-1]) | |
| if not has_early_shake: | |
| return | |
| merge_tail = False | |
| if _seg_is_motion_with_shake(prev): | |
| merge_tail = True | |
| elif not _seg_is_shake_like(prev): | |
| merge_tail = True | |
| if not merge_tail: | |
| return | |
| target = (prev.mode, prev.primary, prev.secondary) | |
| for i in range(last.start_idx, last.end_idx + 1): | |
| set_state(windows[i], target[0], target[1], target[2]) | |
| def resolve_hysteresis_support(window_count: int, cfg_support: int) -> int: | |
| if window_count <= 3: | |
| return 1 | |
| if window_count <= 8: | |
| return min(cfg_support, 2) | |
| return max(1, cfg_support) | |
| def windows_for_variant( | |
| raw_windows: List[WindowPrediction], cv_priors: List[WindowCVPrior], cfg: VariantConfig | |
| ) -> List[WindowPrediction]: | |
| windows = [] | |
| for idx, w in enumerate(raw_windows): | |
| prior = cv_priors[idx] | |
| adjusted_scores = apply_cv_gates(w.raw_scores, prior, cfg) | |
| mode, p, s, conf = classify_scores(adjusted_scores, cfg) | |
| windows.append( | |
| WindowPrediction( | |
| start_sec=w.start_sec, | |
| end_sec=w.end_sec, | |
| raw_scores=adjusted_scores, | |
| raw_top_labels=w.raw_top_labels, | |
| mode=mode, | |
| primary=p, | |
| secondary=s, | |
| confidence=conf, | |
| ) | |
| ) | |
| majority_fill_uncertain(windows, radius=cfg.smooth_radius) | |
| support_need = resolve_hysteresis_support(len(windows), cfg.hysteresis_support) | |
| apply_hysteresis(windows, support_need=support_need) | |
| enforce_compound_support(windows, support_windows=cfg.compound_support_windows) | |
| if cfg.use_terminal_uncertain_absorb: | |
| absorb_terminal_uncertain( | |
| windows, | |
| max_windows=max(1, int(cfg.terminal_uncertain_max_windows)), | |
| neighbor_conf_min=float(cfg.terminal_uncertain_neighbor_conf_min), | |
| ) | |
| if cfg.use_tail_shake_rebound_collapse: | |
| collapse_tail_shake_rebound(windows, max_windows=max(1, int(cfg.tail_shake_max_windows))) | |
| return windows | |
| def _all_keyword_hits(text: str, keyword: str) -> List[int]: | |
| hits: List[int] = [] | |
| start = 0 | |
| while True: | |
| idx = text.find(keyword, start) | |
| if idx < 0: | |
| break | |
| hits.append(idx) | |
| start = idx + len(keyword) | |
| return hits | |
| def parse_gt_stages(desc: str) -> List[List[str]]: | |
| text = (desc or "").strip() | |
| if not text: | |
| return [] | |
| if "不验证" in text: | |
| return [] | |
| mentions: List[Tuple[int, int, str, str]] = [] | |
| has_nonzoom = "非变焦" in text | |
| for label, keywords in TOKEN_PATTERNS: | |
| for kw in keywords: | |
| for pos in _all_keyword_hits(text, kw): | |
| end = pos + len(kw) | |
| local = text[max(0, pos - 12) : min(len(text), end + 12)] | |
| local_nonzoom = "非变焦" in local | |
| # Prefer dolly_* on explicit non-zoom phrases. | |
| if label == "zoom_in" and (local_nonzoom or has_nonzoom) and kw in ( | |
| "变焦推", | |
| "变焦推进", | |
| "推镜头", | |
| "推进", | |
| "前推", | |
| "推", | |
| ): | |
| continue | |
| if label == "zoom_out" and (local_nonzoom or has_nonzoom) and kw in ( | |
| "变焦拉", | |
| "拉镜头", | |
| "后拉", | |
| "拉远", | |
| "拉", | |
| ): | |
| continue | |
| # Prefer zoom_* on explicit focal-length wording. | |
| if label == "dolly_in" and ((("变焦" in local) and (not local_nonzoom)) or ("焦距" in local)): | |
| continue | |
| if label == "dolly_out" and ((("变焦" in local) and (not local_nonzoom)) or ("焦距" in local)): | |
| continue | |
| mentions.append((pos, end, label, kw)) | |
| if "希区柯克" in text: | |
| hitch_pos = text.find("希区柯克") | |
| mentions.append((hitch_pos, hitch_pos + 4, "dolly_in", "希区柯克")) | |
| mentions.append((hitch_pos, hitch_pos + 4, "zoom_out", "希区柯克")) | |
| if not mentions: | |
| return [] | |
| mentions.sort(key=lambda x: (x[0], x[1], x[2], x[3])) | |
| stages: List[List[str]] = [[mentions[0][2]]] | |
| last_end = mentions[0][1] | |
| for pos, end, label, _kw in mentions[1:]: | |
| bridge = text[last_end:pos] | |
| if any(m in bridge for m in SIMULTANEOUS_MARKERS): | |
| if label not in stages[-1]: | |
| stages[-1].append(label) | |
| else: | |
| is_seq = any(m in bridge for m in SEQUENTIAL_MARKERS) | |
| is_short_bridge = len(bridge.strip()) <= 6 and ("," not in bridge) and ("。" not in bridge) | |
| if not is_seq and is_short_bridge: | |
| if label not in stages[-1]: | |
| stages[-1].append(label) | |
| elif label not in stages[-1]: | |
| stages.append([label]) | |
| last_end = max(last_end, end) | |
| deduped: List[List[str]] = [] | |
| for st in stages: | |
| uniq = [] | |
| for lb in st: | |
| if lb not in uniq: | |
| uniq.append(lb) | |
| if "希区柯克" in text: | |
| # "推/拉 + 希区柯克变焦" often means dolly + opposite zoom, not an extra generic zoom token. | |
| if "dolly_in" in uniq and "zoom_out" in uniq and "zoom_in" in uniq: | |
| uniq = [x for x in uniq if x != "zoom_in"] | |
| if "dolly_out" in uniq and "zoom_in" in uniq and "zoom_out" in uniq: | |
| uniq = [x for x in uniq if x != "zoom_out"] | |
| if not deduped or deduped[-1] != uniq: | |
| deduped.append(uniq) | |
| return deduped | |
| def predicted_stages_from_segments(segments: List[dict]) -> List[List[str]]: | |
| stages: List[List[str]] = [] | |
| for seg in segments: | |
| mode = seg["mode"] | |
| if mode == "uncertain": | |
| stages.append(["uncertain"]) | |
| continue | |
| lbs = [] | |
| if seg.get("primary"): | |
| lbs.append(seg["primary"]) | |
| if seg.get("secondary"): | |
| lbs.append(seg["secondary"]) | |
| if lbs: | |
| stages.append(lbs) | |
| return stages | |
| def stage_to_cn(stage: List[str]) -> str: | |
| return " + ".join(LABEL_CN_MAP.get(x, x) for x in stage) | |
| def strict_compare(expected: List[List[str]], predicted: List[List[str]]) -> Tuple[bool, List[str]]: | |
| reasons: List[str] = [] | |
| if any("uncertain" in st for st in predicted): | |
| reasons.append("contains_uncertain") | |
| if len(expected) == 0: | |
| reasons.append("gt_unparsed_or_skipped") | |
| return (False, reasons) | |
| if len(expected) != len(predicted): | |
| reasons.append(f"stage_count_mismatch exp={len(expected)} pred={len(predicted)}") | |
| if len(expected) == len(predicted): | |
| for i, (e, p) in enumerate(zip(expected, predicted), 1): | |
| if set(e) != set(p): | |
| reasons.append( | |
| f"stage_{i}_mismatch exp=({'+'.join(e)}) pred=({'+'.join(p)})" | |
| ) | |
| return (len(reasons) == 0, reasons) | |
| def parse_sample_ids_csv(text: str) -> Tuple[int, ...]: | |
| text = (text or "").strip() | |
| if not text: | |
| return () | |
| out: List[int] = [] | |
| for part in re.split(r"[,\s]+", text): | |
| if not part: | |
| continue | |
| out.append(int(part)) | |
| return tuple(out) | |
| def resolve_case_path(root: str, path: Optional[str]) -> Optional[str]: | |
| if not path: | |
| return None | |
| if os.path.isabs(path): | |
| return path | |
| return os.path.join(root, path) | |
| def load_gt_map(path: Optional[str]) -> Dict[int, str]: | |
| if not path: | |
| return {} | |
| with open(path, "r", encoding="utf-8") as f: | |
| obj = json.load(f) | |
| if isinstance(obj, list): | |
| return {int(x["id"]): str(x.get("desc", "")) for x in obj if isinstance(x, dict) and "id" in x} | |
| if isinstance(obj, dict): | |
| out: Dict[int, str] = {} | |
| for k, v in obj.items(): | |
| try: | |
| sid = int(k) | |
| except Exception: | |
| continue | |
| if isinstance(v, str): | |
| out[sid] = v | |
| elif isinstance(v, dict): | |
| out[sid] = str(v.get("desc", "")) | |
| else: | |
| out[sid] = "" | |
| return out | |
| raise RuntimeError(f"Unsupported GT json format: {path}") | |
| def build_cases(args: argparse.Namespace) -> Tuple[SampleCase, ...]: | |
| if args.video: | |
| if not args.gt_json: | |
| raise RuntimeError("Custom mode requires --gt-json") | |
| return ( | |
| SampleCase( | |
| name=args.case_name.strip() if args.case_name else "custom", | |
| video=args.video, | |
| shots_jsonl=args.shots_jsonl, | |
| gt_json=args.gt_json, | |
| sample_ids=parse_sample_ids_csv(args.sample_ids), | |
| ), | |
| ) | |
| selected = {x.strip() for x in (args.cases or "").split(",") if x.strip()} | |
| if not selected: | |
| return SAMPLE_CASES | |
| out = tuple(c for c in SAMPLE_CASES if c.name in selected) | |
| if not out: | |
| raise RuntimeError(f"--cases did not match built-in cases: {args.cases}") | |
| return out | |
| def _segments_to_judge_text(segments: List[dict]) -> str: | |
| lines = [] | |
| for s in segments: | |
| piece = s.get("narrative_piece", s.get("piece", "")) | |
| lines.append( | |
| ( | |
| f"{s['seg_idx']}. {s['start_sec']:.2f}-{s['end_sec']:.2f}s" | |
| f" | mode={s['mode']}" | |
| f" | primary={s['primary'] or '-'}" | |
| f" | secondary={s['secondary'] or '-'}" | |
| f" | piece={piece}" | |
| ) | |
| ) | |
| return "\n".join(lines) if lines else "(no segments)" | |
| def _extract_json_object(text: str) -> Dict[str, Any]: | |
| stripped = (text or "").strip() | |
| if stripped.startswith("```"): | |
| stripped = re.sub(r"^```(?:json)?\s*", "", stripped) | |
| stripped = re.sub(r"\s*```$", "", stripped) | |
| try: | |
| obj = json.loads(stripped) | |
| if isinstance(obj, dict): | |
| return obj | |
| except Exception: | |
| pass | |
| m = re.search(r"\{.*\}", stripped, re.S) | |
| if not m: | |
| raise RuntimeError("No JSON object found in judge output") | |
| obj = json.loads(m.group(0)) | |
| if not isinstance(obj, dict): | |
| raise RuntimeError("Judge output JSON is not an object") | |
| return obj | |
| def _as_int(v: Any, lo: int, hi: int) -> int: | |
| try: | |
| x = int(round(float(v))) | |
| except Exception: | |
| x = lo | |
| return min(hi, max(lo, x)) | |
| def _verdict_from_overall(overall_score: int) -> str: | |
| if overall_score >= LLM_JUDGE_PASS_THRESHOLD: | |
| return "pass" | |
| if overall_score >= 60: | |
| return "partial" | |
| return "fail" | |
| def _judge_result_template(status: str, model: str, reason: str) -> dict: | |
| return { | |
| "status": status, | |
| "model": model, | |
| "prompt_version": JUDGE_PROMPT_VERSION, | |
| "action_type_score": 0, | |
| "direction_score": 0, | |
| "order_score": 0, | |
| "replayability_score": 0, | |
| "overall_score": 0, | |
| "verdict": "fail" if status == "ok" else status, | |
| "reason": reason[:240], | |
| } | |
| def _build_judge_prompts( | |
| gt_desc: str, gt_stages_cn: List[str], pred_narrative_cn: str, segments: List[dict] | |
| ) -> Tuple[str, str]: | |
| sys_prompt = ( | |
| "你是运镜复拍评委。只看动作类型、方向、先后顺序、可复拍性。" | |
| "不要考虑文采。必须输出JSON。" | |
| ) | |
| user_prompt = ( | |
| "请按固定规则评分:\n" | |
| "1) action_type_score: 0-25,动作类型是否匹配。\n" | |
| "2) direction_score: 0-25,方向/正反是否匹配。\n" | |
| "3) order_score: 0-25,先后顺序是否匹配;并行复合不算顺序错。\n" | |
| "4) replayability_score: 0-25,仅凭预测描述,摄影师能否复拍。\n" | |
| "5) overall_score 必须是上述四项之和(0-100)。\n" | |
| "6) verdict: pass/partial/fail。阈值:pass>=85, partial>=60。\n" | |
| "7) reason: 用中文给一句短理由(<=60字)。\n\n" | |
| f"GT描述:\n{gt_desc}\n\n" | |
| f"GT解析阶段(供参考):\n{'; '.join(gt_stages_cn) if gt_stages_cn else '(empty)'}\n\n" | |
| f"预测叙述:\n{pred_narrative_cn}\n\n" | |
| f"预测分段:\n{_segments_to_judge_text(segments)}" | |
| ) | |
| return (sys_prompt, user_prompt) | |
| def _llm_judge_one_hf( | |
| client: Any, | |
| model: str, | |
| sys_prompt: str, | |
| user_prompt: str, | |
| ) -> Dict[str, Any]: | |
| last_err: Optional[Exception] = None | |
| parsed: Optional[Dict[str, Any]] = None | |
| for use_schema in (True, False): | |
| kwargs: Dict[str, Any] = { | |
| "model": model, | |
| "messages": [ | |
| {"role": "system", "content": sys_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| "temperature": 0.0, | |
| "max_tokens": 320, | |
| "seed": 42, | |
| } | |
| if use_schema: | |
| kwargs["response_format"] = JUDGE_JSON_SCHEMA | |
| try: | |
| resp = client.chat_completion(**kwargs) | |
| choices = getattr(resp, "choices", None) | |
| if not choices: | |
| raise RuntimeError("judge response has no choices") | |
| content = getattr(choices[0].message, "content", "") | |
| if isinstance(content, list): | |
| text = "".join(str(x.get("text", "")) if isinstance(x, dict) else str(x) for x in content) | |
| else: | |
| text = str(content) | |
| parsed = _extract_json_object(text) | |
| break | |
| except Exception as exc: # pragma: no cover - network/runtime path | |
| last_err = exc | |
| continue | |
| if parsed is None: | |
| err = f"judge_failed: {last_err}" if last_err else "judge_failed: unknown" | |
| raise RuntimeError(err) | |
| return parsed | |
| def _llm_judge_one_gemini( | |
| api_key: str, | |
| model: str, | |
| sys_prompt: str, | |
| user_prompt: str, | |
| ) -> Dict[str, Any]: | |
| endpoint = ( | |
| f"https://generativelanguage.googleapis.com/v1beta/models/" | |
| f"{urllib.parse.quote(model)}:generateContent?key={urllib.parse.quote(api_key)}" | |
| ) | |
| payload = { | |
| "systemInstruction": {"parts": [{"text": sys_prompt}]}, | |
| "contents": [{"role": "user", "parts": [{"text": user_prompt}]}], | |
| "generationConfig": { | |
| "temperature": 0, | |
| "responseMimeType": "application/json", | |
| "responseSchema": GEMINI_JUDGE_RESPONSE_SCHEMA, | |
| }, | |
| } | |
| body = json.dumps(payload).encode("utf-8") | |
| req = urllib.request.Request( | |
| endpoint, | |
| data=body, | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| try: | |
| with urllib.request.urlopen(req, timeout=90) as resp: | |
| raw = resp.read().decode("utf-8") | |
| except urllib.error.HTTPError as exc: | |
| detail = exc.read().decode("utf-8", errors="ignore") if hasattr(exc, "read") else str(exc) | |
| raise RuntimeError(f"gemini_http_error: {exc.code} {detail[:300]}") from exc | |
| except Exception as exc: | |
| raise RuntimeError(f"gemini_request_error: {exc}") from exc | |
| obj = json.loads(raw) | |
| cands = obj.get("candidates", []) | |
| if not cands: | |
| raise RuntimeError(f"gemini_no_candidates: {raw[:300]}") | |
| parts = cands[0].get("content", {}).get("parts", []) | |
| text = "".join(str(x.get("text", "")) for x in parts if isinstance(x, dict)) | |
| if not text.strip(): | |
| raise RuntimeError(f"gemini_empty_text: {raw[:300]}") | |
| return _extract_json_object(text) | |
| def llm_judge_one( | |
| *, | |
| provider: str, | |
| client: Any, | |
| model: str, | |
| gt_desc: str, | |
| gt_stages_cn: List[str], | |
| pred_narrative_cn: str, | |
| segments: List[dict], | |
| ) -> dict: | |
| if not gt_desc.strip(): | |
| return _judge_result_template("skipped", model, "missing_gt_desc") | |
| sys_prompt, user_prompt = _build_judge_prompts(gt_desc, gt_stages_cn, pred_narrative_cn, segments) | |
| try: | |
| if provider == "gemini": | |
| parsed = _llm_judge_one_gemini( | |
| api_key=str(client["api_key"]), | |
| model=model, | |
| sys_prompt=sys_prompt, | |
| user_prompt=user_prompt, | |
| ) | |
| else: | |
| parsed = _llm_judge_one_hf( | |
| client=client, | |
| model=model, | |
| sys_prompt=sys_prompt, | |
| user_prompt=user_prompt, | |
| ) | |
| except Exception as exc: | |
| return _judge_result_template("error", model, f"{provider}_judge_failed: {exc}") | |
| action = _as_int(parsed.get("action_type_score"), 0, 25) | |
| direction = _as_int(parsed.get("direction_score"), 0, 25) | |
| order = _as_int(parsed.get("order_score"), 0, 25) | |
| replayability = _as_int(parsed.get("replayability_score"), 0, 25) | |
| summed = action + direction + order + replayability | |
| overall = _as_int(parsed.get("overall_score"), 0, 100) | |
| if abs(overall - summed) > 5: | |
| overall = summed | |
| verdict = str(parsed.get("verdict", "")).strip().lower() | |
| derived = _verdict_from_overall(overall) | |
| if verdict not in {"pass", "partial", "fail"}: | |
| verdict = derived | |
| if verdict != derived: | |
| verdict = derived | |
| reason = str(parsed.get("reason", "")).strip() | |
| if not reason: | |
| reason = "评分由四项匹配程度综合给出。" | |
| out = _judge_result_template("ok", model, reason) | |
| out.update( | |
| { | |
| "action_type_score": action, | |
| "direction_score": direction, | |
| "order_score": order, | |
| "replayability_score": replayability, | |
| "overall_score": overall, | |
| "verdict": verdict, | |
| "reason": reason[:240], | |
| } | |
| ) | |
| return out | |
| def summarize_variant(rows: List[dict], variant_name: str, llm_enabled: bool) -> dict: | |
| scored = [r for r in rows if r[variant_name]["scorable"]] | |
| hits = sum(1 for r in scored if r[variant_name]["strict_pass"]) | |
| summary = { | |
| "scored_shots": len(scored), | |
| "strict_hits": hits, | |
| "strict_hit_rate": round((hits / len(scored)) if scored else 0.0, 4), | |
| "contains_uncertain": sum(1 for r in scored if r[variant_name]["has_uncertain"]), | |
| } | |
| if llm_enabled: | |
| judged = [r[variant_name].get("llm_judge", {}) for r in scored] | |
| valid = [x for x in judged if isinstance(x, dict) and x.get("status") == "ok"] | |
| summary["llm_judged"] = len(valid) | |
| summary["llm_overall_mean"] = round( | |
| float(mean(float(x.get("overall_score", 0.0)) for x in valid)) if valid else 0.0, 2 | |
| ) | |
| summary["llm_pass_rate"] = round( | |
| float(sum(1 for x in valid if x.get("verdict") == "pass") / len(valid)) if valid else 0.0, 4 | |
| ) | |
| summary["llm_error_count"] = sum(1 for x in judged if isinstance(x, dict) and x.get("status") == "error") | |
| return summary | |
| def run_benchmark( | |
| *, | |
| hf_token: str, | |
| output_json: str, | |
| output_csv: str, | |
| cases: Sequence[SampleCase], | |
| llm_judge: bool, | |
| judge_provider: str, | |
| judge_model: str, | |
| judge_token: str, | |
| gemini_api_key: str, | |
| judge_workers: int, | |
| max_shots: int, | |
| ) -> dict: | |
| root = os.path.abspath(os.path.dirname(__file__)) | |
| judge_client = None | |
| if llm_judge: | |
| if judge_provider == "gemini": | |
| key = gemini_api_key or os.environ.get("GEMINI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "") | |
| if not key: | |
| raise RuntimeError("Gemini judge requires --gemini-api-key or GEMINI_API_KEY/GOOGLE_API_KEY") | |
| judge_client = {"api_key": key} | |
| else: | |
| if InferenceClient is None: | |
| raise RuntimeError("huggingface_hub is required for --llm-judge with provider=hf") | |
| token = judge_token or hf_token | |
| if not token: | |
| raise RuntimeError("Judge token required: --judge-token or JUDGE_TOKEN/HF_TOKEN") | |
| judge_client = InferenceClient(token=token) | |
| inferencer = VideoMAEWindowInferencer(MODEL_ID, hf_token) | |
| rows: List[dict] = [] | |
| for case in cases: | |
| video_path = resolve_case_path(root, case.video) | |
| shots_path = resolve_case_path(root, case.shots_jsonl) | |
| gt_path = resolve_case_path(root, case.gt_json) | |
| if not video_path: | |
| raise RuntimeError(f"Missing video path for case={case.name}") | |
| print(f"[CASE] {case.name}", flush=True) | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Cannot open video: {video_path}") | |
| fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| if fps <= 1e-8 or total_frames <= 0: | |
| cap.release() | |
| raise RuntimeError(f"Invalid metadata for {video_path}") | |
| duration_sec = total_frames / fps | |
| if shots_path: | |
| shots = load_shots_jsonl(shots_path, fps=fps, video_duration_sec=duration_sec) | |
| else: | |
| shots = [ShotBoundary(shot_id=1, start_sec=0.0, end_sec=duration_sec)] | |
| shot_map: Dict[int, ShotBoundary] = {s.shot_id: s for s in shots} | |
| gt_map = load_gt_map(gt_path) | |
| selected_ids: Tuple[int, ...] | |
| if case.sample_ids: | |
| selected_ids = tuple(x for x in case.sample_ids if x in shot_map) | |
| else: | |
| selected_ids = tuple(sorted(shot_map.keys())) | |
| if max_shots > 0: | |
| selected_ids = selected_ids[:max_shots] | |
| for sid in selected_ids: | |
| shot = shot_map[sid] | |
| desc = gt_map.get(sid, "") | |
| expected_stages = parse_gt_stages(desc) | |
| print(f" [SHOT] {case.name}#{sid}", flush=True) | |
| ranges = build_window_ranges(shot.start_sec, shot.end_sec) | |
| index_lists = build_window_frame_index_lists(ranges, fps=fps, total_frames=total_frames) | |
| frames_by_window = sample_frames_for_window_indices( | |
| cap, | |
| inferencer, | |
| index_lists=index_lists, | |
| fps=fps, | |
| total_frames=total_frames, | |
| ) | |
| raw_windows: List[WindowPrediction] = [] | |
| cv_priors: List[WindowCVPrior] = [] | |
| for (s, e), frames in zip(ranges, frames_by_window): | |
| raw_windows.append(infer_raw_window_from_frames(inferencer, frames, s, e)) | |
| cv_priors.append(compute_window_cv_prior(frames)) | |
| entry = { | |
| "dataset": case.name, | |
| "shot_id": sid, | |
| "shot_start_sec": round(shot.start_sec, 3), | |
| "shot_end_sec": round(shot.end_sec, 3), | |
| "gt_desc": desc, | |
| "gt_stages": expected_stages, | |
| "gt_stages_cn": [stage_to_cn(st) for st in expected_stages], | |
| } | |
| variant_payload_map: Dict[str, dict] = {} | |
| for cfg in VARIANTS: | |
| windows = windows_for_variant(raw_windows, cv_priors, cfg) | |
| enforce_min_duration(windows, shot.end_sec - shot.start_sec) | |
| enforce_max_segments(windows) | |
| segments = build_segments_output(windows) | |
| pred_stages = predicted_stages_from_segments(segments) | |
| ok, reasons = strict_compare(expected_stages, pred_stages) | |
| has_uncertain = any(seg["mode"] == "uncertain" for seg in segments) | |
| narrative_cn = build_shot_narrative(segments) | |
| variant_payload = { | |
| "narrative_cn": narrative_cn, | |
| "segments": [ | |
| { | |
| "seg_idx": s["seg_idx"], | |
| "start_sec": s["start_sec"], | |
| "end_sec": s["end_sec"], | |
| "mode": s["mode"], | |
| "primary": s["primary"], | |
| "secondary": s["secondary"], | |
| "confidence": s["confidence"], | |
| "piece": s["narrative_piece"], | |
| } | |
| for s in segments | |
| ], | |
| "pred_stages": pred_stages, | |
| "pred_stages_cn": [stage_to_cn(st) for st in pred_stages], | |
| "strict_pass": ok, | |
| "strict_reasons": reasons, | |
| "has_uncertain": has_uncertain, | |
| "scorable": len(expected_stages) > 0, | |
| } | |
| variant_payload_map[cfg.name] = variant_payload | |
| if llm_judge: | |
| if judge_client is None: | |
| raise RuntimeError("judge client init failed") | |
| def _judge_variant(name: str, payload: dict) -> Tuple[str, dict]: | |
| judge = llm_judge_one( | |
| provider=judge_provider, | |
| client=judge_client, | |
| model=judge_model, | |
| gt_desc=desc, | |
| gt_stages_cn=entry["gt_stages_cn"], | |
| pred_narrative_cn=payload["narrative_cn"], | |
| segments=payload["segments"], | |
| ) | |
| return (name, judge) | |
| jobs = list(variant_payload_map.items()) | |
| workers = max(1, min(int(judge_workers), len(jobs))) | |
| if workers == 1: | |
| for name, payload in jobs: | |
| payload["llm_judge"] = _judge_variant(name, payload)[1] | |
| else: | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as ex: | |
| futs = [ex.submit(_judge_variant, name, payload) for name, payload in jobs] | |
| for fut in concurrent.futures.as_completed(futs): | |
| name, judge = fut.result() | |
| variant_payload_map[name]["llm_judge"] = judge | |
| for cfg in VARIANTS: | |
| entry[cfg.name] = variant_payload_map[cfg.name] | |
| entry["window_cv_mean"] = { | |
| "motion": round(float(mean(x.mean_flow_mag for x in cv_priors)) if cv_priors else 0.0, 4), | |
| "jitter": round(float(mean(x.jitter_ratio for x in cv_priors)) if cv_priors else 0.0, 4), | |
| "shake": round(float(mean(x.shake_conf for x in cv_priors)) if cv_priors else 0.0, 4), | |
| "static": round(float(mean(x.static_conf for x in cv_priors)) if cv_priors else 0.0, 4), | |
| "zoom_bias": round(float(mean(x.zoom_bias for x in cv_priors)) if cv_priors else 0.0, 4), | |
| "dolly_bias": round(float(mean(x.dolly_bias for x in cv_priors)) if cv_priors else 0.0, 4), | |
| } | |
| rows.append(entry) | |
| cap.release() | |
| summary = {cfg.name: summarize_variant(rows, cfg.name, llm_enabled=llm_judge) for cfg in VARIANTS} | |
| summary["delta_B_minus_A"] = round( | |
| summary[VARIANT_B.name]["strict_hit_rate"] - summary[VARIANT_A.name]["strict_hit_rate"], 4 | |
| ) | |
| summary["delta_C_minus_B"] = round( | |
| summary[VARIANT_C.name]["strict_hit_rate"] - summary[VARIANT_B.name]["strict_hit_rate"], 4 | |
| ) | |
| summary["delta_D_minus_C"] = round( | |
| summary[VARIANT_D.name]["strict_hit_rate"] - summary[VARIANT_C.name]["strict_hit_rate"], 4 | |
| ) | |
| summary["delta_D_minus_B"] = round( | |
| summary[VARIANT_D.name]["strict_hit_rate"] - summary[VARIANT_B.name]["strict_hit_rate"], 4 | |
| ) | |
| if llm_judge: | |
| summary["llm_delta_B_minus_A"] = round( | |
| summary[VARIANT_B.name]["llm_overall_mean"] - summary[VARIANT_A.name]["llm_overall_mean"], 2 | |
| ) | |
| summary["llm_delta_C_minus_B"] = round( | |
| summary[VARIANT_C.name]["llm_overall_mean"] - summary[VARIANT_B.name]["llm_overall_mean"], 2 | |
| ) | |
| summary["llm_delta_D_minus_C"] = round( | |
| summary[VARIANT_D.name]["llm_overall_mean"] - summary[VARIANT_C.name]["llm_overall_mean"], 2 | |
| ) | |
| payload = { | |
| "model_id": MODEL_ID, | |
| "variants": [asdict(v) for v in VARIANTS], | |
| "llm_judge": { | |
| "enabled": llm_judge, | |
| "provider": judge_provider if llm_judge else "", | |
| "model": judge_model if llm_judge else "", | |
| "prompt_version": JUDGE_PROMPT_VERSION, | |
| "pass_threshold": LLM_JUDGE_PASS_THRESHOLD, | |
| }, | |
| "cases": [asdict(c) for c in cases], | |
| "summary": summary, | |
| "rows": rows, | |
| } | |
| with open(output_json, "w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2) | |
| def llm_fields(v: dict) -> Tuple[Any, Any, Any]: | |
| j = v.get("llm_judge", {}) | |
| if j.get("status") == "ok": | |
| return (j.get("overall_score", ""), j.get("verdict", ""), j.get("reason", "")) | |
| return ("", j.get("status", ""), j.get("reason", "")) | |
| with open(output_csv, "w", newline="", encoding="utf-8") as f: | |
| w = csv.writer(f) | |
| w.writerow( | |
| [ | |
| "dataset", | |
| "shot_id", | |
| "gt_desc", | |
| "gt_stages_cn", | |
| "A_pass", | |
| "A_pred_stages_cn", | |
| "A_reasons", | |
| "A_llm_score", | |
| "A_llm_verdict", | |
| "A_llm_reason", | |
| "B_pass", | |
| "B_pred_stages_cn", | |
| "B_reasons", | |
| "B_llm_score", | |
| "B_llm_verdict", | |
| "B_llm_reason", | |
| "C_pass", | |
| "C_pred_stages_cn", | |
| "C_reasons", | |
| "C_llm_score", | |
| "C_llm_verdict", | |
| "C_llm_reason", | |
| "D_pass", | |
| "D_pred_stages_cn", | |
| "D_reasons", | |
| "D_llm_score", | |
| "D_llm_verdict", | |
| "D_llm_reason", | |
| ] | |
| ) | |
| for r in rows: | |
| a = r[VARIANT_A.name] | |
| b = r[VARIANT_B.name] | |
| c = r[VARIANT_C.name] | |
| d = r[VARIANT_D.name] | |
| a_llm = llm_fields(a) | |
| b_llm = llm_fields(b) | |
| c_llm = llm_fields(c) | |
| d_llm = llm_fields(d) | |
| w.writerow( | |
| [ | |
| r["dataset"], | |
| r["shot_id"], | |
| r["gt_desc"], | |
| " | ".join(r["gt_stages_cn"]), | |
| int(a["strict_pass"]), | |
| " | ".join(a["pred_stages_cn"]), | |
| " | ".join(a["strict_reasons"]), | |
| a_llm[0], | |
| a_llm[1], | |
| a_llm[2], | |
| int(b["strict_pass"]), | |
| " | ".join(b["pred_stages_cn"]), | |
| " | ".join(b["strict_reasons"]), | |
| b_llm[0], | |
| b_llm[1], | |
| b_llm[2], | |
| int(c["strict_pass"]), | |
| " | ".join(c["pred_stages_cn"]), | |
| " | ".join(c["strict_reasons"]), | |
| c_llm[0], | |
| c_llm[1], | |
| c_llm[2], | |
| int(d["strict_pass"]), | |
| " | ".join(d["pred_stages_cn"]), | |
| " | ".join(d["strict_reasons"]), | |
| d_llm[0], | |
| d_llm[1], | |
| d_llm[2], | |
| ] | |
| ) | |
| print(f"[OK] wrote {output_json}") | |
| print(f"[OK] wrote {output_csv}") | |
| print("[SUMMARY]") | |
| print(json.dumps(summary, ensure_ascii=False, indent=2)) | |
| return payload | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="A/B strict replayability benchmark for VideoMAE timeline") | |
| parser.add_argument("--hf-token", default=os.environ.get("HF_TOKEN", ""), help="HF token or set HF_TOKEN") | |
| parser.add_argument("--output-json", default="ab_strict_report.json") | |
| parser.add_argument("--output-csv", default="ab_strict_report.csv") | |
| parser.add_argument("--cases", default="", help="Built-in cases filter, comma separated (baseus,runner,vertical)") | |
| parser.add_argument("--max-shots", type=int, default=0, help="Cap processed shots per case; 0 means all selected") | |
| parser.add_argument("--video", default="", help="Custom video path; when set, run custom mode") | |
| parser.add_argument("--shots-jsonl", default="", help="Custom shot boundary JSONL path (optional)") | |
| parser.add_argument("--gt-json", default="", help="Custom GT json path (required in custom mode)") | |
| parser.add_argument("--sample-ids", default="", help="Custom sample shot ids, comma separated; default all") | |
| parser.add_argument("--case-name", default="custom", help="Custom case display name") | |
| parser.add_argument("--llm-judge", action="store_true", help="Enable LLM-as-judge scoring") | |
| parser.add_argument( | |
| "--judge-provider", | |
| choices=("hf", "gemini"), | |
| default=os.environ.get("JUDGE_PROVIDER", DEFAULT_JUDGE_PROVIDER), | |
| help="Judge backend provider", | |
| ) | |
| parser.add_argument("--judge-model", default=os.environ.get("JUDGE_MODEL", DEFAULT_JUDGE_MODEL)) | |
| parser.add_argument("--judge-token", default=os.environ.get("JUDGE_TOKEN", "")) | |
| parser.add_argument("--gemini-api-key", default=os.environ.get("GEMINI_API_KEY", os.environ.get("GOOGLE_API_KEY", ""))) | |
| parser.add_argument("--judge-workers", type=int, default=4, help="Parallel workers for judge calls per shot") | |
| args = parser.parse_args() | |
| if not args.hf_token: | |
| raise RuntimeError("HF token required: --hf-token or HF_TOKEN") | |
| if args.llm_judge and args.judge_provider == "gemini" and args.judge_model == DEFAULT_JUDGE_MODEL: | |
| args.judge_model = os.environ.get("GEMINI_MODEL", DEFAULT_GEMINI_MODEL) | |
| cases = build_cases(args) | |
| run_benchmark( | |
| hf_token=args.hf_token, | |
| output_json=args.output_json, | |
| output_csv=args.output_csv, | |
| cases=cases, | |
| llm_judge=bool(args.llm_judge), | |
| judge_provider=args.judge_provider, | |
| judge_model=args.judge_model, | |
| judge_token=args.judge_token, | |
| gemini_api_key=args.gemini_api_key, | |
| judge_workers=max(1, int(args.judge_workers)), | |
| max_shots=max(0, int(args.max_shots)), | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |