brian4dwell's picture
split key framer out
01e8928
"""Key frame selection utilities."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, Mapping
import cv2
import numpy as np
from .config import WorkerSettings
from .pipeline import run_stream3r_inference
from .runtime import WorkerRuntime
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class FrameRecord:
index: int
frame_id: str
path: Path
source: str | None = None
timestamp: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class KeyframeSelectionResult:
indices: list[int]
diagnostics: list[dict[str, Any]]
top_k: int
def pose_confidence(predictions: Mapping[str, np.ndarray]) -> np.ndarray | None:
if "world_points_conf" in predictions:
return np.asarray(predictions["world_points_conf"], dtype=np.float32)
if "depth_conf" in predictions:
return np.asarray(predictions["depth_conf"], dtype=np.float32)
return None
def _camera_poses(extrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
matrices = np.asarray(extrinsic, dtype=np.float64)
if matrices.ndim != 3 or matrices.shape[1:] != (3, 4):
raise ValueError("Extrinsic array must have shape (N, 3, 4)")
count = matrices.shape[0]
rotations = np.empty((count, 3, 3), dtype=np.float64)
translations = np.empty((count, 3), dtype=np.float64)
for idx in range(count):
mat = np.eye(4, dtype=np.float64)
mat[:3, :4] = matrices[idx]
cam_to_world = np.linalg.inv(mat)
rotations[idx] = cam_to_world[:3, :3]
translations[idx] = cam_to_world[:3, 3]
return rotations, translations
def _compute_motion_deltas(rotations: np.ndarray, translations: np.ndarray, rot_weight: float) -> np.ndarray:
count = rotations.shape[0]
deltas = np.zeros(count, dtype=np.float64)
if count <= 1:
return deltas
for idx in range(1, count):
delta_t = np.linalg.norm(translations[idx] - translations[idx - 1])
rel = rotations[idx - 1].T @ rotations[idx]
trace = np.clip((np.trace(rel) - 1.0) / 2.0, -1.0, 1.0)
delta_r = float(np.arccos(trace))
deltas[idx] = delta_t + rot_weight * delta_r
return deltas
def _hash_quantized_voxels(coords: np.ndarray) -> np.ndarray:
coords = coords.astype(np.int64, copy=False)
primes = np.array([73856093, 19349663, 83492791], dtype=np.int64)
return coords @ primes
def _frame_voxel_sets(
world_points: np.ndarray,
confidence: np.ndarray,
*,
threshold: float,
voxel_size: float,
max_points: int,
) -> tuple[list[set[int]], int]:
rng = np.random.default_rng(42)
frames = world_points.shape[0]
voxel_sets: list[set[int]] = []
global_union: set[int] = set()
if voxel_size <= 0.0:
return [set() for _ in range(frames)], 0
for idx in range(frames):
conf_frame = confidence[idx]
mask = conf_frame >= threshold
if not np.any(mask):
voxel_sets.append(set())
continue
points = world_points[idx][mask]
if points.shape[0] > max_points:
sample_idx = rng.choice(points.shape[0], max_points, replace=False)
points = points[sample_idx]
quantized = np.floor(points / voxel_size).astype(np.int64, copy=False)
hashes = np.unique(_hash_quantized_voxels(quantized))
voxel_set = set(int(v) for v in hashes.tolist())
voxel_sets.append(voxel_set)
global_union.update(voxel_set)
return voxel_sets, len(global_union)
def _select_motion_indices(
motion_deltas: np.ndarray,
*,
threshold: float,
min_gap: int,
max_gap: int,
) -> tuple[list[int], dict[int, dict[str, float]]]:
total_frames = motion_deltas.shape[0]
if total_frames == 0:
return [], {}
selected = [0]
diagnostics: dict[int, dict[str, float]] = {0: {"motion_delta": 0.0, "cum_motion": 0.0}}
cumulative = 0.0
gap = 0
for idx in range(1, total_frames):
delta = float(motion_deltas[idx])
cumulative += delta
gap += 1
if gap < max(1, min_gap):
continue
should_select = cumulative >= threshold
if max_gap > 0 and gap >= max_gap:
should_select = True
if should_select:
selected.append(idx)
diagnostics[idx] = {"motion_delta": delta, "cum_motion": cumulative}
cumulative = 0.0
gap = 0
if selected[-1] != total_frames - 1:
selected.append(total_frames - 1)
diagnostics.setdefault(total_frames - 1, {"motion_delta": float(motion_deltas[-1]), "cum_motion": cumulative})
return selected, diagnostics
def select_keyframes_motion_coverage(
frame_records: list[FrameRecord],
predictions: Mapping[str, np.ndarray],
settings: WorkerSettings,
requested_top_k: int,
) -> KeyframeSelectionResult | None:
extrinsic = np.asarray(predictions.get("extrinsic"))
if extrinsic.size == 0:
return None
rotations, translations = _camera_poses(extrinsic)
motion_deltas = _compute_motion_deltas(rotations, translations, settings.keyframe_rotation_weight)
motion_indices, motion_diag = _select_motion_indices(
motion_deltas,
threshold=settings.keyframe_motion_threshold,
min_gap=max(1, settings.keyframe_min_gap_frames),
max_gap=max(0, settings.keyframe_max_gap_frames),
)
total_frames = len(frame_records)
confidence = pose_confidence(predictions)
world_points = predictions.get("world_points")
if world_points is None:
world_points = predictions.get("world_points_from_depth")
voxel_sets: list[set[int]] = [set() for _ in range(total_frames)]
total_voxels = 0
mean_conf = np.zeros(total_frames, dtype=np.float32)
if confidence is not None:
mean_conf = confidence.reshape(confidence.shape[0], -1).mean(axis=1)
if confidence is not None and world_points is not None:
voxel_sets, total_voxels = _frame_voxel_sets(
np.asarray(world_points),
np.asarray(confidence),
threshold=settings.keyframe_coverage_confidence,
voxel_size=settings.keyframe_coverage_voxel_size,
max_points=max(1000, settings.keyframe_coverage_max_points),
)
total_voxels = max(total_voxels, 1)
top_k = requested_top_k if requested_top_k > 0 else settings.keyframe_default_top_k
top_k = max(min(top_k, total_frames), len(motion_indices))
selected_set: set[int] = set(motion_indices)
diagnostics: dict[int, dict[str, Any]] = {}
covered: set[int] = set()
for idx in motion_indices:
gain_count = len(voxel_sets[idx] - covered) if voxel_sets[idx] else 0
gain_ratio = gain_count / total_voxels
covered.update(voxel_sets[idx])
diagnostics[idx] = {
"frame_id": frame_records[idx].frame_id,
"frame_index": frame_records[idx].index,
"reason": "motion",
"motion_delta": float(motion_deltas[idx]),
"cum_motion": float(motion_diag.get(idx, {}).get("cum_motion", 0.0)),
"coverage_gain_ratio": float(gain_ratio),
"coverage_gain_count": int(gain_count),
"mean_confidence": float(mean_conf[idx]) if confidence is not None else None,
}
if len(selected_set) < top_k and total_voxels > 0:
min_gain_ratio = settings.keyframe_min_gain_ratio
remaining = [i for i in range(total_frames) if i not in selected_set and voxel_sets[i]]
while remaining and len(selected_set) < top_k:
best_idx = -1
best_gain = -1
best_ratio = -1.0
for idx in remaining:
gain = len(voxel_sets[idx] - covered)
if gain <= 0:
continue
ratio = gain / total_voxels
if ratio > best_ratio or (np.isclose(ratio, best_ratio) and gain > best_gain):
best_idx = idx
best_gain = gain
best_ratio = ratio
if best_idx == -1 or best_ratio < min_gain_ratio:
break
selected_set.add(best_idx)
covered.update(voxel_sets[best_idx])
diagnostics[best_idx] = {
"frame_id": frame_records[best_idx].frame_id,
"frame_index": frame_records[best_idx].index,
"reason": "coverage",
"motion_delta": float(motion_deltas[best_idx]),
"cum_motion": float(motion_diag.get(best_idx, {}).get("cum_motion", 0.0)),
"coverage_gain_ratio": float(best_ratio),
"coverage_gain_count": int(best_gain),
"mean_confidence": float(mean_conf[best_idx]) if confidence is not None else None,
}
remaining.remove(best_idx)
if requested_top_k > 0 and len(selected_set) > requested_top_k:
coverage_candidates = [idx for idx in selected_set if diagnostics[idx]["reason"] == "coverage"]
coverage_candidates.sort(key=lambda idx: diagnostics[idx].get("coverage_gain_ratio", 0.0))
while len(selected_set) > requested_top_k and coverage_candidates:
drop_idx = coverage_candidates.pop(0)
selected_set.remove(drop_idx)
diagnostics.pop(drop_idx, None)
final_indices = sorted(selected_set)
final_diags = [diagnostics[idx] for idx in final_indices]
return KeyframeSelectionResult(indices=final_indices, diagnostics=final_diags, top_k=len(final_indices))
def run_keyframe_prepass(
*,
runtime: WorkerRuntime,
payload: Mapping[str, Any],
frame_records: list[FrameRecord],
mode: str,
streaming: bool,
window_size: int | None,
) -> KeyframeSelectionResult | None:
if len(frame_records) <= 1:
return None
settings = runtime.settings
top_k_payload = int(payload.get("prepass_top_k") or payload.get("top_k_frames") or payload.get("top_k") or 0)
try:
inference = run_stream3r_inference(
runtime=runtime,
image_paths=[record.path for record in frame_records],
mode=mode,
streaming=streaming,
cache_output_path=None,
progress_cb=None,
window_size=window_size if streaming and mode == "window" else None,
)
except Exception:
logger.exception("Keyframe pre-pass inference failed")
return None
try:
return select_keyframes_motion_coverage(
frame_records,
inference.predictions,
settings,
requested_top_k=top_k_payload,
)
finally:
del inference
def extract_video_frames(
video_path: Path,
output_dir: Path,
*,
target_fps: float | None = None,
max_frames: int | None = None,
) -> tuple[list[FrameRecord], float]:
if not video_path.exists():
raise FileNotFoundError(f"Video file not found: {video_path}")
output_dir.mkdir(parents=True, exist_ok=True)
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError(f"Failed to open video: {video_path}")
native_fps = cap.get(cv2.CAP_PROP_FPS)
if not native_fps or native_fps <= 0:
native_fps = 30.0
frame_interval = 1
if target_fps and target_fps > 0:
frame_interval = max(1, int(round(native_fps / target_fps)))
frame_records: list[FrameRecord] = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
frame_idx = 0
extracted = 0
success, frame = cap.read()
while success:
if frame_idx % frame_interval == 0:
frame_id = f"frame_{extracted:06d}"
frame_path = output_dir / f"{frame_id}.jpg"
if not cv2.imwrite(str(frame_path), frame):
cap.release()
raise RuntimeError(f"Failed to write frame: {frame_path}")
timestamp_s = frame_idx / native_fps
frame_records.append(
FrameRecord(
index=extracted,
frame_id=frame_id,
path=frame_path,
metadata={"frame_number": frame_idx, "timestamp_s": timestamp_s},
)
)
extracted += 1
if max_frames and extracted >= max_frames:
break
frame_idx += 1
success, frame = cap.read()
cap.release()
if not frame_records:
raise RuntimeError("No frames extracted from video")
return frame_records, native_fps
def linear_sample_indices(total: int, desired: int) -> list[int]:
if desired <= 0 or total <= desired:
return list(range(total))
step = total / desired
return [min(total - 1, int(round(i * step))) for i in range(desired)]
def build_keyframe_uploads(
runtime: WorkerRuntime,
scene_id: str,
selected_records: Iterable[FrameRecord],
diagnostics: list[dict[str, Any]],
*,
subdir: str,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
diag_by_index = {entry.get("frame_index"): entry for entry in diagnostics}
storage_entries: list[dict[str, Any]] = []
media_entries: list[dict[str, Any]] = []
for record in selected_records:
diag = diag_by_index.get(record.index, {})
filename = f"{record.frame_id}.jpg"
key = runtime.storage.build_key(scene_id, subdir, filename)
uri = runtime.storage.upload_file(record.path, key, content_type="image/jpeg")
storage_entries.append(
{
"frame_id": record.frame_id,
"frame_index": record.index,
"url": uri,
"storage_key": key,
"diagnostics": diag,
}
)
media_entries.append(
{
"media_type": "image",
"file": key,
"captured_at": _diagnostic_captured_at(record, diag),
}
)
return storage_entries, media_entries
def _diagnostic_captured_at(record: FrameRecord, diag: Mapping[str, Any]) -> str | None:
if record.timestamp:
return record.timestamp
ts = diag.get("timestamp") or record.metadata.get("timestamp")
if isinstance(ts, str):
return ts
if isinstance(ts, (int, float)):
return datetime.utcfromtimestamp(float(ts)).isoformat() + "Z"
timestamp_s = record.metadata.get("timestamp_s")
if isinstance(timestamp_s, (int, float)):
return datetime.utcfromtimestamp(float(timestamp_s)).isoformat() + "Z"
return None