"""Key frame selection utilities.""" from __future__ import annotations import logging from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Iterable, Mapping import cv2 import numpy as np from .config import WorkerSettings from .pipeline import run_stream3r_inference from .runtime import WorkerRuntime logger = logging.getLogger(__name__) @dataclass(slots=True) class FrameRecord: index: int frame_id: str path: Path source: str | None = None timestamp: str | None = None metadata: dict[str, Any] = field(default_factory=dict) @dataclass(slots=True) class KeyframeSelectionResult: indices: list[int] diagnostics: list[dict[str, Any]] top_k: int def pose_confidence(predictions: Mapping[str, np.ndarray]) -> np.ndarray | None: if "world_points_conf" in predictions: return np.asarray(predictions["world_points_conf"], dtype=np.float32) if "depth_conf" in predictions: return np.asarray(predictions["depth_conf"], dtype=np.float32) return None def _camera_poses(extrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: matrices = np.asarray(extrinsic, dtype=np.float64) if matrices.ndim != 3 or matrices.shape[1:] != (3, 4): raise ValueError("Extrinsic array must have shape (N, 3, 4)") count = matrices.shape[0] rotations = np.empty((count, 3, 3), dtype=np.float64) translations = np.empty((count, 3), dtype=np.float64) for idx in range(count): mat = np.eye(4, dtype=np.float64) mat[:3, :4] = matrices[idx] cam_to_world = np.linalg.inv(mat) rotations[idx] = cam_to_world[:3, :3] translations[idx] = cam_to_world[:3, 3] return rotations, translations def _compute_motion_deltas(rotations: np.ndarray, translations: np.ndarray, rot_weight: float) -> np.ndarray: count = rotations.shape[0] deltas = np.zeros(count, dtype=np.float64) if count <= 1: return deltas for idx in range(1, count): delta_t = np.linalg.norm(translations[idx] - translations[idx - 1]) rel = rotations[idx - 1].T @ rotations[idx] trace = np.clip((np.trace(rel) - 1.0) / 2.0, -1.0, 1.0) delta_r = float(np.arccos(trace)) deltas[idx] = delta_t + rot_weight * delta_r return deltas def _hash_quantized_voxels(coords: np.ndarray) -> np.ndarray: coords = coords.astype(np.int64, copy=False) primes = np.array([73856093, 19349663, 83492791], dtype=np.int64) return coords @ primes def _frame_voxel_sets( world_points: np.ndarray, confidence: np.ndarray, *, threshold: float, voxel_size: float, max_points: int, ) -> tuple[list[set[int]], int]: rng = np.random.default_rng(42) frames = world_points.shape[0] voxel_sets: list[set[int]] = [] global_union: set[int] = set() if voxel_size <= 0.0: return [set() for _ in range(frames)], 0 for idx in range(frames): conf_frame = confidence[idx] mask = conf_frame >= threshold if not np.any(mask): voxel_sets.append(set()) continue points = world_points[idx][mask] if points.shape[0] > max_points: sample_idx = rng.choice(points.shape[0], max_points, replace=False) points = points[sample_idx] quantized = np.floor(points / voxel_size).astype(np.int64, copy=False) hashes = np.unique(_hash_quantized_voxels(quantized)) voxel_set = set(int(v) for v in hashes.tolist()) voxel_sets.append(voxel_set) global_union.update(voxel_set) return voxel_sets, len(global_union) def _select_motion_indices( motion_deltas: np.ndarray, *, threshold: float, min_gap: int, max_gap: int, ) -> tuple[list[int], dict[int, dict[str, float]]]: total_frames = motion_deltas.shape[0] if total_frames == 0: return [], {} selected = [0] diagnostics: dict[int, dict[str, float]] = {0: {"motion_delta": 0.0, "cum_motion": 0.0}} cumulative = 0.0 gap = 0 for idx in range(1, total_frames): delta = float(motion_deltas[idx]) cumulative += delta gap += 1 if gap < max(1, min_gap): continue should_select = cumulative >= threshold if max_gap > 0 and gap >= max_gap: should_select = True if should_select: selected.append(idx) diagnostics[idx] = {"motion_delta": delta, "cum_motion": cumulative} cumulative = 0.0 gap = 0 if selected[-1] != total_frames - 1: selected.append(total_frames - 1) diagnostics.setdefault(total_frames - 1, {"motion_delta": float(motion_deltas[-1]), "cum_motion": cumulative}) return selected, diagnostics def select_keyframes_motion_coverage( frame_records: list[FrameRecord], predictions: Mapping[str, np.ndarray], settings: WorkerSettings, requested_top_k: int, ) -> KeyframeSelectionResult | None: extrinsic = np.asarray(predictions.get("extrinsic")) if extrinsic.size == 0: return None rotations, translations = _camera_poses(extrinsic) motion_deltas = _compute_motion_deltas(rotations, translations, settings.keyframe_rotation_weight) motion_indices, motion_diag = _select_motion_indices( motion_deltas, threshold=settings.keyframe_motion_threshold, min_gap=max(1, settings.keyframe_min_gap_frames), max_gap=max(0, settings.keyframe_max_gap_frames), ) total_frames = len(frame_records) confidence = pose_confidence(predictions) world_points = predictions.get("world_points") if world_points is None: world_points = predictions.get("world_points_from_depth") voxel_sets: list[set[int]] = [set() for _ in range(total_frames)] total_voxels = 0 mean_conf = np.zeros(total_frames, dtype=np.float32) if confidence is not None: mean_conf = confidence.reshape(confidence.shape[0], -1).mean(axis=1) if confidence is not None and world_points is not None: voxel_sets, total_voxels = _frame_voxel_sets( np.asarray(world_points), np.asarray(confidence), threshold=settings.keyframe_coverage_confidence, voxel_size=settings.keyframe_coverage_voxel_size, max_points=max(1000, settings.keyframe_coverage_max_points), ) total_voxels = max(total_voxels, 1) top_k = requested_top_k if requested_top_k > 0 else settings.keyframe_default_top_k top_k = max(min(top_k, total_frames), len(motion_indices)) selected_set: set[int] = set(motion_indices) diagnostics: dict[int, dict[str, Any]] = {} covered: set[int] = set() for idx in motion_indices: gain_count = len(voxel_sets[idx] - covered) if voxel_sets[idx] else 0 gain_ratio = gain_count / total_voxels covered.update(voxel_sets[idx]) diagnostics[idx] = { "frame_id": frame_records[idx].frame_id, "frame_index": frame_records[idx].index, "reason": "motion", "motion_delta": float(motion_deltas[idx]), "cum_motion": float(motion_diag.get(idx, {}).get("cum_motion", 0.0)), "coverage_gain_ratio": float(gain_ratio), "coverage_gain_count": int(gain_count), "mean_confidence": float(mean_conf[idx]) if confidence is not None else None, } if len(selected_set) < top_k and total_voxels > 0: min_gain_ratio = settings.keyframe_min_gain_ratio remaining = [i for i in range(total_frames) if i not in selected_set and voxel_sets[i]] while remaining and len(selected_set) < top_k: best_idx = -1 best_gain = -1 best_ratio = -1.0 for idx in remaining: gain = len(voxel_sets[idx] - covered) if gain <= 0: continue ratio = gain / total_voxels if ratio > best_ratio or (np.isclose(ratio, best_ratio) and gain > best_gain): best_idx = idx best_gain = gain best_ratio = ratio if best_idx == -1 or best_ratio < min_gain_ratio: break selected_set.add(best_idx) covered.update(voxel_sets[best_idx]) diagnostics[best_idx] = { "frame_id": frame_records[best_idx].frame_id, "frame_index": frame_records[best_idx].index, "reason": "coverage", "motion_delta": float(motion_deltas[best_idx]), "cum_motion": float(motion_diag.get(best_idx, {}).get("cum_motion", 0.0)), "coverage_gain_ratio": float(best_ratio), "coverage_gain_count": int(best_gain), "mean_confidence": float(mean_conf[best_idx]) if confidence is not None else None, } remaining.remove(best_idx) if requested_top_k > 0 and len(selected_set) > requested_top_k: coverage_candidates = [idx for idx in selected_set if diagnostics[idx]["reason"] == "coverage"] coverage_candidates.sort(key=lambda idx: diagnostics[idx].get("coverage_gain_ratio", 0.0)) while len(selected_set) > requested_top_k and coverage_candidates: drop_idx = coverage_candidates.pop(0) selected_set.remove(drop_idx) diagnostics.pop(drop_idx, None) final_indices = sorted(selected_set) final_diags = [diagnostics[idx] for idx in final_indices] return KeyframeSelectionResult(indices=final_indices, diagnostics=final_diags, top_k=len(final_indices)) def run_keyframe_prepass( *, runtime: WorkerRuntime, payload: Mapping[str, Any], frame_records: list[FrameRecord], mode: str, streaming: bool, window_size: int | None, ) -> KeyframeSelectionResult | None: if len(frame_records) <= 1: return None settings = runtime.settings top_k_payload = int(payload.get("prepass_top_k") or payload.get("top_k_frames") or payload.get("top_k") or 0) try: inference = run_stream3r_inference( runtime=runtime, image_paths=[record.path for record in frame_records], mode=mode, streaming=streaming, cache_output_path=None, progress_cb=None, window_size=window_size if streaming and mode == "window" else None, ) except Exception: logger.exception("Keyframe pre-pass inference failed") return None try: return select_keyframes_motion_coverage( frame_records, inference.predictions, settings, requested_top_k=top_k_payload, ) finally: del inference def extract_video_frames( video_path: Path, output_dir: Path, *, target_fps: float | None = None, max_frames: int | None = None, ) -> tuple[list[FrameRecord], float]: if not video_path.exists(): raise FileNotFoundError(f"Video file not found: {video_path}") output_dir.mkdir(parents=True, exist_ok=True) cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): raise RuntimeError(f"Failed to open video: {video_path}") native_fps = cap.get(cv2.CAP_PROP_FPS) if not native_fps or native_fps <= 0: native_fps = 30.0 frame_interval = 1 if target_fps and target_fps > 0: frame_interval = max(1, int(round(native_fps / target_fps))) frame_records: list[FrameRecord] = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) frame_idx = 0 extracted = 0 success, frame = cap.read() while success: if frame_idx % frame_interval == 0: frame_id = f"frame_{extracted:06d}" frame_path = output_dir / f"{frame_id}.jpg" if not cv2.imwrite(str(frame_path), frame): cap.release() raise RuntimeError(f"Failed to write frame: {frame_path}") timestamp_s = frame_idx / native_fps frame_records.append( FrameRecord( index=extracted, frame_id=frame_id, path=frame_path, metadata={"frame_number": frame_idx, "timestamp_s": timestamp_s}, ) ) extracted += 1 if max_frames and extracted >= max_frames: break frame_idx += 1 success, frame = cap.read() cap.release() if not frame_records: raise RuntimeError("No frames extracted from video") return frame_records, native_fps def linear_sample_indices(total: int, desired: int) -> list[int]: if desired <= 0 or total <= desired: return list(range(total)) step = total / desired return [min(total - 1, int(round(i * step))) for i in range(desired)] def build_keyframe_uploads( runtime: WorkerRuntime, scene_id: str, selected_records: Iterable[FrameRecord], diagnostics: list[dict[str, Any]], *, subdir: str, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: diag_by_index = {entry.get("frame_index"): entry for entry in diagnostics} storage_entries: list[dict[str, Any]] = [] media_entries: list[dict[str, Any]] = [] for record in selected_records: diag = diag_by_index.get(record.index, {}) filename = f"{record.frame_id}.jpg" key = runtime.storage.build_key(scene_id, subdir, filename) uri = runtime.storage.upload_file(record.path, key, content_type="image/jpeg") storage_entries.append( { "frame_id": record.frame_id, "frame_index": record.index, "url": uri, "storage_key": key, "diagnostics": diag, } ) media_entries.append( { "media_type": "image", "file": key, "captured_at": _diagnostic_captured_at(record, diag), } ) return storage_entries, media_entries def _diagnostic_captured_at(record: FrameRecord, diag: Mapping[str, Any]) -> str | None: if record.timestamp: return record.timestamp ts = diag.get("timestamp") or record.metadata.get("timestamp") if isinstance(ts, str): return ts if isinstance(ts, (int, float)): return datetime.utcfromtimestamp(float(ts)).isoformat() + "Z" timestamp_s = record.metadata.get("timestamp_s") if isinstance(timestamp_s, (int, float)): return datetime.utcfromtimestamp(float(timestamp_s)).isoformat() + "Z" return None