Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Generate stable shot timelines from kandinsky VideoMAE window predictions. | |
| Pipeline (fixed constants): | |
| 1) Sliding-window inference (16 frames per window). | |
| 2) Uncertainty handling with margin and undefined score. | |
| 3) Window-level stabilization (majority fill + hysteresis). | |
| 4) Segment constraints (minimum length, maximum segment count). | |
| 5) Chinese narrative generation with single/compound/uncertain modes. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import os | |
| import re | |
| from dataclasses import dataclass | |
| from statistics import mean | |
| from typing import Dict, List, Optional, Sequence, Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor | |
| # --------------------------- | |
| # Fixed constants (locked) | |
| # --------------------------- | |
| MODEL_ID = "ai-forever/kandinsky-videomae-large-camera-motion" | |
| WINDOW_SEC = 1.6 | |
| STRIDE_SEC = 0.4 | |
| NUM_FRAMES = 16 | |
| UNCERTAIN_P1_MIN = 0.45 | |
| UNCERTAIN_MARGIN_MIN = 0.12 | |
| UNCERTAIN_UNDEFINED_MIN = 0.50 | |
| SECONDARY_MIN = 0.50 | |
| SECONDARY_GAP_MAX = 0.18 | |
| SMOOTH_RADIUS = 2 | |
| HYSTERESIS_SUPPORT = 2 | |
| MIN_SEG_SEC = 0.8 | |
| MAX_SEGMENTS_PER_SHOT = 3 | |
| SHOT_BOUNDARY_GUARD_FRAMES = 1 | |
| FRAME_INDEX_EPS = 1e-9 | |
| LABEL_CN_MAP: Dict[str, str] = { | |
| "arc_left": "左弧绕", | |
| "arc_right": "右弧绕", | |
| "dolly_in": "机位前推", | |
| "dolly_out": "机位后拉", | |
| "pan_left": "左摇", | |
| "pan_right": "右摇", | |
| "pedestal_down": "机位下移", | |
| "pedestal_up": "机位上移", | |
| "pov": "主观视角", | |
| "roll_left": "左滚转", | |
| "roll_right": "右滚转", | |
| "shake": "抖动", | |
| "static": "固定镜头", | |
| "tilt_down": "下俯", | |
| "tilt_up": "上仰", | |
| "track": "跟拍", | |
| "truck_left": "机位左移", | |
| "truck_right": "机位右移", | |
| "undefined": "未定义", | |
| "zoom_in": "变焦推", | |
| "zoom_out": "变焦拉", | |
| "uncertain": "复杂/不确定", | |
| } | |
| PRIMARY_LABELS: Tuple[str, ...] = ( | |
| "arc_left", | |
| "arc_right", | |
| "pan_left", | |
| "pan_right", | |
| "truck_left", | |
| "truck_right", | |
| "tilt_up", | |
| "tilt_down", | |
| "pedestal_up", | |
| "pedestal_down", | |
| "dolly_in", | |
| "dolly_out", | |
| "zoom_in", | |
| "zoom_out", | |
| "roll_left", | |
| "roll_right", | |
| "static", | |
| ) | |
| # Secondary can come from any non-undefined label (including track/pov/shake). | |
| SECONDARY_CANDIDATES: Tuple[str, ...] = tuple( | |
| lb for lb in LABEL_CN_MAP.keys() if lb not in ("undefined", "uncertain") | |
| ) | |
| class ShotBoundary: | |
| shot_id: int | |
| start_sec: float | |
| end_sec: float | |
| class WindowPrediction: | |
| start_sec: float | |
| end_sec: float | |
| raw_scores: Dict[str, float] | |
| raw_top_labels: List[Dict[str, float]] | |
| mode: str # single | compound | uncertain | |
| primary: Optional[str] | |
| secondary: Optional[str] | |
| confidence: float | |
| def parse_timecode(tc: str, fps: float) -> float: | |
| tc = tc.strip() | |
| m = re.fullmatch(r"(\d+)m(\d+)s(\d+)f", tc) | |
| if m: | |
| mins = int(m.group(1)) | |
| secs = int(m.group(2)) | |
| frames = int(m.group(3)) | |
| return mins * 60.0 + secs + (frames / fps if fps > 1e-8 else 0.0) | |
| m = re.fullmatch(r"(\d+)m(\d+)s(\d+)", tc) | |
| if m: | |
| mins = int(m.group(1)) | |
| secs = int(m.group(2)) | |
| frac_raw = m.group(3) | |
| frac = int(frac_raw) / (10 ** len(frac_raw)) | |
| return mins * 60.0 + secs + frac | |
| raise ValueError(f"Unsupported timecode: {tc}") | |
| def should_apply_boundary_guard(start_tc: str, end_tc: str) -> bool: | |
| # mmssms is usually converted from frame-based boundaries and can carry | |
| # +/-1 frame ambiguity after decimal rounding. | |
| return not (start_tc.strip().endswith("f") and end_tc.strip().endswith("f")) | |
| def sanitize_shot_bounds( | |
| start_sec: float, | |
| end_sec: float, | |
| fps: float, | |
| video_duration_sec: float, | |
| use_boundary_guard: bool, | |
| ) -> Tuple[float, float]: | |
| start_sec = max(0.0, min(video_duration_sec, start_sec)) | |
| end_sec = max(start_sec, min(video_duration_sec, end_sec)) | |
| if not use_boundary_guard or fps <= 1e-8: | |
| return (start_sec, end_sec) | |
| guard_sec = SHOT_BOUNDARY_GUARD_FRAMES / fps | |
| if end_sec - start_sec > (2.0 * guard_sec): | |
| start_sec += guard_sec | |
| end_sec -= guard_sec | |
| return (start_sec, end_sec) | |
| def compute_frame_span(start_sec: float, end_sec: float, fps: float, total_frames: int) -> Tuple[int, int]: | |
| if fps <= 1e-8 or total_frames <= 0: | |
| raise ValueError(f"Invalid fps/frames: fps={fps}, total_frames={total_frames}") | |
| # Use half-open interval [start, end) to avoid tail-frame bleed into next shot. | |
| f0 = int(np.ceil(start_sec * fps - FRAME_INDEX_EPS)) | |
| f1 = int(np.floor(end_sec * fps - FRAME_INDEX_EPS)) | |
| f0 = max(0, min(total_frames - 1, f0)) | |
| f1 = max(0, min(total_frames - 1, f1)) | |
| if f1 < f0: | |
| mid = int(round(((start_sec + end_sec) * 0.5) * fps)) | |
| mid = max(0, min(total_frames - 1, mid)) | |
| return (mid, mid) | |
| return (f0, f1) | |
| def load_shots_jsonl(path: str, fps: float, video_duration_sec: float) -> List[ShotBoundary]: | |
| rows: List[ShotBoundary] = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for idx, line in enumerate(f, 1): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| obj = json.loads(line) | |
| shot_id = int(obj.get("shot_index", idx)) | |
| start_tc = str(obj["shot_start"]) | |
| end_tc = str(obj["shot_end"]) | |
| start_sec = parse_timecode(start_tc, fps) | |
| end_sec = parse_timecode(end_tc, fps) | |
| use_guard = should_apply_boundary_guard(start_tc, end_tc) | |
| start_sec, end_sec = sanitize_shot_bounds( | |
| start_sec=start_sec, | |
| end_sec=end_sec, | |
| fps=fps, | |
| video_duration_sec=video_duration_sec, | |
| use_boundary_guard=use_guard, | |
| ) | |
| rows.append(ShotBoundary(shot_id=shot_id, start_sec=start_sec, end_sec=end_sec)) | |
| if not rows: | |
| raise RuntimeError(f"No valid shots loaded from: {path}") | |
| return rows | |
| def default_full_shot(video_duration_sec: float) -> List[ShotBoundary]: | |
| return [ShotBoundary(shot_id=1, start_sec=0.0, end_sec=video_duration_sec)] | |
| def build_window_ranges(shot_start: float, shot_end: float) -> List[Tuple[float, float]]: | |
| duration = max(0.0, shot_end - shot_start) | |
| if duration <= 1e-6: | |
| return [(shot_start, shot_end)] | |
| if duration <= WINDOW_SEC: | |
| return [(shot_start, shot_end)] | |
| starts: List[float] = [] | |
| cur = shot_start | |
| limit = shot_end - WINDOW_SEC | |
| while cur <= limit + 1e-9: | |
| starts.append(cur) | |
| cur += STRIDE_SEC | |
| # Ensure tail coverage. | |
| if starts and abs(starts[-1] - limit) > 1e-6: | |
| starts.append(limit) | |
| elif not starts: | |
| starts.append(shot_start) | |
| return [(s, min(shot_end, s + WINDOW_SEC)) for s in starts] | |
| def label_cn(label: Optional[str]) -> str: | |
| if not label: | |
| return "未定义" | |
| return LABEL_CN_MAP.get(label, label) | |
| def axis_of(label: str) -> Optional[Tuple[str, str]]: | |
| if label.endswith("_left"): | |
| return ("h", "left") | |
| if label.endswith("_right"): | |
| return ("h", "right") | |
| if label.endswith("_up"): | |
| return ("v", "up") | |
| if label.endswith("_down"): | |
| return ("v", "down") | |
| if label.endswith("_in"): | |
| return ("d", "in") | |
| if label.endswith("_out"): | |
| return ("d", "out") | |
| return None | |
| def is_axis_conflict(a: str, b: str) -> bool: | |
| # Keep conflicts within the same motion family (e.g. pan_left vs pan_right). | |
| # Cross-family combos like dolly_in + zoom_out are explicitly allowed. | |
| if a.split("_", 1)[0] != b.split("_", 1)[0]: | |
| return False | |
| aa = axis_of(a) | |
| bb = axis_of(b) | |
| if aa is None or bb is None: | |
| return False | |
| return aa[0] == bb[0] and aa[1] != bb[1] | |
| def classify_raw_scores(raw_scores: Dict[str, float]) -> Tuple[str, Optional[str], Optional[str], float]: | |
| scored_primary = sorted( | |
| ((lb, raw_scores.get(lb, 0.0)) for lb in PRIMARY_LABELS), | |
| key=lambda x: (x[1], x[0]), | |
| reverse=True, | |
| ) | |
| p1_label, p1 = scored_primary[0] | |
| p2 = scored_primary[1][1] if len(scored_primary) > 1 else 0.0 | |
| margin = p1 - p2 | |
| undef = raw_scores.get("undefined", 0.0) | |
| if p1 < UNCERTAIN_P1_MIN or margin < UNCERTAIN_MARGIN_MIN or undef >= UNCERTAIN_UNDEFINED_MIN: | |
| return ("uncertain", None, None, p1) | |
| if p1_label == "static": | |
| return ("single", p1_label, None, p1) | |
| scored_secondary = sorted( | |
| ( | |
| (lb, raw_scores.get(lb, 0.0)) | |
| for lb in SECONDARY_CANDIDATES | |
| if lb not in ("undefined", p1_label) | |
| ), | |
| key=lambda x: (x[1], x[0]), | |
| reverse=True, | |
| ) | |
| secondary: Optional[str] = None | |
| secondary_score = 0.0 | |
| for lb, score in scored_secondary: | |
| if score < SECONDARY_MIN: | |
| break | |
| if p1 - score > SECONDARY_GAP_MAX: | |
| continue | |
| if is_axis_conflict(p1_label, lb): | |
| continue | |
| secondary = lb | |
| secondary_score = score | |
| break | |
| if secondary: | |
| return ("compound", p1_label, secondary, min(p1, secondary_score)) | |
| return ("single", p1_label, None, p1) | |
| def state_key(w: WindowPrediction) -> Tuple[str, Optional[str], Optional[str]]: | |
| return (w.mode, w.primary, w.secondary) | |
| def set_state(w: WindowPrediction, mode: str, primary: Optional[str], secondary: Optional[str]) -> None: | |
| w.mode = mode | |
| w.primary = primary | |
| w.secondary = secondary | |
| def majority_fill_uncertain(windows: List[WindowPrediction]) -> None: | |
| n = len(windows) | |
| for i, w in enumerate(windows): | |
| if w.mode != "uncertain": | |
| continue | |
| lo = max(0, i - SMOOTH_RADIUS) | |
| hi = min(n - 1, i + SMOOTH_RADIUS) | |
| counter: Dict[Tuple[str, Optional[str], Optional[str]], int] = {} | |
| for j in range(lo, hi + 1): | |
| if j == i: | |
| continue | |
| k = state_key(windows[j]) | |
| if k[0] == "uncertain": | |
| continue | |
| counter[k] = counter.get(k, 0) + 1 | |
| if not counter: | |
| continue | |
| ranked = sorted(counter.items(), key=lambda kv: (kv[1], kv[0]), reverse=True) | |
| best_k, best_cnt = ranked[0] | |
| second_cnt = ranked[1][1] if len(ranked) > 1 else -1 | |
| if best_cnt >= 2 and best_cnt > second_cnt: | |
| set_state(w, best_k[0], best_k[1], best_k[2]) | |
| def apply_hysteresis(windows: List[WindowPrediction]) -> None: | |
| if not windows: | |
| return | |
| states = [state_key(w) for w in windows] | |
| out: List[Tuple[str, Optional[str], Optional[str]]] = [] | |
| current = states[0] | |
| out.append(current) | |
| for i in range(1, len(states)): | |
| cand = states[i] | |
| if cand == current: | |
| out.append(current) | |
| continue | |
| support = 1 | |
| j = i + 1 | |
| while j < len(states) and states[j] == cand and support < HYSTERESIS_SUPPORT: | |
| support += 1 | |
| j += 1 | |
| if support >= HYSTERESIS_SUPPORT: | |
| current = cand | |
| out.append(current) | |
| else: | |
| out.append(current) | |
| for w, st in zip(windows, out): | |
| set_state(w, st[0], st[1], st[2]) | |
| class SegmentMeta: | |
| seg_idx: int | |
| start_idx: int | |
| end_idx: int | |
| mode: str | |
| primary: Optional[str] | |
| secondary: Optional[str] | |
| duration_sec: float | |
| confidence: float | |
| def build_segment_meta(windows: List[WindowPrediction]) -> List[SegmentMeta]: | |
| if not windows: | |
| return [] | |
| segments: List[SegmentMeta] = [] | |
| s = 0 | |
| for i in range(1, len(windows) + 1): | |
| is_break = i == len(windows) or state_key(windows[i]) != state_key(windows[s]) | |
| if not is_break: | |
| continue | |
| st = state_key(windows[s]) | |
| win_slice = windows[s:i] | |
| duration = max(0.0, win_slice[-1].end_sec - win_slice[0].start_sec) | |
| conf = mean(w.confidence for w in win_slice) if win_slice else 0.0 | |
| segments.append( | |
| SegmentMeta( | |
| seg_idx=len(segments), | |
| start_idx=s, | |
| end_idx=i - 1, | |
| mode=st[0], | |
| primary=st[1], | |
| secondary=st[2], | |
| duration_sec=duration, | |
| confidence=conf, | |
| ) | |
| ) | |
| s = i | |
| return segments | |
| def state_similarity(a: SegmentMeta, b: SegmentMeta) -> float: | |
| score = 0.0 | |
| if a.mode == b.mode: | |
| score += 3.0 | |
| if a.primary and b.primary and a.primary == b.primary: | |
| score += 2.0 | |
| if a.secondary and b.secondary and a.secondary == b.secondary: | |
| score += 1.0 | |
| return score | |
| def choose_merge_target(meta: List[SegmentMeta], idx: int) -> int: | |
| if idx <= 0: | |
| return 1 | |
| if idx >= len(meta) - 1: | |
| return len(meta) - 2 | |
| cur = meta[idx] | |
| left = meta[idx - 1] | |
| right = meta[idx + 1] | |
| left_key = (state_similarity(cur, left), left.confidence, left.duration_sec) | |
| right_key = (state_similarity(cur, right), right.confidence, right.duration_sec) | |
| return idx - 1 if left_key >= right_key else idx + 1 | |
| def relabel_window_range( | |
| windows: List[WindowPrediction], | |
| start_idx: int, | |
| end_idx: int, | |
| target_state: Tuple[str, Optional[str], Optional[str]], | |
| ) -> None: | |
| for i in range(start_idx, end_idx + 1): | |
| set_state(windows[i], target_state[0], target_state[1], target_state[2]) | |
| def enforce_min_duration(windows: List[WindowPrediction], shot_duration_sec: float) -> None: | |
| if shot_duration_sec < MIN_SEG_SEC: | |
| return | |
| safety = 0 | |
| while safety < 1000: | |
| safety += 1 | |
| meta = build_segment_meta(windows) | |
| if len(meta) <= 1: | |
| break | |
| short_idxs = [i for i, seg in enumerate(meta) if seg.duration_sec < MIN_SEG_SEC] | |
| if not short_idxs: | |
| break | |
| # Merge the shortest + lowest-confidence short segment first. | |
| idx = min(short_idxs, key=lambda i: (meta[i].duration_sec, meta[i].confidence, i)) | |
| target_idx = choose_merge_target(meta, idx) | |
| target_state = (meta[target_idx].mode, meta[target_idx].primary, meta[target_idx].secondary) | |
| relabel_window_range(windows, meta[idx].start_idx, meta[idx].end_idx, target_state) | |
| def enforce_max_segments(windows: List[WindowPrediction]) -> None: | |
| safety = 0 | |
| while safety < 1000: | |
| safety += 1 | |
| meta = build_segment_meta(windows) | |
| if len(meta) <= MAX_SEGMENTS_PER_SHOT: | |
| break | |
| idx = min(range(len(meta)), key=lambda i: (meta[i].duration_sec, meta[i].confidence, i)) | |
| target_idx = choose_merge_target(meta, idx) | |
| target_state = (meta[target_idx].mode, meta[target_idx].primary, meta[target_idx].secondary) | |
| relabel_window_range(windows, meta[idx].start_idx, meta[idx].end_idx, target_state) | |
| def segment_piece(mode: str, primary: Optional[str], secondary: Optional[str], raw_scores_mean: Dict[str, float]) -> str: | |
| if mode == "uncertain": | |
| primary_rank = sorted( | |
| ((lb, raw_scores_mean.get(lb, 0.0)) for lb in PRIMARY_LABELS), | |
| key=lambda x: (x[1], x[0]), | |
| reverse=True, | |
| ) | |
| cands = [label_cn(primary_rank[0][0])] | |
| if len(primary_rank) > 1: | |
| cands.append(label_cn(primary_rank[1][0])) | |
| return f"复杂/不确定(候选: {', '.join(cands)})" | |
| if mode == "compound" and primary and secondary: | |
| return f"{label_cn(primary)} + {label_cn(secondary)}" | |
| return label_cn(primary) | |
| def build_segments_output(windows: List[WindowPrediction]) -> List[dict]: | |
| meta = build_segment_meta(windows) | |
| out: List[dict] = [] | |
| for idx, seg in enumerate(meta, 1): | |
| win_slice = windows[seg.start_idx : seg.end_idx + 1] | |
| raw_scores_mean: Dict[str, float] = {} | |
| for lb in LABEL_CN_MAP.keys(): | |
| if lb == "uncertain": | |
| continue | |
| raw_scores_mean[lb] = mean(w.raw_scores.get(lb, 0.0) for w in win_slice) | |
| raw_top = sorted(raw_scores_mean.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)[:8] | |
| piece = segment_piece(seg.mode, seg.primary, seg.secondary, raw_scores_mean) | |
| out.append( | |
| { | |
| "seg_idx": idx, | |
| "start_sec": round(win_slice[0].start_sec, 3), | |
| "end_sec": round(win_slice[-1].end_sec, 3), | |
| "mode": seg.mode, | |
| "primary": seg.primary, | |
| "secondary": seg.secondary, | |
| "confidence": round(seg.confidence, 4), | |
| "raw_top_labels": [{"label": lb, "score": round(sc, 4)} for lb, sc in raw_top], | |
| "raw_scores_mean": {k: round(v, 4) for k, v in raw_scores_mean.items()}, | |
| "narrative_piece": piece, | |
| } | |
| ) | |
| return out | |
| def build_shot_narrative(segments: List[dict]) -> str: | |
| pieces = [seg["narrative_piece"] for seg in segments] | |
| if not pieces: | |
| return "未检测到有效运镜段" | |
| if len(pieces) == 1: | |
| return f"主要为{pieces[0]}" | |
| if len(pieces) == 2: | |
| return f"先{pieces[0]},后{pieces[1]}" | |
| if len(pieces) == 3: | |
| return f"先{pieces[0]},然后{pieces[1]},最后{pieces[2]}" | |
| return ",".join(pieces) | |
| class VideoMAEWindowInferencer: | |
| def __init__(self, model_id: str, hf_token: str): | |
| self.processor = VideoMAEImageProcessor.from_pretrained(model_id, token=hf_token) | |
| self.model = VideoMAEForVideoClassification.from_pretrained(model_id, token=hf_token) | |
| self.model.eval() | |
| if torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| else: | |
| self.device = torch.device("cpu") | |
| self.model.to(self.device) | |
| id2label = self.model.config.id2label | |
| if isinstance(id2label, dict): | |
| self.id2label = {int(k): v for k, v in id2label.items()} | |
| else: | |
| self.id2label = {i: v for i, v in enumerate(id2label)} | |
| def _sample_frames( | |
| self, | |
| cap: cv2.VideoCapture, | |
| start_sec: float, | |
| end_sec: float, | |
| fps: float, | |
| total_frames: int, | |
| ) -> List[np.ndarray]: | |
| if end_sec <= start_sec: | |
| end_sec = start_sec + max(1.0 / fps, 1e-3) | |
| f0, f1 = compute_frame_span(start_sec, end_sec, fps, total_frames) | |
| idxs = np.linspace(f0, f1, NUM_FRAMES).astype(int).tolist() | |
| frames: List[np.ndarray] = [] | |
| last: Optional[np.ndarray] = None | |
| for idx in idxs: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) | |
| ok, frame = cap.read() | |
| if ok: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame) | |
| last = frame | |
| elif last is not None: | |
| frames.append(last) | |
| if not frames: | |
| raise RuntimeError("Failed to sample frames for window.") | |
| while len(frames) < NUM_FRAMES: | |
| frames.append(frames[-1]) | |
| return frames[:NUM_FRAMES] | |
| def infer_window( | |
| self, | |
| cap: cv2.VideoCapture, | |
| start_sec: float, | |
| end_sec: float, | |
| fps: float, | |
| total_frames: int, | |
| ) -> WindowPrediction: | |
| frames = self._sample_frames(cap, start_sec, end_sec, fps, total_frames) | |
| inputs = self.processor([frames], return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| logits = self.model(**inputs).logits.float().cpu().numpy()[0] | |
| probs = 1.0 / (1.0 + np.exp(-logits)) | |
| raw_scores = {self.id2label[i]: float(probs[i]) for i in sorted(self.id2label.keys())} | |
| raw_top = sorted(raw_scores.items(), key=lambda kv: (kv[1], kv[0]), reverse=True)[:8] | |
| mode, primary, secondary, conf = classify_raw_scores(raw_scores) | |
| return WindowPrediction( | |
| start_sec=start_sec, | |
| end_sec=end_sec, | |
| raw_scores=raw_scores, | |
| raw_top_labels=[{"label": lb, "score": round(sc, 4)} for lb, sc in raw_top], | |
| mode=mode, | |
| primary=primary, | |
| secondary=secondary, | |
| confidence=float(conf), | |
| ) | |
| def process_shot( | |
| cap: cv2.VideoCapture, | |
| inferencer: VideoMAEWindowInferencer, | |
| shot: ShotBoundary, | |
| fps: float, | |
| total_frames: int, | |
| ) -> dict: | |
| ranges = build_window_ranges(shot.start_sec, shot.end_sec) | |
| windows = [inferencer.infer_window(cap, s, e, fps, total_frames) for s, e in ranges] | |
| majority_fill_uncertain(windows) | |
| apply_hysteresis(windows) | |
| enforce_min_duration(windows, shot.end_sec - shot.start_sec) | |
| enforce_max_segments(windows) | |
| segments = build_segments_output(windows) | |
| narrative = build_shot_narrative(segments) | |
| return { | |
| "shot_id": shot.shot_id, | |
| "start_sec": round(shot.start_sec, 3), | |
| "end_sec": round(shot.end_sec, 3), | |
| "segments": segments, | |
| "narrative_cn": narrative, | |
| } | |
| def write_csv(path: str, shots: List[dict]) -> None: | |
| fields = [ | |
| "shot_id", | |
| "seg_idx", | |
| "start_sec", | |
| "end_sec", | |
| "mode", | |
| "primary", | |
| "secondary", | |
| "confidence", | |
| "narrative_piece", | |
| ] | |
| with open(path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=fields) | |
| writer.writeheader() | |
| for shot in shots: | |
| for seg in shot["segments"]: | |
| writer.writerow( | |
| { | |
| "shot_id": shot["shot_id"], | |
| "seg_idx": seg["seg_idx"], | |
| "start_sec": seg["start_sec"], | |
| "end_sec": seg["end_sec"], | |
| "mode": seg["mode"], | |
| "primary": seg["primary"], | |
| "secondary": seg["secondary"], | |
| "confidence": seg["confidence"], | |
| "narrative_piece": seg["narrative_piece"], | |
| } | |
| ) | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Stable timeline generation from kandinsky VideoMAE") | |
| parser.add_argument("--video", required=True, help="Input video path") | |
| parser.add_argument("--hf-token", default=os.environ.get("HF_TOKEN", ""), help="HF token (or set HF_TOKEN)") | |
| parser.add_argument("--shots-jsonl", default="", help="Optional shot boundaries jsonl") | |
| parser.add_argument("--output-json", default="", help="Output timeline JSON path") | |
| parser.add_argument("--output-csv", default="", help="Output segment CSV path") | |
| args = parser.parse_args() | |
| if not args.hf_token: | |
| raise RuntimeError("HF token required: use --hf-token or set HF_TOKEN") | |
| if not os.path.exists(args.video): | |
| raise FileNotFoundError(args.video) | |
| output_json = args.output_json or f"{args.video}.timeline.json" | |
| output_csv = args.output_csv or f"{args.video}.timeline.csv" | |
| cap = cv2.VideoCapture(args.video) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Cannot open video: {args.video}") | |
| fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| if fps <= 1e-8 or total_frames <= 0: | |
| cap.release() | |
| raise RuntimeError(f"Invalid video metadata (fps={fps}, frames={total_frames})") | |
| duration_sec = total_frames / fps | |
| if args.shots_jsonl: | |
| shots = load_shots_jsonl(args.shots_jsonl, fps=fps, video_duration_sec=duration_sec) | |
| else: | |
| shots = default_full_shot(video_duration_sec=duration_sec) | |
| inferencer = VideoMAEWindowInferencer(MODEL_ID, args.hf_token) | |
| shot_outputs = [] | |
| for shot in shots: | |
| print(f"[RUN] shot={shot.shot_id} {shot.start_sec:.3f}s~{shot.end_sec:.3f}s") | |
| shot_result = process_shot(cap, inferencer, shot, fps=fps, total_frames=total_frames) | |
| print(f" segments={len(shot_result['segments'])} narrative={shot_result['narrative_cn']}") | |
| shot_outputs.append(shot_result) | |
| cap.release() | |
| payload = { | |
| "video_path": os.path.abspath(args.video), | |
| "model_id": MODEL_ID, | |
| "fps": fps, | |
| "total_frames": total_frames, | |
| "duration_sec": round(duration_sec, 3), | |
| "fixed_constants": { | |
| "window_sec": WINDOW_SEC, | |
| "stride_sec": STRIDE_SEC, | |
| "num_frames": NUM_FRAMES, | |
| "uncertain_p1_min": UNCERTAIN_P1_MIN, | |
| "uncertain_margin_min": UNCERTAIN_MARGIN_MIN, | |
| "uncertain_undefined_min": UNCERTAIN_UNDEFINED_MIN, | |
| "secondary_min": SECONDARY_MIN, | |
| "secondary_gap_max": SECONDARY_GAP_MAX, | |
| "smooth_radius": SMOOTH_RADIUS, | |
| "hysteresis_support": HYSTERESIS_SUPPORT, | |
| "min_seg_sec": MIN_SEG_SEC, | |
| "max_segments_per_shot": MAX_SEGMENTS_PER_SHOT, | |
| "shot_boundary_guard_frames": SHOT_BOUNDARY_GUARD_FRAMES, | |
| }, | |
| "shots": shot_outputs, | |
| } | |
| with open(output_json, "w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2) | |
| write_csv(output_csv, shot_outputs) | |
| print(f"[OK] wrote json: {output_json}") | |
| print(f"[OK] wrote csv: {output_csv}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |