Spaces:
Configuration error
Configuration error
| """Key frame selection utilities.""" | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Iterable, Mapping | |
| import cv2 | |
| import numpy as np | |
| from .config import WorkerSettings | |
| from .pipeline import run_stream3r_inference | |
| from .runtime import WorkerRuntime | |
| logger = logging.getLogger(__name__) | |
| class FrameRecord: | |
| index: int | |
| frame_id: str | |
| path: Path | |
| source: str | None = None | |
| timestamp: str | None = None | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| class KeyframeSelectionResult: | |
| indices: list[int] | |
| diagnostics: list[dict[str, Any]] | |
| top_k: int | |
| def pose_confidence(predictions: Mapping[str, np.ndarray]) -> np.ndarray | None: | |
| if "world_points_conf" in predictions: | |
| return np.asarray(predictions["world_points_conf"], dtype=np.float32) | |
| if "depth_conf" in predictions: | |
| return np.asarray(predictions["depth_conf"], dtype=np.float32) | |
| return None | |
| def _camera_poses(extrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: | |
| matrices = np.asarray(extrinsic, dtype=np.float64) | |
| if matrices.ndim != 3 or matrices.shape[1:] != (3, 4): | |
| raise ValueError("Extrinsic array must have shape (N, 3, 4)") | |
| count = matrices.shape[0] | |
| rotations = np.empty((count, 3, 3), dtype=np.float64) | |
| translations = np.empty((count, 3), dtype=np.float64) | |
| for idx in range(count): | |
| mat = np.eye(4, dtype=np.float64) | |
| mat[:3, :4] = matrices[idx] | |
| cam_to_world = np.linalg.inv(mat) | |
| rotations[idx] = cam_to_world[:3, :3] | |
| translations[idx] = cam_to_world[:3, 3] | |
| return rotations, translations | |
| def _compute_motion_deltas(rotations: np.ndarray, translations: np.ndarray, rot_weight: float) -> np.ndarray: | |
| count = rotations.shape[0] | |
| deltas = np.zeros(count, dtype=np.float64) | |
| if count <= 1: | |
| return deltas | |
| for idx in range(1, count): | |
| delta_t = np.linalg.norm(translations[idx] - translations[idx - 1]) | |
| rel = rotations[idx - 1].T @ rotations[idx] | |
| trace = np.clip((np.trace(rel) - 1.0) / 2.0, -1.0, 1.0) | |
| delta_r = float(np.arccos(trace)) | |
| deltas[idx] = delta_t + rot_weight * delta_r | |
| return deltas | |
| def _hash_quantized_voxels(coords: np.ndarray) -> np.ndarray: | |
| coords = coords.astype(np.int64, copy=False) | |
| primes = np.array([73856093, 19349663, 83492791], dtype=np.int64) | |
| return coords @ primes | |
| def _frame_voxel_sets( | |
| world_points: np.ndarray, | |
| confidence: np.ndarray, | |
| *, | |
| threshold: float, | |
| voxel_size: float, | |
| max_points: int, | |
| ) -> tuple[list[set[int]], int]: | |
| rng = np.random.default_rng(42) | |
| frames = world_points.shape[0] | |
| voxel_sets: list[set[int]] = [] | |
| global_union: set[int] = set() | |
| if voxel_size <= 0.0: | |
| return [set() for _ in range(frames)], 0 | |
| for idx in range(frames): | |
| conf_frame = confidence[idx] | |
| mask = conf_frame >= threshold | |
| if not np.any(mask): | |
| voxel_sets.append(set()) | |
| continue | |
| points = world_points[idx][mask] | |
| if points.shape[0] > max_points: | |
| sample_idx = rng.choice(points.shape[0], max_points, replace=False) | |
| points = points[sample_idx] | |
| quantized = np.floor(points / voxel_size).astype(np.int64, copy=False) | |
| hashes = np.unique(_hash_quantized_voxels(quantized)) | |
| voxel_set = set(int(v) for v in hashes.tolist()) | |
| voxel_sets.append(voxel_set) | |
| global_union.update(voxel_set) | |
| return voxel_sets, len(global_union) | |
| def _select_motion_indices( | |
| motion_deltas: np.ndarray, | |
| *, | |
| threshold: float, | |
| min_gap: int, | |
| max_gap: int, | |
| ) -> tuple[list[int], dict[int, dict[str, float]]]: | |
| total_frames = motion_deltas.shape[0] | |
| if total_frames == 0: | |
| return [], {} | |
| selected = [0] | |
| diagnostics: dict[int, dict[str, float]] = {0: {"motion_delta": 0.0, "cum_motion": 0.0}} | |
| cumulative = 0.0 | |
| gap = 0 | |
| for idx in range(1, total_frames): | |
| delta = float(motion_deltas[idx]) | |
| cumulative += delta | |
| gap += 1 | |
| if gap < max(1, min_gap): | |
| continue | |
| should_select = cumulative >= threshold | |
| if max_gap > 0 and gap >= max_gap: | |
| should_select = True | |
| if should_select: | |
| selected.append(idx) | |
| diagnostics[idx] = {"motion_delta": delta, "cum_motion": cumulative} | |
| cumulative = 0.0 | |
| gap = 0 | |
| if selected[-1] != total_frames - 1: | |
| selected.append(total_frames - 1) | |
| diagnostics.setdefault(total_frames - 1, {"motion_delta": float(motion_deltas[-1]), "cum_motion": cumulative}) | |
| return selected, diagnostics | |
| def select_keyframes_motion_coverage( | |
| frame_records: list[FrameRecord], | |
| predictions: Mapping[str, np.ndarray], | |
| settings: WorkerSettings, | |
| requested_top_k: int, | |
| ) -> KeyframeSelectionResult | None: | |
| extrinsic = np.asarray(predictions.get("extrinsic")) | |
| if extrinsic.size == 0: | |
| return None | |
| rotations, translations = _camera_poses(extrinsic) | |
| motion_deltas = _compute_motion_deltas(rotations, translations, settings.keyframe_rotation_weight) | |
| motion_indices, motion_diag = _select_motion_indices( | |
| motion_deltas, | |
| threshold=settings.keyframe_motion_threshold, | |
| min_gap=max(1, settings.keyframe_min_gap_frames), | |
| max_gap=max(0, settings.keyframe_max_gap_frames), | |
| ) | |
| total_frames = len(frame_records) | |
| confidence = pose_confidence(predictions) | |
| world_points = predictions.get("world_points") | |
| if world_points is None: | |
| world_points = predictions.get("world_points_from_depth") | |
| voxel_sets: list[set[int]] = [set() for _ in range(total_frames)] | |
| total_voxels = 0 | |
| mean_conf = np.zeros(total_frames, dtype=np.float32) | |
| if confidence is not None: | |
| mean_conf = confidence.reshape(confidence.shape[0], -1).mean(axis=1) | |
| if confidence is not None and world_points is not None: | |
| voxel_sets, total_voxels = _frame_voxel_sets( | |
| np.asarray(world_points), | |
| np.asarray(confidence), | |
| threshold=settings.keyframe_coverage_confidence, | |
| voxel_size=settings.keyframe_coverage_voxel_size, | |
| max_points=max(1000, settings.keyframe_coverage_max_points), | |
| ) | |
| total_voxels = max(total_voxels, 1) | |
| top_k = requested_top_k if requested_top_k > 0 else settings.keyframe_default_top_k | |
| top_k = max(min(top_k, total_frames), len(motion_indices)) | |
| selected_set: set[int] = set(motion_indices) | |
| diagnostics: dict[int, dict[str, Any]] = {} | |
| covered: set[int] = set() | |
| for idx in motion_indices: | |
| gain_count = len(voxel_sets[idx] - covered) if voxel_sets[idx] else 0 | |
| gain_ratio = gain_count / total_voxels | |
| covered.update(voxel_sets[idx]) | |
| diagnostics[idx] = { | |
| "frame_id": frame_records[idx].frame_id, | |
| "frame_index": frame_records[idx].index, | |
| "reason": "motion", | |
| "motion_delta": float(motion_deltas[idx]), | |
| "cum_motion": float(motion_diag.get(idx, {}).get("cum_motion", 0.0)), | |
| "coverage_gain_ratio": float(gain_ratio), | |
| "coverage_gain_count": int(gain_count), | |
| "mean_confidence": float(mean_conf[idx]) if confidence is not None else None, | |
| } | |
| if len(selected_set) < top_k and total_voxels > 0: | |
| min_gain_ratio = settings.keyframe_min_gain_ratio | |
| remaining = [i for i in range(total_frames) if i not in selected_set and voxel_sets[i]] | |
| while remaining and len(selected_set) < top_k: | |
| best_idx = -1 | |
| best_gain = -1 | |
| best_ratio = -1.0 | |
| for idx in remaining: | |
| gain = len(voxel_sets[idx] - covered) | |
| if gain <= 0: | |
| continue | |
| ratio = gain / total_voxels | |
| if ratio > best_ratio or (np.isclose(ratio, best_ratio) and gain > best_gain): | |
| best_idx = idx | |
| best_gain = gain | |
| best_ratio = ratio | |
| if best_idx == -1 or best_ratio < min_gain_ratio: | |
| break | |
| selected_set.add(best_idx) | |
| covered.update(voxel_sets[best_idx]) | |
| diagnostics[best_idx] = { | |
| "frame_id": frame_records[best_idx].frame_id, | |
| "frame_index": frame_records[best_idx].index, | |
| "reason": "coverage", | |
| "motion_delta": float(motion_deltas[best_idx]), | |
| "cum_motion": float(motion_diag.get(best_idx, {}).get("cum_motion", 0.0)), | |
| "coverage_gain_ratio": float(best_ratio), | |
| "coverage_gain_count": int(best_gain), | |
| "mean_confidence": float(mean_conf[best_idx]) if confidence is not None else None, | |
| } | |
| remaining.remove(best_idx) | |
| if requested_top_k > 0 and len(selected_set) > requested_top_k: | |
| coverage_candidates = [idx for idx in selected_set if diagnostics[idx]["reason"] == "coverage"] | |
| coverage_candidates.sort(key=lambda idx: diagnostics[idx].get("coverage_gain_ratio", 0.0)) | |
| while len(selected_set) > requested_top_k and coverage_candidates: | |
| drop_idx = coverage_candidates.pop(0) | |
| selected_set.remove(drop_idx) | |
| diagnostics.pop(drop_idx, None) | |
| final_indices = sorted(selected_set) | |
| final_diags = [diagnostics[idx] for idx in final_indices] | |
| return KeyframeSelectionResult(indices=final_indices, diagnostics=final_diags, top_k=len(final_indices)) | |
| def run_keyframe_prepass( | |
| *, | |
| runtime: WorkerRuntime, | |
| payload: Mapping[str, Any], | |
| frame_records: list[FrameRecord], | |
| mode: str, | |
| streaming: bool, | |
| window_size: int | None, | |
| ) -> KeyframeSelectionResult | None: | |
| if len(frame_records) <= 1: | |
| return None | |
| settings = runtime.settings | |
| top_k_payload = int(payload.get("prepass_top_k") or payload.get("top_k_frames") or payload.get("top_k") or 0) | |
| try: | |
| inference = run_stream3r_inference( | |
| runtime=runtime, | |
| image_paths=[record.path for record in frame_records], | |
| mode=mode, | |
| streaming=streaming, | |
| cache_output_path=None, | |
| progress_cb=None, | |
| window_size=window_size if streaming and mode == "window" else None, | |
| ) | |
| except Exception: | |
| logger.exception("Keyframe pre-pass inference failed") | |
| return None | |
| try: | |
| return select_keyframes_motion_coverage( | |
| frame_records, | |
| inference.predictions, | |
| settings, | |
| requested_top_k=top_k_payload, | |
| ) | |
| finally: | |
| del inference | |
| def extract_video_frames( | |
| video_path: Path, | |
| output_dir: Path, | |
| *, | |
| target_fps: float | None = None, | |
| max_frames: int | None = None, | |
| ) -> tuple[list[FrameRecord], float]: | |
| if not video_path.exists(): | |
| raise FileNotFoundError(f"Video file not found: {video_path}") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| cap = cv2.VideoCapture(str(video_path)) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Failed to open video: {video_path}") | |
| native_fps = cap.get(cv2.CAP_PROP_FPS) | |
| if not native_fps or native_fps <= 0: | |
| native_fps = 30.0 | |
| frame_interval = 1 | |
| if target_fps and target_fps > 0: | |
| frame_interval = max(1, int(round(native_fps / target_fps))) | |
| frame_records: list[FrameRecord] = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| frame_idx = 0 | |
| extracted = 0 | |
| success, frame = cap.read() | |
| while success: | |
| if frame_idx % frame_interval == 0: | |
| frame_id = f"frame_{extracted:06d}" | |
| frame_path = output_dir / f"{frame_id}.jpg" | |
| if not cv2.imwrite(str(frame_path), frame): | |
| cap.release() | |
| raise RuntimeError(f"Failed to write frame: {frame_path}") | |
| timestamp_s = frame_idx / native_fps | |
| frame_records.append( | |
| FrameRecord( | |
| index=extracted, | |
| frame_id=frame_id, | |
| path=frame_path, | |
| metadata={"frame_number": frame_idx, "timestamp_s": timestamp_s}, | |
| ) | |
| ) | |
| extracted += 1 | |
| if max_frames and extracted >= max_frames: | |
| break | |
| frame_idx += 1 | |
| success, frame = cap.read() | |
| cap.release() | |
| if not frame_records: | |
| raise RuntimeError("No frames extracted from video") | |
| return frame_records, native_fps | |
| def linear_sample_indices(total: int, desired: int) -> list[int]: | |
| if desired <= 0 or total <= desired: | |
| return list(range(total)) | |
| step = total / desired | |
| return [min(total - 1, int(round(i * step))) for i in range(desired)] | |
| def build_keyframe_uploads( | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| selected_records: Iterable[FrameRecord], | |
| diagnostics: list[dict[str, Any]], | |
| *, | |
| subdir: str, | |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
| diag_by_index = {entry.get("frame_index"): entry for entry in diagnostics} | |
| storage_entries: list[dict[str, Any]] = [] | |
| media_entries: list[dict[str, Any]] = [] | |
| for record in selected_records: | |
| diag = diag_by_index.get(record.index, {}) | |
| filename = f"{record.frame_id}.jpg" | |
| key = runtime.storage.build_key(scene_id, subdir, filename) | |
| uri = runtime.storage.upload_file(record.path, key, content_type="image/jpeg") | |
| storage_entries.append( | |
| { | |
| "frame_id": record.frame_id, | |
| "frame_index": record.index, | |
| "url": uri, | |
| "storage_key": key, | |
| "diagnostics": diag, | |
| } | |
| ) | |
| media_entries.append( | |
| { | |
| "media_type": "image", | |
| "file": key, | |
| "captured_at": _diagnostic_captured_at(record, diag), | |
| } | |
| ) | |
| return storage_entries, media_entries | |
| def _diagnostic_captured_at(record: FrameRecord, diag: Mapping[str, Any]) -> str | None: | |
| if record.timestamp: | |
| return record.timestamp | |
| ts = diag.get("timestamp") or record.metadata.get("timestamp") | |
| if isinstance(ts, str): | |
| return ts | |
| if isinstance(ts, (int, float)): | |
| return datetime.utcfromtimestamp(float(ts)).isoformat() + "Z" | |
| timestamp_s = record.metadata.get("timestamp_s") | |
| if isinstance(timestamp_s, (int, float)): | |
| return datetime.utcfromtimestamp(float(timestamp_s)).isoformat() + "Z" | |
| return None | |