Spaces:
Configuration error
Configuration error
| """RQ job entrypoints for STream3R worker.""" | |
| from __future__ import annotations | |
| import base64 | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import shutil | |
| import tempfile | |
| import traceback | |
| import uuid | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from contextlib import nullcontext | |
| from time import perf_counter | |
| from typing import Any, Callable, Mapping | |
| import numpy as np | |
| import requests | |
| from rq import get_current_job | |
| from stream3r.utils.visual_utils import predictions_to_glb | |
| from .keyframes import ( | |
| FrameRecord, | |
| KeyframeSelectionResult, | |
| build_keyframe_uploads, | |
| extract_video_frames, | |
| linear_sample_indices, | |
| pose_confidence, | |
| run_keyframe_prepass, | |
| ) | |
| from .pipeline import InferenceResult, run_stream3r_inference | |
| from .runtime import WorkerRuntime, get_runtime | |
| logger = logging.getLogger(__name__) | |
| IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".bmp", ".webp"} | |
| _SAFE_CHARS = re.compile(r"[^0-9A-Za-z_-]") | |
| def _as_bool(value: Any, default: bool) -> bool: | |
| if isinstance(value, bool): | |
| return value | |
| if isinstance(value, str): | |
| lowered = value.strip().lower() | |
| if lowered in {"1", "true", "yes", "y", "on"}: | |
| return True | |
| if lowered in {"0", "false", "no", "n", "off"}: | |
| return False | |
| return default | |
| def _as_int(value: Any, default: int) -> int: | |
| try: | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return default | |
| class ProgressTracker: | |
| """Aggregates frame progress to percentage updates.""" | |
| def __init__(self, runtime: WorkerRuntime, job_meta: Mapping[str, str | None]): | |
| self.runtime = runtime | |
| self.job_meta = job_meta | |
| self.last_value = -1 | |
| def __call__(self, processed: int, total: int) -> None: | |
| if total <= 0: | |
| return | |
| percent = int(round((processed / total) * 100)) | |
| percent = max(0, min(100, percent)) | |
| if percent == self.last_value: | |
| return | |
| self.last_value = percent | |
| payload = { | |
| **self.job_meta, | |
| "status": "progress", | |
| "progress": percent, | |
| "ts": datetime.now(timezone.utc).timestamp(), | |
| } | |
| runtime_emit(self.runtime, payload) | |
| def runtime_emit(runtime: WorkerRuntime, payload: Mapping[str, Any]) -> None: | |
| runtime.emit_event(payload) | |
| def _slugify(value: str, fallback: str) -> str: | |
| candidate = _SAFE_CHARS.sub("_", value).strip("_") | |
| if not candidate: | |
| candidate = fallback | |
| return candidate[:128] | |
| def _is_url(value: str) -> bool: | |
| return value.startswith("http://") or value.startswith("https://") | |
| def _download_to_path(url: str, destination: Path) -> None: | |
| response = requests.get(url, stream=True, timeout=60) | |
| response.raise_for_status() | |
| with destination.open("wb") as handle: | |
| for chunk in response.iter_content(chunk_size=1 << 16): | |
| if chunk: | |
| handle.write(chunk) | |
| def _write_base64(content: str, destination: Path) -> None: | |
| data = base64.b64decode(content) | |
| destination.write_bytes(data) | |
| def _register_scene_media_entries(runtime: WorkerRuntime, scene_id: str, entries: list[dict[str, Any]]) -> None: | |
| if not entries: | |
| return | |
| base_url = runtime.settings.scene_media_api_base_url | |
| if not base_url: | |
| logger.info("Scene media API base URL not configured; skipping registration for %s", scene_id) | |
| return | |
| url = f"{base_url.rstrip('/')}/scenes/{scene_id}/media" | |
| headers: dict[str, str] = {"Content-Type": "application/json"} | |
| token = runtime.settings.scene_media_api_token | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| secret = runtime.settings.scene_media_api_secret | |
| if secret: | |
| headers["x-internal-secret"] = secret | |
| try: | |
| response = requests.post(url, json={"entries": entries}, headers=headers, timeout=30) | |
| if response.status_code == 405: | |
| logger.info( | |
| "Scene media API does not accept POST at %s (status %s); skipping registration", | |
| url, | |
| response.status_code, | |
| ) | |
| return | |
| response.raise_for_status() | |
| except requests.HTTPError as exc: | |
| status = exc.response.status_code if exc.response is not None else None | |
| if status == 422: | |
| logger.warning( | |
| "Scene media API rejected payload for scene %s (422): %s", | |
| scene_id, | |
| exc.response.text if exc.response is not None else "", | |
| ) | |
| return | |
| if status == 500: | |
| logger.warning( | |
| "Scene media API encountered server error (500) for scene %s; skipping registration", | |
| scene_id, | |
| ) | |
| return | |
| logger.exception("Failed to register scene media entries for scene %s", scene_id) | |
| except requests.RequestException: | |
| logger.exception("Failed to register scene media entries for scene %s", scene_id) | |
| def _resolve_frame_entry( | |
| entry: Any, | |
| *, | |
| index: int, | |
| dest_dir: Path, | |
| runtime: WorkerRuntime | None = None, | |
| ) -> FrameRecord: | |
| metadata: dict[str, Any] = {} | |
| timestamp = None | |
| source = None | |
| dest_dir.mkdir(parents=True, exist_ok=True) | |
| if isinstance(entry, str): | |
| if _is_url(entry): | |
| source = entry | |
| frame_id = _slugify(Path(entry).stem or f"frame_{index:06d}", f"frame_{index:06d}") | |
| destination = dest_dir / f"{frame_id}.jpg" | |
| _download_to_path(entry, destination) | |
| else: | |
| path = Path(entry) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Frame path does not exist: {entry}") | |
| frame_id = _slugify(path.stem, f"frame_{index:06d}") | |
| destination = dest_dir / path.name | |
| shutil.copy2(path, destination) | |
| elif isinstance(entry, Mapping): | |
| frame_id = _slugify(str(entry.get("frame_id") or entry.get("id") or f"frame_{index:06d}"), f"frame_{index:06d}") | |
| timestamp = entry.get("timestamp") | |
| metadata = {k: v for k, v in entry.items() if k not in {"path", "url", "content", "frame_id", "id", "timestamp"}} | |
| if path := entry.get("path") or entry.get("local_path"): | |
| path = Path(path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Frame path does not exist: {path}") | |
| destination = dest_dir / (path.name if path.suffix else f"{frame_id}.jpg") | |
| shutil.copy2(path, destination) | |
| elif storage_key := entry.get("storage_key") or entry.get("file") or entry.get("key"): | |
| if runtime is None: | |
| raise ValueError("Frame entry provided storage key but runtime is not available") | |
| filename = Path(str(storage_key)).name or f"{frame_id}.jpg" | |
| destination = dest_dir / filename | |
| runtime.storage.download_to_path(str(storage_key), destination) | |
| source = str(storage_key) | |
| elif url := entry.get("url"): | |
| source = url | |
| suffix = Path(url).suffix or ".jpg" | |
| destination = dest_dir / f"{frame_id}{suffix}" | |
| _download_to_path(url, destination) | |
| elif content := entry.get("content"): | |
| destination = dest_dir / f"{frame_id}.jpg" | |
| _write_base64(content, destination) | |
| else: | |
| raise ValueError("Frame entry must include 'path', 'url', or 'content'") | |
| if destination.suffix.lower() not in IMAGE_EXTENSIONS: | |
| destination = destination.with_suffix(".png") | |
| return FrameRecord( | |
| index=index, | |
| frame_id=_slugify(destination.stem, f"frame_{index:06d}"), | |
| path=destination, | |
| source=source, | |
| timestamp=timestamp, | |
| metadata=metadata, | |
| ) | |
| def _collect_frames( | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| payload: Mapping[str, Any], | |
| tmp_dir: Path, | |
| ) -> list[FrameRecord]: | |
| frames_dir = tmp_dir / "frames" | |
| frames_payload = payload.get("frames") or [] | |
| frame_limit = runtime.settings.max_frames_per_job | |
| records: list[FrameRecord] = [] | |
| if frames_payload: | |
| for entry in frames_payload: | |
| if frame_limit and frame_limit > 0 and len(records) >= frame_limit: | |
| break | |
| records.append( | |
| _resolve_frame_entry(entry, index=len(records), dest_dir=frames_dir, runtime=runtime) | |
| ) | |
| else: | |
| directory = payload.get("frames_dir") or payload.get("images_dir") | |
| if directory: | |
| directory_path = Path(directory) | |
| if not directory_path.is_dir(): | |
| raise ValueError(f"frames_dir does not exist: {directory}") | |
| for idx, file in enumerate(sorted(directory_path.iterdir())): | |
| if file.suffix.lower() not in IMAGE_EXTENSIONS: | |
| continue | |
| if frame_limit and frame_limit > 0 and len(records) >= frame_limit: | |
| break | |
| destination = frames_dir / file.name | |
| shutil.copy2(file, destination) | |
| records.append( | |
| FrameRecord( | |
| index=len(records), | |
| frame_id=_slugify(file.stem, f"frame_{idx:06d}"), | |
| path=destination, | |
| ) | |
| ) | |
| if not records: | |
| records = _collect_frames_from_scene_media(runtime, scene_id, frames_dir) | |
| if not records: | |
| raise ValueError(f"No valid frames found for scene '{scene_id}'") | |
| limit = runtime.settings.max_frames_per_job | |
| if limit and limit > 0 and len(records) > limit: | |
| records = records[:limit] | |
| for new_idx, record in enumerate(records): | |
| if record.index != new_idx: | |
| record.index = new_idx | |
| return records | |
| def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]: | |
| result = dict(payload) | |
| frames = result.pop("frames", None) | |
| if frames is not None: | |
| result["frame_count"] = len(frames) | |
| if "frames_dir" in result: | |
| result["frames_dir"] = str(result["frames_dir"]) | |
| return result | |
| def _prepare_session_settings( | |
| payload: Mapping[str, Any], | |
| *, | |
| mode: str, | |
| streaming: bool, | |
| frame_records: list[FrameRecord], | |
| window_size: int | None = None, | |
| ) -> dict[str, Any]: | |
| base_settings = payload.get("session_settings") or {} | |
| session_settings = dict(base_settings) | |
| session_settings.update( | |
| { | |
| "mode": mode, | |
| "streaming": streaming, | |
| "frame_count": len(frame_records), | |
| } | |
| ) | |
| window_setting = window_size if window_size is not None else payload.get("window_size") | |
| if window_setting: | |
| try: | |
| session_settings["window_size"] = int(window_setting) | |
| except (TypeError, ValueError): | |
| pass | |
| return session_settings | |
| def _collect_frames_from_scene_media( | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| dest_dir: Path, | |
| ) -> list[FrameRecord]: | |
| base_url = runtime.settings.scene_media_api_base_url | |
| if not base_url: | |
| raise ValueError( | |
| "Scene media API base URL is not configured. Set API_BASE_URL" | |
| ) | |
| base_url = base_url.rstrip("/") | |
| dest_dir.mkdir(parents=True, exist_ok=True) | |
| per_page = runtime.settings.scene_media_page_size | |
| if per_page <= 0: | |
| per_page = 100 | |
| per_page = max(1, min(per_page, 1000)) | |
| frame_limit = runtime.settings.max_frames_per_job | |
| headers = {} | |
| token = runtime.settings.scene_media_api_token | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| url = f"{base_url}/scenes/{scene_id}/media" | |
| session = requests.Session() | |
| records: list[FrameRecord] = [] | |
| offset = 0 | |
| while True: | |
| if frame_limit and frame_limit > 0 and len(records) >= frame_limit: | |
| break | |
| request_limit = per_page | |
| if frame_limit and frame_limit > 0: | |
| remaining = frame_limit - len(records) | |
| if remaining <= 0: | |
| break | |
| request_limit = min(request_limit, remaining) | |
| params = { | |
| "limit": request_limit, | |
| "offset": offset, | |
| "media_type": "image", | |
| } | |
| try: | |
| response = session.get(url, params=params, headers=headers, timeout=30) | |
| response.raise_for_status() | |
| except requests.RequestException as exc: | |
| raise RuntimeError(f"Failed to fetch media for scene '{scene_id}': {exc}") from exc | |
| data = response.json() | |
| items = data.get("items") or [] | |
| if not items: | |
| break | |
| for item in items: | |
| if frame_limit and frame_limit > 0 and len(records) >= frame_limit: | |
| break | |
| file_key = item.get("file") | |
| if not file_key: | |
| continue | |
| idx = len(records) | |
| source_path = Path(str(file_key)) | |
| suffix = source_path.suffix if source_path.suffix else ".png" | |
| frame_id = _slugify(source_path.stem or f"frame_{idx:06d}", f"frame_{idx:06d}") | |
| destination = dest_dir / f"{frame_id}{suffix}" | |
| try: | |
| runtime.storage.download_to_path(str(file_key), destination) | |
| except Exception as exc: # pragma: no cover - download depends on external storage | |
| raise RuntimeError(f"Failed to download media '{file_key}' for scene '{scene_id}': {exc}") from exc | |
| records.append( | |
| FrameRecord( | |
| index=idx, | |
| frame_id=frame_id, | |
| path=destination, | |
| source=str(file_key), | |
| timestamp=item.get("captured_at"), | |
| metadata={ | |
| "media_id": item.get("id"), | |
| "media_type": item.get("media_type"), | |
| }, | |
| ) | |
| ) | |
| if len(items) < request_limit: | |
| break | |
| offset += request_limit | |
| return records | |
| def _save_pointmaps( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| predictions: Mapping[str, np.ndarray], | |
| frame_records: list[FrameRecord], | |
| temp_dir: Path, | |
| ) -> dict[str, Any]: | |
| world_points = predictions.get("world_points") | |
| if world_points is None: | |
| world_points = predictions.get("world_points_from_depth") | |
| if world_points is None: | |
| raise RuntimeError("Predictions missing world points") | |
| world_points = np.asarray(world_points) | |
| confidence = pose_confidence(predictions) | |
| if confidence is None: | |
| confidence = np.ones(world_points.shape[:-1], dtype=np.float32) | |
| local_dir = temp_dir / "pointmaps" | |
| local_dir.mkdir(parents=True, exist_ok=True) | |
| entries: list[dict[str, Any]] = [] | |
| for record in frame_records: | |
| idx = record.index | |
| filename = f"{record.frame_id}.npz" | |
| local_file = local_dir / filename | |
| np.savez( | |
| local_file, | |
| xyz=np.asarray(world_points[idx], dtype=np.float32), | |
| confidence=np.asarray(confidence[idx], dtype=np.float32), | |
| ) | |
| key = runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir, filename) | |
| uri = runtime.storage.upload_file(local_file, key, content_type="application/octet-stream") | |
| entries.append( | |
| { | |
| "frame_id": record.frame_id, | |
| "frame_index": record.index, | |
| "url": uri, | |
| "timestamp": record.timestamp, | |
| } | |
| ) | |
| directory_uri = runtime.storage.build_uri( | |
| runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir) | |
| ) | |
| return { | |
| "pointmaps": entries, | |
| "pointmap_dir": directory_uri, | |
| } | |
| def _write_poses_jsonl( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| job_id: str, | |
| predictions: Mapping[str, np.ndarray], | |
| frame_records: list[FrameRecord], | |
| temp_dir: Path, | |
| ) -> str: | |
| extrinsic = np.asarray(predictions.get("extrinsic")) | |
| intrinsic = predictions.get("intrinsic") | |
| if intrinsic is not None: | |
| intrinsic = np.asarray(intrinsic) | |
| local_file = temp_dir / "poses.jsonl" | |
| with local_file.open("w", encoding="utf-8") as handle: | |
| for record in frame_records: | |
| idx = record.index | |
| payload = { | |
| "job_id": job_id, | |
| "scene_id": scene_id, | |
| "frame_id": record.frame_id, | |
| "frame_index": record.index, | |
| "extrinsic": extrinsic[idx].tolist(), | |
| } | |
| if intrinsic is not None: | |
| payload["intrinsic"] = intrinsic[idx].tolist() | |
| if record.timestamp is not None: | |
| payload["timestamp"] = record.timestamp | |
| if record.source is not None: | |
| payload["source"] = record.source | |
| if record.metadata: | |
| payload["metadata"] = record.metadata | |
| handle.write(json.dumps(payload)) | |
| handle.write("\n") | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.models_dir, | |
| runtime.settings.poses_filename, | |
| ) | |
| return runtime.storage.upload_file(local_file, key, content_type="application/json") | |
| def _upload_cache( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| cache_path: Path | None, | |
| ) -> str | None: | |
| if cache_path is None or not cache_path.exists(): | |
| return None | |
| if not runtime.settings.upload_session_cache: | |
| logger.debug( | |
| "Skipping session cache upload for scene %s (disabled via settings)", | |
| scene_id, | |
| ) | |
| return None | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.models_dir, | |
| runtime.settings.session_cache_filename, | |
| ) | |
| return runtime.storage.upload_file(cache_path, key, content_type="application/octet-stream") | |
| def _write_predictions_npz( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| predictions: Mapping[str, np.ndarray], | |
| temp_dir: Path, | |
| ) -> str: | |
| payload = {k: v for k, v in predictions.items() if isinstance(v, np.ndarray)} | |
| local_file = temp_dir / runtime.settings.predictions_filename | |
| np.savez(local_file, **payload) | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.models_dir, | |
| runtime.settings.predictions_filename, | |
| ) | |
| return runtime.storage.upload_file(local_file, key, content_type="application/octet-stream") | |
| def _write_session_settings( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| session_settings: Mapping[str, Any], | |
| temp_dir: Path, | |
| ) -> str: | |
| local_file = temp_dir / runtime.settings.session_settings_filename | |
| local_file.write_text(json.dumps(session_settings, indent=2), encoding="utf-8") | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.models_dir, | |
| runtime.settings.session_settings_filename, | |
| ) | |
| return runtime.storage.upload_file(local_file, key, content_type="application/json") | |
| def _write_selected_frames( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| selected_frames: list[dict[str, Any]], | |
| top_k: int, | |
| temp_dir: Path, | |
| ) -> str | None: | |
| if not selected_frames: | |
| return None | |
| local_file = temp_dir / runtime.settings.selected_frames_filename | |
| payload = {"top_k": top_k, "frames": selected_frames} | |
| local_file.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.models_dir, | |
| runtime.settings.selected_frames_filename, | |
| ) | |
| return runtime.storage.upload_file(local_file, key, content_type="application/json") | |
| def _camera_poses(extrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: | |
| matrices = np.asarray(extrinsic, dtype=np.float64) | |
| if matrices.ndim != 3 or matrices.shape[1:] != (3, 4): | |
| raise ValueError("Extrinsic array must have shape (N, 3, 4)") | |
| count = matrices.shape[0] | |
| rotations = np.empty((count, 3, 3), dtype=np.float64) | |
| translations = np.empty((count, 3), dtype=np.float64) | |
| for idx in range(count): | |
| mat = np.eye(4, dtype=np.float64) | |
| mat[:3, :4] = matrices[idx] | |
| cam_to_world = np.linalg.inv(mat) | |
| rotations[idx] = cam_to_world[:3, :3] | |
| translations[idx] = cam_to_world[:3, 3] | |
| return rotations, translations | |
| def _compute_motion_deltas(rotations: np.ndarray, translations: np.ndarray, rot_weight: float) -> np.ndarray: | |
| count = rotations.shape[0] | |
| deltas = np.zeros(count, dtype=np.float64) | |
| if count <= 1: | |
| return deltas | |
| for idx in range(1, count): | |
| delta_t = np.linalg.norm(translations[idx] - translations[idx - 1]) | |
| rel = rotations[idx - 1].T @ rotations[idx] | |
| trace = np.clip((np.trace(rel) - 1.0) / 2.0, -1.0, 1.0) | |
| delta_r = float(np.arccos(trace)) | |
| deltas[idx] = delta_t + rot_weight * delta_r | |
| return deltas | |
| def _hash_quantized_voxels(coords: np.ndarray) -> np.ndarray: | |
| coords = coords.astype(np.int64, copy=False) | |
| primes = np.array([73856093, 19349663, 83492791], dtype=np.int64) | |
| return coords @ primes | |
| def _frame_voxel_sets( | |
| world_points: np.ndarray, | |
| confidence: np.ndarray, | |
| *, | |
| threshold: float, | |
| voxel_size: float, | |
| max_points: int, | |
| ) -> tuple[list[set[int]], int]: | |
| rng = np.random.default_rng(42) | |
| frames = world_points.shape[0] | |
| voxel_sets: list[set[int]] = [] | |
| global_union: set[int] = set() | |
| if voxel_size <= 0.0: | |
| return [set() for _ in range(frames)], 0 | |
| for idx in range(frames): | |
| conf_frame = confidence[idx] | |
| mask = conf_frame >= threshold | |
| if not np.any(mask): | |
| voxel_sets.append(set()) | |
| continue | |
| points = world_points[idx][mask] | |
| if points.shape[0] > max_points: | |
| sample_idx = rng.choice(points.shape[0], max_points, replace=False) | |
| points = points[sample_idx] | |
| quantized = np.floor(points / voxel_size).astype(np.int64, copy=False) | |
| hashes = np.unique(_hash_quantized_voxels(quantized)) | |
| voxel_set = set(int(v) for v in hashes.tolist()) | |
| voxel_sets.append(voxel_set) | |
| global_union.update(voxel_set) | |
| return voxel_sets, len(global_union) | |
| def _select_motion_indices( | |
| motion_deltas: np.ndarray, | |
| *, | |
| threshold: float, | |
| min_gap: int, | |
| max_gap: int, | |
| ) -> tuple[list[int], dict[int, dict[str, float]]]: | |
| total_frames = motion_deltas.shape[0] | |
| if total_frames == 0: | |
| return [], {} | |
| selected = [0] | |
| diagnostics: dict[int, dict[str, float]] = {0: {"motion_delta": 0.0, "cum_motion": 0.0}} | |
| cumulative = 0.0 | |
| gap = 0 | |
| for idx in range(1, total_frames): | |
| delta = float(motion_deltas[idx]) | |
| cumulative += delta | |
| gap += 1 | |
| if gap < max(1, min_gap): | |
| continue | |
| should_select = cumulative >= threshold | |
| if max_gap > 0 and gap >= max_gap: | |
| should_select = True | |
| if should_select: | |
| selected.append(idx) | |
| diagnostics[idx] = {"motion_delta": delta, "cum_motion": cumulative} | |
| cumulative = 0.0 | |
| gap = 0 | |
| if selected[-1] != total_frames - 1: | |
| selected.append(total_frames - 1) | |
| diagnostics.setdefault(total_frames - 1, {"motion_delta": float(motion_deltas[-1]), "cum_motion": cumulative}) | |
| return selected, diagnostics | |
| def _select_keyframes_motion_coverage( | |
| frame_records: list[FrameRecord], | |
| predictions: Mapping[str, np.ndarray], | |
| settings: WorkerSettings, | |
| requested_top_k: int, | |
| ) -> KeyframeSelectionResult | None: | |
| extrinsic = np.asarray(predictions.get("extrinsic")) | |
| if extrinsic.size == 0: | |
| return None | |
| rotations, translations = _camera_poses(extrinsic) | |
| motion_deltas = _compute_motion_deltas(rotations, translations, settings.keyframe_rotation_weight) | |
| motion_indices, motion_diag = _select_motion_indices( | |
| motion_deltas, | |
| threshold=settings.keyframe_motion_threshold, | |
| min_gap=max(1, settings.keyframe_min_gap_frames), | |
| max_gap=max(0, settings.keyframe_max_gap_frames), | |
| ) | |
| total_frames = len(frame_records) | |
| confidence = pose_confidence(predictions) | |
| world_points = predictions.get("world_points") | |
| if world_points is None: | |
| world_points = predictions.get("world_points_from_depth") | |
| voxel_sets: list[set[int]] = [set() for _ in range(total_frames)] | |
| total_voxels = 0 | |
| mean_conf = np.zeros(total_frames, dtype=np.float32) | |
| if confidence is not None: | |
| mean_conf = confidence.reshape(confidence.shape[0], -1).mean(axis=1) | |
| if confidence is not None and world_points is not None: | |
| voxel_sets, total_voxels = _frame_voxel_sets( | |
| np.asarray(world_points), | |
| np.asarray(confidence), | |
| threshold=settings.keyframe_coverage_confidence, | |
| voxel_size=settings.keyframe_coverage_voxel_size, | |
| max_points=max(1000, settings.keyframe_coverage_max_points), | |
| ) | |
| total_voxels = max(total_voxels, 1) | |
| top_k = requested_top_k if requested_top_k > 0 else settings.keyframe_default_top_k | |
| top_k = max(min(top_k, total_frames), len(motion_indices)) | |
| selected_set: set[int] = set(motion_indices) | |
| diagnostics: dict[int, dict[str, Any]] = {} | |
| covered: set[int] = set() | |
| for idx in motion_indices: | |
| gain_count = len(voxel_sets[idx] - covered) if voxel_sets[idx] else 0 | |
| gain_ratio = gain_count / total_voxels | |
| covered.update(voxel_sets[idx]) | |
| diagnostics[idx] = { | |
| "frame_id": frame_records[idx].frame_id, | |
| "frame_index": frame_records[idx].index, | |
| "reason": "motion", | |
| "motion_delta": float(motion_deltas[idx]), | |
| "cum_motion": float(motion_diag.get(idx, {}).get("cum_motion", 0.0)), | |
| "coverage_gain_ratio": float(gain_ratio), | |
| "coverage_gain_count": int(gain_count), | |
| "mean_confidence": float(mean_conf[idx]) if confidence is not None else None, | |
| } | |
| if len(selected_set) < top_k and total_voxels > 0: | |
| min_gain_ratio = settings.keyframe_min_gain_ratio | |
| remaining = [i for i in range(total_frames) if i not in selected_set and voxel_sets[i]] | |
| while remaining and len(selected_set) < top_k: | |
| best_idx = -1 | |
| best_gain = -1 | |
| best_ratio = -1.0 | |
| for idx in remaining: | |
| gain = len(voxel_sets[idx] - covered) | |
| if gain <= 0: | |
| continue | |
| ratio = gain / total_voxels | |
| if ratio > best_ratio or (np.isclose(ratio, best_ratio) and gain > best_gain): | |
| best_idx = idx | |
| best_gain = gain | |
| best_ratio = ratio | |
| if best_idx == -1 or best_ratio < min_gain_ratio: | |
| break | |
| selected_set.add(best_idx) | |
| covered.update(voxel_sets[best_idx]) | |
| diagnostics[best_idx] = { | |
| "frame_id": frame_records[best_idx].frame_id, | |
| "frame_index": frame_records[best_idx].index, | |
| "reason": "coverage", | |
| "motion_delta": float(motion_deltas[best_idx]), | |
| "cum_motion": float(motion_diag.get(best_idx, {}).get("cum_motion", 0.0)), | |
| "coverage_gain_ratio": float(best_ratio), | |
| "coverage_gain_count": int(best_gain), | |
| "mean_confidence": float(mean_conf[best_idx]) if confidence is not None else None, | |
| } | |
| remaining.remove(best_idx) | |
| if requested_top_k > 0 and len(selected_set) > requested_top_k: | |
| coverage_candidates = [idx for idx in selected_set if diagnostics[idx]["reason"] == "coverage"] | |
| coverage_candidates.sort(key=lambda idx: diagnostics[idx].get("coverage_gain_ratio", 0.0)) | |
| while len(selected_set) > requested_top_k and coverage_candidates: | |
| drop_idx = coverage_candidates.pop(0) | |
| selected_set.remove(drop_idx) | |
| diagnostics.pop(drop_idx, None) | |
| final_indices = sorted(selected_set) | |
| final_diags = [diagnostics[idx] for idx in final_indices] | |
| return KeyframeSelectionResult(indices=final_indices, diagnostics=final_diags, top_k=len(final_indices)) | |
| def _compute_selected_frames( | |
| predictions: Mapping[str, np.ndarray], | |
| frame_records: list[FrameRecord], | |
| top_k: int, | |
| ) -> list[dict[str, Any]]: | |
| if top_k <= 0: | |
| return [] | |
| confidence = pose_confidence(predictions) | |
| if confidence is None: | |
| return [] | |
| scores = confidence.reshape(confidence.shape[0], -1).mean(axis=1) | |
| indices = np.argsort(scores)[::-1][:top_k] | |
| result = [] | |
| for idx in indices: | |
| record = frame_records[int(idx)] | |
| result.append( | |
| { | |
| "frame_id": record.frame_id, | |
| "frame_index": record.index, | |
| "score": float(scores[idx]), | |
| } | |
| ) | |
| return result | |
| def _run_keyframe_prepass( | |
| *, | |
| runtime: WorkerRuntime, | |
| payload: Mapping[str, Any], | |
| frame_records: list[FrameRecord], | |
| mode: str, | |
| streaming: bool, | |
| window_size: int | None, | |
| ) -> KeyframeSelectionResult | None: | |
| if len(frame_records) <= 1: | |
| return None | |
| settings = runtime.settings | |
| top_k_payload = _as_int(payload.get("prepass_top_k") or payload.get("top_k_frames") or payload.get("top_k"), 0) | |
| try: | |
| inference = run_stream3r_inference( | |
| runtime=runtime, | |
| image_paths=[record.path for record in frame_records], | |
| mode=mode, | |
| streaming=streaming, | |
| cache_output_path=None, | |
| progress_cb=None, | |
| window_size=window_size if streaming and mode == "window" else None, | |
| ) | |
| except Exception: | |
| logger.exception("Keyframe pre-pass inference failed") | |
| return None | |
| try: | |
| selection = _select_keyframes_motion_coverage( | |
| frame_records, | |
| inference.predictions, | |
| settings, | |
| requested_top_k=top_k_payload, | |
| ) | |
| finally: | |
| del inference | |
| return selection | |
| def _save_scene_glb( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| predictions: Mapping[str, np.ndarray], | |
| temp_dir: Path, | |
| payload: Mapping[str, Any], | |
| ) -> str: | |
| local_file = temp_dir / runtime.settings.scene_glb_filename | |
| ceiling_percentile = payload.get("ceiling_percentile") | |
| try: | |
| ceiling_percentile_value = float(ceiling_percentile) if ceiling_percentile is not None else None | |
| except (TypeError, ValueError): | |
| ceiling_percentile_value = None | |
| ceiling_margin_value = payload.get("ceiling_margin") | |
| try: | |
| ceiling_margin_value = float(ceiling_margin_value) if ceiling_margin_value is not None else 0.05 | |
| except (TypeError, ValueError): | |
| ceiling_margin_value = 0.05 | |
| ceiling_z_max = payload.get("ceiling_z_max") | |
| try: | |
| ceiling_z_max_value = float(ceiling_z_max) if ceiling_z_max is not None else None | |
| except (TypeError, ValueError): | |
| ceiling_z_max_value = None | |
| scene = predictions_to_glb( | |
| dict(predictions), | |
| conf_thres=float(payload.get("conf_thres", 3.0)), | |
| filter_by_frames=payload.get("frame_filter", "All"), | |
| mask_black_bg=_as_bool(payload.get("mask_black_bg"), False), | |
| mask_white_bg=_as_bool(payload.get("mask_white_bg"), False), | |
| show_cam=_as_bool(payload.get("show_cam"), False), | |
| mask_sky=_as_bool(payload.get("mask_sky"), False), | |
| target_dir=str(temp_dir), | |
| prediction_mode=payload.get("prediction_mode", "Predicted Pointmap"), | |
| ceiling_percentile=ceiling_percentile_value, | |
| ceiling_margin=ceiling_margin_value, | |
| ceiling_z_max=ceiling_z_max_value, | |
| ) | |
| scene.export(file_obj=str(local_file)) | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.models_dir, | |
| runtime.settings.scene_glb_filename, | |
| ) | |
| return runtime.storage.upload_file(local_file, key, content_type="model/gltf-binary") | |
| def _write_summary_json( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| summary: Mapping[str, Any], | |
| temp_dir: Path, | |
| ) -> str: | |
| filename = runtime.settings.result_filename | |
| local_file = temp_dir / filename | |
| local_file.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.models_dir, | |
| filename, | |
| ) | |
| return runtime.storage.upload_file(local_file, key, content_type="application/json") | |
| def _upload_result_record( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| job_id: str, | |
| payload: Mapping[str, Any], | |
| ) -> str: | |
| local = json.dumps(payload, indent=2).encode("utf-8") | |
| key = runtime.storage.build_key( | |
| scene_id, | |
| runtime.settings.results_dir, | |
| f"{job_id}.json", | |
| ) | |
| return runtime.storage.upload_bytes(local, key, content_type="application/json") | |
| def _model_dir_uri(runtime: WorkerRuntime, scene_id: str) -> str: | |
| return runtime.storage.build_uri( | |
| runtime.storage.build_key(scene_id, runtime.settings.models_dir) | |
| ) | |
| def _generate_core_outputs( | |
| *, | |
| runtime: WorkerRuntime, | |
| scene_id: str, | |
| job_id: str, | |
| predictions: Mapping[str, np.ndarray], | |
| frame_records: list[FrameRecord], | |
| inference: InferenceResult, | |
| session_settings: Mapping[str, Any], | |
| temp_dir: Path, | |
| ) -> dict[str, Any]: | |
| pointmap_info = _save_pointmaps( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| predictions=predictions, | |
| frame_records=frame_records, | |
| temp_dir=temp_dir, | |
| ) | |
| poses_url = _write_poses_jsonl( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| job_id=job_id, | |
| predictions=predictions, | |
| frame_records=frame_records, | |
| temp_dir=temp_dir, | |
| ) | |
| cache_url = _upload_cache( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| cache_path=inference.cache_path, | |
| ) | |
| predictions_url = _write_predictions_npz( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| predictions=predictions, | |
| temp_dir=temp_dir, | |
| ) | |
| session_settings_url = _write_session_settings( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| session_settings=session_settings, | |
| temp_dir=temp_dir, | |
| ) | |
| extrinsic = np.asarray(predictions.get("extrinsic")) | |
| intrinsic = predictions.get("intrinsic") | |
| if intrinsic is not None: | |
| intrinsic = np.asarray(intrinsic) | |
| frames_payload: list[dict[str, Any]] = [] | |
| for entry in pointmap_info["pointmaps"]: | |
| idx = entry["frame_index"] | |
| frame = frame_records[idx] | |
| frame_payload = { | |
| "frame_id": frame.frame_id, | |
| "frame_index": frame.index, | |
| "pointmap_url": entry["url"], | |
| "extrinsic": extrinsic[idx].tolist(), | |
| } | |
| if intrinsic is not None: | |
| frame_payload["intrinsic"] = intrinsic[idx].tolist() | |
| if frame.timestamp is not None: | |
| frame_payload["timestamp"] = frame.timestamp | |
| if frame.source is not None: | |
| frame_payload["source"] = frame.source | |
| frames_payload.append(frame_payload) | |
| artifacts = { | |
| "poses_url": poses_url, | |
| "pointmap_dir": pointmap_info["pointmap_dir"], | |
| "pointmaps": pointmap_info["pointmaps"], | |
| "predictions_url": predictions_url, | |
| "session_settings_url": session_settings_url, | |
| } | |
| if cache_url: | |
| artifacts["kv_cache_url"] = cache_url | |
| return { | |
| "artifacts": artifacts, | |
| "frames": frames_payload, | |
| } | |
| def _handle_pose_pointmap( | |
| *, | |
| runtime: WorkerRuntime, | |
| payload: Mapping[str, Any], | |
| mode: str, | |
| streaming: bool, | |
| job_id: str, | |
| scene_id: str, | |
| frame_records: list[FrameRecord], | |
| inference: InferenceResult, | |
| session_settings: Mapping[str, Any], | |
| temp_dir: Path, | |
| ) -> dict[str, Any]: | |
| predictions = inference.predictions | |
| core = _generate_core_outputs( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| job_id=job_id, | |
| predictions=predictions, | |
| frame_records=frame_records, | |
| inference=inference, | |
| session_settings=session_settings, | |
| temp_dir=temp_dir, | |
| ) | |
| result_payload = { | |
| "job_id": job_id, | |
| "job_type": "pose_pointmap", | |
| "scene_id": scene_id, | |
| "mode": mode, | |
| "streaming": streaming, | |
| "frame_count": inference.total_frames, | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| "artifacts": core["artifacts"], | |
| "frames": core["frames"], | |
| } | |
| selected_frames_payload = payload.get("_selected_frames_info") | |
| if selected_frames_payload: | |
| result_payload["selected_frames"] = list(selected_frames_payload) | |
| try: | |
| selected_frames_url = _write_selected_frames( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| selected_frames=list(selected_frames_payload), | |
| top_k=_as_int(payload.get("_selected_top_k"), len(selected_frames_payload)), | |
| temp_dir=temp_dir, | |
| ) | |
| if selected_frames_url: | |
| result_payload["artifacts"]["selected_frames_url"] = selected_frames_url | |
| except Exception: | |
| logger.exception("Failed to persist selected frames artifact for pose_pointmap job") | |
| result_url = _upload_result_record( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| job_id=job_id, | |
| payload=result_payload, | |
| ) | |
| result_payload["result_url"] = result_url | |
| result_payload["model_dir"] = _model_dir_uri(runtime, scene_id) | |
| return result_payload | |
| JobHandler = Callable[..., dict[str, Any]] | |
| def _execute_job(job_type: str, payload: Mapping[str, Any], handler: JobHandler) -> dict[str, Any]: | |
| runtime = get_runtime() | |
| job = get_current_job() | |
| payload = dict(payload) | |
| job_id = str(payload.get("job_id") or (job.id if job else uuid.uuid4())) | |
| scene_id = payload.get("scene_id") | |
| if not scene_id: | |
| raise ValueError("Job payload is missing 'scene_id'") | |
| payload.setdefault("job_type", job_type) | |
| payload.setdefault("scene_id", scene_id) | |
| mode = payload.get("mode") or runtime.settings.default_mode | |
| streaming = _as_bool(payload.get("streaming"), runtime.settings.default_streaming) | |
| window_size: int | None = None | |
| if mode == "window": | |
| streaming = True | |
| payload["streaming"] = True | |
| window_candidate = payload.get("window_size") or runtime.settings.stream_window_size | |
| try: | |
| window_size = int(window_candidate) if window_candidate else None | |
| except (TypeError, ValueError): | |
| window_size = runtime.settings.stream_window_size or None | |
| if window_size and window_size > 0: | |
| payload["window_size"] = window_size | |
| else: | |
| window_size = None | |
| payload["mode"] = mode | |
| desired_timeout = runtime.settings.default_job_timeout | |
| timeout_override = payload.get("timeout") | |
| applied_timeout: int | None = None | |
| if timeout_override is not None: | |
| try: | |
| applied_timeout = int(timeout_override) | |
| if job is not None: | |
| job.timeout = applied_timeout | |
| except (TypeError, ValueError): | |
| applied_timeout = None | |
| if applied_timeout is None and desired_timeout and desired_timeout > 0: | |
| if job is not None: | |
| current_timeout = getattr(job, "timeout", None) | |
| try: | |
| current_timeout_value = int(current_timeout) if current_timeout is not None else None | |
| except (TypeError, ValueError): | |
| current_timeout_value = None | |
| if current_timeout_value is None or current_timeout_value < desired_timeout: | |
| job.timeout = desired_timeout | |
| applied_timeout = desired_timeout | |
| else: | |
| applied_timeout = current_timeout_value | |
| else: | |
| applied_timeout = desired_timeout | |
| if applied_timeout is not None: | |
| payload["timeout"] = applied_timeout | |
| sanitized_payload = _sanitize_payload(payload) | |
| job_meta = { | |
| "job_id": job_id, | |
| "job_type": job_type, | |
| "scene_id": scene_id, | |
| } | |
| logger.info( | |
| "Job %s (%s) started for scene %s (timeout=%s)", | |
| job_id, | |
| job_type, | |
| scene_id, | |
| applied_timeout or desired_timeout or "default", | |
| ) | |
| start_time = perf_counter() | |
| last_time = start_time | |
| def log_progress(stage: str) -> None: | |
| nonlocal last_time | |
| now = perf_counter() | |
| logger.info( | |
| "Job %s (%s): %s [delta=%.2fs total=%.2fs]", | |
| job_id, | |
| job_type, | |
| stage, | |
| now - last_time, | |
| now - start_time, | |
| ) | |
| last_time = now | |
| runtime.db.upsert_job( | |
| job_id=job_id, | |
| job_type=job_type, | |
| scene_id=scene_id, | |
| status="started", | |
| payload=sanitized_payload, | |
| ) | |
| runtime_emit( | |
| runtime, | |
| { | |
| **job_meta, | |
| "status": "started", | |
| "progress": 0, | |
| "ts": datetime.now(timezone.utc).timestamp(), | |
| }, | |
| ) | |
| lock_ctx = nullcontext() if os.getenv("STREAM3R_GPU_LOCK_HELD") == "1" else runtime.gpu_lock() | |
| try: | |
| with lock_ctx: | |
| with tempfile.TemporaryDirectory(prefix=f"stream3r_{job_id}_") as tmp_dir: | |
| temp_path = Path(tmp_dir) | |
| frame_records = _collect_frames(runtime, scene_id, payload, temp_path) | |
| log_progress(f"collected frames ({len(frame_records)} items)") | |
| selection_result: KeyframeSelectionResult | None = None | |
| if runtime.settings.keyframe_prepass_enabled and len(frame_records) > 1: | |
| log_progress("starting keyframe pre-pass") | |
| try: | |
| selection_result = _run_keyframe_prepass( | |
| runtime=runtime, | |
| payload=payload, | |
| frame_records=frame_records, | |
| mode=mode, | |
| streaming=streaming, | |
| window_size=window_size, | |
| ) | |
| except Exception: | |
| selection_result = None | |
| logger.exception("Keyframe pre-pass failed; falling back to full frame set") | |
| if selection_result and selection_result.indices: | |
| log_progress( | |
| f"pre-pass selected {len(selection_result.indices)} frames from {len(frame_records)}" | |
| ) | |
| frame_records = [frame_records[i] for i in selection_result.indices] | |
| for new_idx, record in enumerate(frame_records): | |
| record.index = new_idx | |
| payload["_selected_frames_info"] = selection_result.diagnostics | |
| payload["_selected_top_k"] = selection_result.top_k | |
| payload["_selected_frame_indices"] = selection_result.indices | |
| if len(frame_records) <= runtime.settings.keyframe_full_mode_max_frames: | |
| mode = "full" | |
| streaming = False | |
| window_size = None | |
| payload["mode"] = mode | |
| payload["streaming"] = streaming | |
| else: | |
| selection_result = None | |
| cache_path = temp_path / runtime.settings.session_cache_filename if streaming else None | |
| tracker = ProgressTracker(runtime, job_meta) | |
| inference = run_stream3r_inference( | |
| runtime=runtime, | |
| image_paths=[record.path for record in frame_records], | |
| mode=mode, | |
| streaming=streaming, | |
| cache_output_path=cache_path, | |
| progress_cb=tracker, | |
| window_size=window_size if streaming and mode == "window" else None, | |
| ) | |
| log_progress(f"inference completed ({inference.total_frames} frames)") | |
| session_settings = _prepare_session_settings( | |
| payload, | |
| mode=mode, | |
| streaming=streaming, | |
| frame_records=frame_records, | |
| window_size=window_size, | |
| ) | |
| result_payload = handler( | |
| runtime=runtime, | |
| payload=payload, | |
| mode=mode, | |
| streaming=streaming, | |
| job_id=job_id, | |
| scene_id=scene_id, | |
| frame_records=frame_records, | |
| inference=inference, | |
| session_settings=session_settings, | |
| temp_dir=temp_path, | |
| ) | |
| log_progress("artifact generation completed") | |
| except Exception as exc: | |
| error_text = traceback.format_exc() | |
| runtime.db.upsert_job( | |
| job_id=job_id, | |
| job_type=job_type, | |
| scene_id=scene_id, | |
| status="failed", | |
| error=error_text, | |
| ) | |
| runtime_emit( | |
| runtime, | |
| { | |
| **job_meta, | |
| "status": "failed", | |
| "ts": datetime.now(timezone.utc).timestamp(), | |
| "error": str(exc), | |
| }, | |
| ) | |
| logger.exception( | |
| "Job %s (%s) failed after %.2fs: %s", | |
| job_id, | |
| job_type, | |
| perf_counter() - start_time, | |
| exc, | |
| ) | |
| raise | |
| log_progress("job finished") | |
| runtime.db.upsert_job( | |
| job_id=job_id, | |
| job_type=job_type, | |
| scene_id=scene_id, | |
| status="finished", | |
| result=result_payload, | |
| ) | |
| runtime_emit( | |
| runtime, | |
| { | |
| **job_meta, | |
| "status": "finished", | |
| "progress": 100, | |
| "result_url": result_payload.get("result_url"), | |
| "model_dir": result_payload.get("model_dir"), | |
| "ts": datetime.now(timezone.utc).timestamp(), | |
| }, | |
| ) | |
| return result_payload | |
| def pose_pointmap_job(payload: Mapping[str, Any]) -> dict[str, Any]: | |
| """Process a pose + pointmap job.""" | |
| return _execute_job("pose_pointmap", payload, _handle_pose_pointmap) | |
| def model_build_job(payload: Mapping[str, Any]) -> dict[str, Any]: | |
| """Process a full model build job.""" | |
| return _execute_job("model_build", payload, _handle_model_build) | |
| def _fallback_selection(frame_records: list[FrameRecord], top_k: int) -> KeyframeSelectionResult: | |
| indices = linear_sample_indices(len(frame_records), top_k) | |
| diagnostics = [ | |
| { | |
| "frame_id": frame_records[idx].frame_id, | |
| "frame_index": frame_records[idx].index, | |
| "reason": "linear", | |
| } | |
| for idx in indices | |
| ] | |
| return KeyframeSelectionResult(indices=indices, diagnostics=diagnostics, top_k=len(indices)) | |
| def keyframe_selection_job(payload: Mapping[str, Any]) -> dict[str, Any]: | |
| runtime = get_runtime() | |
| job = get_current_job() | |
| payload = dict(payload) | |
| job_id = str(payload.get("job_id") or (job.id if job else uuid.uuid4())) | |
| scene_id = payload.get("scene_id") | |
| if not scene_id: | |
| raise ValueError("Keyframe job payload is missing 'scene_id'") | |
| video_key = payload.get("video_key") | |
| if not video_key: | |
| raise ValueError("Keyframe job payload is missing 'video_key'") | |
| job_type = "keyframe_selection" | |
| job_meta = { | |
| "job_id": job_id, | |
| "job_type": job_type, | |
| "scene_id": scene_id, | |
| } | |
| sanitized_payload = { | |
| "scene_id": scene_id, | |
| "video_key": video_key, | |
| "top_k": payload.get("top_k"), | |
| "extract_fps": payload.get("extract_fps"), | |
| "extract_max_frames": payload.get("extract_max_frames"), | |
| } | |
| runtime.db.upsert_job( | |
| job_id=job_id, | |
| job_type=job_type, | |
| scene_id=scene_id, | |
| status="started", | |
| payload=sanitized_payload, | |
| ) | |
| runtime_emit( | |
| runtime, | |
| { | |
| **job_meta, | |
| "status": "started", | |
| "progress": 0, | |
| "ts": datetime.now(timezone.utc).timestamp(), | |
| }, | |
| ) | |
| start_time = perf_counter() | |
| try: | |
| with tempfile.TemporaryDirectory(prefix=f"keyframe_{job_id}_") as tmp_dir: | |
| temp_path = Path(tmp_dir) | |
| video_path = temp_path / "input_video" | |
| runtime.storage.download_to_path(video_key, video_path) | |
| extract_fps = payload.get("extract_fps") | |
| try: | |
| extract_fps_value = float(extract_fps) if extract_fps is not None else runtime.settings.keyframe_extract_fps | |
| except (TypeError, ValueError): | |
| extract_fps_value = runtime.settings.keyframe_extract_fps | |
| max_frames = _as_int( | |
| payload.get("extract_max_frames"), | |
| runtime.settings.keyframe_extract_max_frames, | |
| ) | |
| frame_records, native_fps = extract_video_frames( | |
| video_path, | |
| temp_path / "frames", | |
| target_fps=extract_fps_value, | |
| max_frames=max_frames, | |
| ) | |
| selection = run_keyframe_prepass( | |
| runtime=runtime, | |
| payload=payload, | |
| frame_records=frame_records, | |
| mode="window", | |
| streaming=True, | |
| window_size=runtime.settings.stream_window_size, | |
| ) | |
| if selection is None or not selection.indices: | |
| requested_top_k = _as_int(payload.get("top_k"), runtime.settings.keyframe_default_top_k) | |
| selection = _fallback_selection(frame_records, requested_top_k) | |
| selected_records = [frame_records[i] for i in selection.indices] | |
| storage_entries, media_entries = build_keyframe_uploads( | |
| runtime, | |
| scene_id, | |
| selected_records, | |
| selection.diagnostics, | |
| subdir=runtime.settings.keyframe_upload_dir, | |
| ) | |
| _register_scene_media_entries(runtime, scene_id, media_entries) | |
| result_payload = { | |
| "job_id": job_id, | |
| "job_type": job_type, | |
| "scene_id": scene_id, | |
| "video_key": video_key, | |
| "native_fps": native_fps, | |
| "total_frames": len(frame_records), | |
| "selected_frames": storage_entries, | |
| "selection": selection.diagnostics, | |
| } | |
| except Exception as exc: | |
| error_text = traceback.format_exc() | |
| runtime.db.upsert_job( | |
| job_id=job_id, | |
| job_type=job_type, | |
| scene_id=scene_id, | |
| status="failed", | |
| error=error_text, | |
| ) | |
| runtime_emit( | |
| runtime, | |
| { | |
| **job_meta, | |
| "status": "failed", | |
| "ts": datetime.now(timezone.utc).timestamp(), | |
| "error": str(exc), | |
| }, | |
| ) | |
| logger.exception("Keyframe selection job %s failed", job_id) | |
| raise | |
| runtime.db.upsert_job( | |
| job_id=job_id, | |
| job_type=job_type, | |
| scene_id=scene_id, | |
| status="finished", | |
| result=result_payload, | |
| ) | |
| runtime_emit( | |
| runtime, | |
| { | |
| **job_meta, | |
| "status": "finished", | |
| "progress": 100, | |
| "ts": datetime.now(timezone.utc).timestamp(), | |
| }, | |
| ) | |
| logger.info( | |
| "Keyframe selection job %s finished in %.2fs (selected %d/%d frames)", | |
| job_id, | |
| perf_counter() - start_time, | |
| len(selection.indices), | |
| len(frame_records), | |
| ) | |
| return result_payload | |
| def _handle_model_build( | |
| *, | |
| runtime: WorkerRuntime, | |
| payload: Mapping[str, Any], | |
| mode: str, | |
| streaming: bool, | |
| job_id: str, | |
| scene_id: str, | |
| frame_records: list[FrameRecord], | |
| inference: InferenceResult, | |
| session_settings: Mapping[str, Any], | |
| temp_dir: Path, | |
| ) -> dict[str, Any]: | |
| predictions = inference.predictions | |
| core = _generate_core_outputs( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| job_id=job_id, | |
| predictions=predictions, | |
| frame_records=frame_records, | |
| inference=inference, | |
| session_settings=session_settings, | |
| temp_dir=temp_dir, | |
| ) | |
| artifacts = dict(core["artifacts"]) | |
| selected_frames_payload = payload.get("_selected_frames_info") | |
| if selected_frames_payload: | |
| top_k = _as_int(payload.get("_selected_top_k"), len(selected_frames_payload)) | |
| selected_frames = list(selected_frames_payload) | |
| else: | |
| top_k = _as_int(payload.get("top_k_frames") or payload.get("top_k"), 0) | |
| selected_frames = _compute_selected_frames(predictions, frame_records, top_k) | |
| selected_frames_url = _write_selected_frames( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| selected_frames=selected_frames, | |
| top_k=top_k, | |
| temp_dir=temp_dir, | |
| ) | |
| if selected_frames_url: | |
| artifacts["selected_frames_url"] = selected_frames_url | |
| scene_glb_url = _save_scene_glb( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| predictions=predictions, | |
| temp_dir=temp_dir, | |
| payload=payload, | |
| ) | |
| artifacts["scene_glb_url"] = scene_glb_url | |
| summary_payload = { | |
| "job_id": job_id, | |
| "job_type": "model_build", | |
| "scene_id": scene_id, | |
| "frame_count": inference.total_frames, | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| "artifacts": artifacts, | |
| "selected_frames": selected_frames, | |
| "parameters": { | |
| "mode": mode, | |
| "streaming": streaming, | |
| "conf_thres": float(payload.get("conf_thres", 3.0)), | |
| "frame_filter": payload.get("frame_filter", "All"), | |
| "mask_black_bg": _as_bool(payload.get("mask_black_bg"), False), | |
| "mask_white_bg": _as_bool(payload.get("mask_white_bg"), False), | |
| "show_cam": _as_bool(payload.get("show_cam"), True), | |
| "mask_sky": _as_bool(payload.get("mask_sky"), False), | |
| "prediction_mode": payload.get("prediction_mode", "Predicted Pointmap"), | |
| }, | |
| } | |
| summary_url = _write_summary_json( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| summary=summary_payload, | |
| temp_dir=temp_dir, | |
| ) | |
| artifacts["summary_url"] = summary_url | |
| result_record = dict(summary_payload) | |
| result_record["result_url"] = summary_url | |
| result_record_url = _upload_result_record( | |
| runtime=runtime, | |
| scene_id=scene_id, | |
| job_id=job_id, | |
| payload=result_record, | |
| ) | |
| result_payload = { | |
| "job_id": job_id, | |
| "job_type": "model_build", | |
| "scene_id": scene_id, | |
| "mode": mode, | |
| "streaming": streaming, | |
| "frame_count": inference.total_frames, | |
| "created_at": summary_payload["created_at"], | |
| "artifacts": artifacts, | |
| "frames": core["frames"], | |
| "selected_frames": selected_frames, | |
| "summary_url": summary_url, | |
| "result_url": summary_url, | |
| "result_record_url": result_record_url, | |
| "model_dir": _model_dir_uri(runtime, scene_id), | |
| } | |
| return result_payload | |