brian4dwell's picture
workflows
4b27cfa
"""RQ job entrypoints for STream3R worker."""
from __future__ import annotations
import base64
import json
import logging
import os
import re
import shutil
import tempfile
import traceback
import uuid
from datetime import datetime, timezone
from pathlib import Path
from contextlib import nullcontext
from time import perf_counter
from typing import Any, Callable, Mapping
import numpy as np
import requests
from rq import get_current_job
from stream3r.utils.visual_utils import predictions_to_glb
from .keyframes import (
FrameRecord,
KeyframeSelectionResult,
build_keyframe_uploads,
extract_video_frames,
linear_sample_indices,
pose_confidence,
run_keyframe_prepass,
)
from .pipeline import InferenceResult, run_stream3r_inference
from .runtime import WorkerRuntime, get_runtime
logger = logging.getLogger(__name__)
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
_SAFE_CHARS = re.compile(r"[^0-9A-Za-z_-]")
def _as_bool(value: Any, default: bool) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in {"1", "true", "yes", "y", "on"}:
return True
if lowered in {"0", "false", "no", "n", "off"}:
return False
return default
def _as_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
class ProgressTracker:
"""Aggregates frame progress to percentage updates."""
def __init__(self, runtime: WorkerRuntime, job_meta: Mapping[str, str | None]):
self.runtime = runtime
self.job_meta = job_meta
self.last_value = -1
def __call__(self, processed: int, total: int) -> None:
if total <= 0:
return
percent = int(round((processed / total) * 100))
percent = max(0, min(100, percent))
if percent == self.last_value:
return
self.last_value = percent
payload = {
**self.job_meta,
"status": "progress",
"progress": percent,
"ts": datetime.now(timezone.utc).timestamp(),
}
runtime_emit(self.runtime, payload)
def runtime_emit(runtime: WorkerRuntime, payload: Mapping[str, Any]) -> None:
runtime.emit_event(payload)
def _slugify(value: str, fallback: str) -> str:
candidate = _SAFE_CHARS.sub("_", value).strip("_")
if not candidate:
candidate = fallback
return candidate[:128]
def _is_url(value: str) -> bool:
return value.startswith("http://") or value.startswith("https://")
def _download_to_path(url: str, destination: Path) -> None:
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
with destination.open("wb") as handle:
for chunk in response.iter_content(chunk_size=1 << 16):
if chunk:
handle.write(chunk)
def _write_base64(content: str, destination: Path) -> None:
data = base64.b64decode(content)
destination.write_bytes(data)
def _register_scene_media_entries(runtime: WorkerRuntime, scene_id: str, entries: list[dict[str, Any]]) -> None:
if not entries:
return
base_url = runtime.settings.scene_media_api_base_url
if not base_url:
logger.info("Scene media API base URL not configured; skipping registration for %s", scene_id)
return
url = f"{base_url.rstrip('/')}/scenes/{scene_id}/media"
headers: dict[str, str] = {"Content-Type": "application/json"}
token = runtime.settings.scene_media_api_token
if token:
headers["Authorization"] = f"Bearer {token}"
secret = runtime.settings.scene_media_api_secret
if secret:
headers["x-internal-secret"] = secret
try:
response = requests.post(url, json={"entries": entries}, headers=headers, timeout=30)
if response.status_code == 405:
logger.info(
"Scene media API does not accept POST at %s (status %s); skipping registration",
url,
response.status_code,
)
return
response.raise_for_status()
except requests.HTTPError as exc:
status = exc.response.status_code if exc.response is not None else None
if status == 422:
logger.warning(
"Scene media API rejected payload for scene %s (422): %s",
scene_id,
exc.response.text if exc.response is not None else "",
)
return
if status == 500:
logger.warning(
"Scene media API encountered server error (500) for scene %s; skipping registration",
scene_id,
)
return
logger.exception("Failed to register scene media entries for scene %s", scene_id)
except requests.RequestException:
logger.exception("Failed to register scene media entries for scene %s", scene_id)
def _resolve_frame_entry(
entry: Any,
*,
index: int,
dest_dir: Path,
runtime: WorkerRuntime | None = None,
) -> FrameRecord:
metadata: dict[str, Any] = {}
timestamp = None
source = None
dest_dir.mkdir(parents=True, exist_ok=True)
if isinstance(entry, str):
if _is_url(entry):
source = entry
frame_id = _slugify(Path(entry).stem or f"frame_{index:06d}", f"frame_{index:06d}")
destination = dest_dir / f"{frame_id}.jpg"
_download_to_path(entry, destination)
else:
path = Path(entry)
if not path.exists():
raise FileNotFoundError(f"Frame path does not exist: {entry}")
frame_id = _slugify(path.stem, f"frame_{index:06d}")
destination = dest_dir / path.name
shutil.copy2(path, destination)
elif isinstance(entry, Mapping):
frame_id = _slugify(str(entry.get("frame_id") or entry.get("id") or f"frame_{index:06d}"), f"frame_{index:06d}")
timestamp = entry.get("timestamp")
metadata = {k: v for k, v in entry.items() if k not in {"path", "url", "content", "frame_id", "id", "timestamp"}}
if path := entry.get("path") or entry.get("local_path"):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Frame path does not exist: {path}")
destination = dest_dir / (path.name if path.suffix else f"{frame_id}.jpg")
shutil.copy2(path, destination)
elif storage_key := entry.get("storage_key") or entry.get("file") or entry.get("key"):
if runtime is None:
raise ValueError("Frame entry provided storage key but runtime is not available")
filename = Path(str(storage_key)).name or f"{frame_id}.jpg"
destination = dest_dir / filename
runtime.storage.download_to_path(str(storage_key), destination)
source = str(storage_key)
elif url := entry.get("url"):
source = url
suffix = Path(url).suffix or ".jpg"
destination = dest_dir / f"{frame_id}{suffix}"
_download_to_path(url, destination)
elif content := entry.get("content"):
destination = dest_dir / f"{frame_id}.jpg"
_write_base64(content, destination)
else:
raise ValueError("Frame entry must include 'path', 'url', or 'content'")
if destination.suffix.lower() not in IMAGE_EXTENSIONS:
destination = destination.with_suffix(".png")
return FrameRecord(
index=index,
frame_id=_slugify(destination.stem, f"frame_{index:06d}"),
path=destination,
source=source,
timestamp=timestamp,
metadata=metadata,
)
def _collect_frames(
runtime: WorkerRuntime,
scene_id: str,
payload: Mapping[str, Any],
tmp_dir: Path,
) -> list[FrameRecord]:
frames_dir = tmp_dir / "frames"
frames_payload = payload.get("frames") or []
frame_limit = runtime.settings.max_frames_per_job
records: list[FrameRecord] = []
if frames_payload:
for entry in frames_payload:
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
break
records.append(
_resolve_frame_entry(entry, index=len(records), dest_dir=frames_dir, runtime=runtime)
)
else:
directory = payload.get("frames_dir") or payload.get("images_dir")
if directory:
directory_path = Path(directory)
if not directory_path.is_dir():
raise ValueError(f"frames_dir does not exist: {directory}")
for idx, file in enumerate(sorted(directory_path.iterdir())):
if file.suffix.lower() not in IMAGE_EXTENSIONS:
continue
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
break
destination = frames_dir / file.name
shutil.copy2(file, destination)
records.append(
FrameRecord(
index=len(records),
frame_id=_slugify(file.stem, f"frame_{idx:06d}"),
path=destination,
)
)
if not records:
records = _collect_frames_from_scene_media(runtime, scene_id, frames_dir)
if not records:
raise ValueError(f"No valid frames found for scene '{scene_id}'")
limit = runtime.settings.max_frames_per_job
if limit and limit > 0 and len(records) > limit:
records = records[:limit]
for new_idx, record in enumerate(records):
if record.index != new_idx:
record.index = new_idx
return records
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
result = dict(payload)
frames = result.pop("frames", None)
if frames is not None:
result["frame_count"] = len(frames)
if "frames_dir" in result:
result["frames_dir"] = str(result["frames_dir"])
return result
def _prepare_session_settings(
payload: Mapping[str, Any],
*,
mode: str,
streaming: bool,
frame_records: list[FrameRecord],
window_size: int | None = None,
) -> dict[str, Any]:
base_settings = payload.get("session_settings") or {}
session_settings = dict(base_settings)
session_settings.update(
{
"mode": mode,
"streaming": streaming,
"frame_count": len(frame_records),
}
)
window_setting = window_size if window_size is not None else payload.get("window_size")
if window_setting:
try:
session_settings["window_size"] = int(window_setting)
except (TypeError, ValueError):
pass
return session_settings
def _collect_frames_from_scene_media(
runtime: WorkerRuntime,
scene_id: str,
dest_dir: Path,
) -> list[FrameRecord]:
base_url = runtime.settings.scene_media_api_base_url
if not base_url:
raise ValueError(
"Scene media API base URL is not configured. Set API_BASE_URL"
)
base_url = base_url.rstrip("/")
dest_dir.mkdir(parents=True, exist_ok=True)
per_page = runtime.settings.scene_media_page_size
if per_page <= 0:
per_page = 100
per_page = max(1, min(per_page, 1000))
frame_limit = runtime.settings.max_frames_per_job
headers = {}
token = runtime.settings.scene_media_api_token
if token:
headers["Authorization"] = f"Bearer {token}"
url = f"{base_url}/scenes/{scene_id}/media"
session = requests.Session()
records: list[FrameRecord] = []
offset = 0
while True:
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
break
request_limit = per_page
if frame_limit and frame_limit > 0:
remaining = frame_limit - len(records)
if remaining <= 0:
break
request_limit = min(request_limit, remaining)
params = {
"limit": request_limit,
"offset": offset,
"media_type": "image",
}
try:
response = session.get(url, params=params, headers=headers, timeout=30)
response.raise_for_status()
except requests.RequestException as exc:
raise RuntimeError(f"Failed to fetch media for scene '{scene_id}': {exc}") from exc
data = response.json()
items = data.get("items") or []
if not items:
break
for item in items:
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
break
file_key = item.get("file")
if not file_key:
continue
idx = len(records)
source_path = Path(str(file_key))
suffix = source_path.suffix if source_path.suffix else ".png"
frame_id = _slugify(source_path.stem or f"frame_{idx:06d}", f"frame_{idx:06d}")
destination = dest_dir / f"{frame_id}{suffix}"
try:
runtime.storage.download_to_path(str(file_key), destination)
except Exception as exc: # pragma: no cover - download depends on external storage
raise RuntimeError(f"Failed to download media '{file_key}' for scene '{scene_id}': {exc}") from exc
records.append(
FrameRecord(
index=idx,
frame_id=frame_id,
path=destination,
source=str(file_key),
timestamp=item.get("captured_at"),
metadata={
"media_id": item.get("id"),
"media_type": item.get("media_type"),
},
)
)
if len(items) < request_limit:
break
offset += request_limit
return records
def _save_pointmaps(
*,
runtime: WorkerRuntime,
scene_id: str,
predictions: Mapping[str, np.ndarray],
frame_records: list[FrameRecord],
temp_dir: Path,
) -> dict[str, Any]:
world_points = predictions.get("world_points")
if world_points is None:
world_points = predictions.get("world_points_from_depth")
if world_points is None:
raise RuntimeError("Predictions missing world points")
world_points = np.asarray(world_points)
confidence = pose_confidence(predictions)
if confidence is None:
confidence = np.ones(world_points.shape[:-1], dtype=np.float32)
local_dir = temp_dir / "pointmaps"
local_dir.mkdir(parents=True, exist_ok=True)
entries: list[dict[str, Any]] = []
for record in frame_records:
idx = record.index
filename = f"{record.frame_id}.npz"
local_file = local_dir / filename
np.savez(
local_file,
xyz=np.asarray(world_points[idx], dtype=np.float32),
confidence=np.asarray(confidence[idx], dtype=np.float32),
)
key = runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir, filename)
uri = runtime.storage.upload_file(local_file, key, content_type="application/octet-stream")
entries.append(
{
"frame_id": record.frame_id,
"frame_index": record.index,
"url": uri,
"timestamp": record.timestamp,
}
)
directory_uri = runtime.storage.build_uri(
runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir)
)
return {
"pointmaps": entries,
"pointmap_dir": directory_uri,
}
def _write_poses_jsonl(
*,
runtime: WorkerRuntime,
scene_id: str,
job_id: str,
predictions: Mapping[str, np.ndarray],
frame_records: list[FrameRecord],
temp_dir: Path,
) -> str:
extrinsic = np.asarray(predictions.get("extrinsic"))
intrinsic = predictions.get("intrinsic")
if intrinsic is not None:
intrinsic = np.asarray(intrinsic)
local_file = temp_dir / "poses.jsonl"
with local_file.open("w", encoding="utf-8") as handle:
for record in frame_records:
idx = record.index
payload = {
"job_id": job_id,
"scene_id": scene_id,
"frame_id": record.frame_id,
"frame_index": record.index,
"extrinsic": extrinsic[idx].tolist(),
}
if intrinsic is not None:
payload["intrinsic"] = intrinsic[idx].tolist()
if record.timestamp is not None:
payload["timestamp"] = record.timestamp
if record.source is not None:
payload["source"] = record.source
if record.metadata:
payload["metadata"] = record.metadata
handle.write(json.dumps(payload))
handle.write("\n")
key = runtime.storage.build_key(
scene_id,
runtime.settings.models_dir,
runtime.settings.poses_filename,
)
return runtime.storage.upload_file(local_file, key, content_type="application/json")
def _upload_cache(
*,
runtime: WorkerRuntime,
scene_id: str,
cache_path: Path | None,
) -> str | None:
if cache_path is None or not cache_path.exists():
return None
if not runtime.settings.upload_session_cache:
logger.debug(
"Skipping session cache upload for scene %s (disabled via settings)",
scene_id,
)
return None
key = runtime.storage.build_key(
scene_id,
runtime.settings.models_dir,
runtime.settings.session_cache_filename,
)
return runtime.storage.upload_file(cache_path, key, content_type="application/octet-stream")
def _write_predictions_npz(
*,
runtime: WorkerRuntime,
scene_id: str,
predictions: Mapping[str, np.ndarray],
temp_dir: Path,
) -> str:
payload = {k: v for k, v in predictions.items() if isinstance(v, np.ndarray)}
local_file = temp_dir / runtime.settings.predictions_filename
np.savez(local_file, **payload)
key = runtime.storage.build_key(
scene_id,
runtime.settings.models_dir,
runtime.settings.predictions_filename,
)
return runtime.storage.upload_file(local_file, key, content_type="application/octet-stream")
def _write_session_settings(
*,
runtime: WorkerRuntime,
scene_id: str,
session_settings: Mapping[str, Any],
temp_dir: Path,
) -> str:
local_file = temp_dir / runtime.settings.session_settings_filename
local_file.write_text(json.dumps(session_settings, indent=2), encoding="utf-8")
key = runtime.storage.build_key(
scene_id,
runtime.settings.models_dir,
runtime.settings.session_settings_filename,
)
return runtime.storage.upload_file(local_file, key, content_type="application/json")
def _write_selected_frames(
*,
runtime: WorkerRuntime,
scene_id: str,
selected_frames: list[dict[str, Any]],
top_k: int,
temp_dir: Path,
) -> str | None:
if not selected_frames:
return None
local_file = temp_dir / runtime.settings.selected_frames_filename
payload = {"top_k": top_k, "frames": selected_frames}
local_file.write_text(json.dumps(payload, indent=2), encoding="utf-8")
key = runtime.storage.build_key(
scene_id,
runtime.settings.models_dir,
runtime.settings.selected_frames_filename,
)
return runtime.storage.upload_file(local_file, key, content_type="application/json")
def _camera_poses(extrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
matrices = np.asarray(extrinsic, dtype=np.float64)
if matrices.ndim != 3 or matrices.shape[1:] != (3, 4):
raise ValueError("Extrinsic array must have shape (N, 3, 4)")
count = matrices.shape[0]
rotations = np.empty((count, 3, 3), dtype=np.float64)
translations = np.empty((count, 3), dtype=np.float64)
for idx in range(count):
mat = np.eye(4, dtype=np.float64)
mat[:3, :4] = matrices[idx]
cam_to_world = np.linalg.inv(mat)
rotations[idx] = cam_to_world[:3, :3]
translations[idx] = cam_to_world[:3, 3]
return rotations, translations
def _compute_motion_deltas(rotations: np.ndarray, translations: np.ndarray, rot_weight: float) -> np.ndarray:
count = rotations.shape[0]
deltas = np.zeros(count, dtype=np.float64)
if count <= 1:
return deltas
for idx in range(1, count):
delta_t = np.linalg.norm(translations[idx] - translations[idx - 1])
rel = rotations[idx - 1].T @ rotations[idx]
trace = np.clip((np.trace(rel) - 1.0) / 2.0, -1.0, 1.0)
delta_r = float(np.arccos(trace))
deltas[idx] = delta_t + rot_weight * delta_r
return deltas
def _hash_quantized_voxels(coords: np.ndarray) -> np.ndarray:
coords = coords.astype(np.int64, copy=False)
primes = np.array([73856093, 19349663, 83492791], dtype=np.int64)
return coords @ primes
def _frame_voxel_sets(
world_points: np.ndarray,
confidence: np.ndarray,
*,
threshold: float,
voxel_size: float,
max_points: int,
) -> tuple[list[set[int]], int]:
rng = np.random.default_rng(42)
frames = world_points.shape[0]
voxel_sets: list[set[int]] = []
global_union: set[int] = set()
if voxel_size <= 0.0:
return [set() for _ in range(frames)], 0
for idx in range(frames):
conf_frame = confidence[idx]
mask = conf_frame >= threshold
if not np.any(mask):
voxel_sets.append(set())
continue
points = world_points[idx][mask]
if points.shape[0] > max_points:
sample_idx = rng.choice(points.shape[0], max_points, replace=False)
points = points[sample_idx]
quantized = np.floor(points / voxel_size).astype(np.int64, copy=False)
hashes = np.unique(_hash_quantized_voxels(quantized))
voxel_set = set(int(v) for v in hashes.tolist())
voxel_sets.append(voxel_set)
global_union.update(voxel_set)
return voxel_sets, len(global_union)
def _select_motion_indices(
motion_deltas: np.ndarray,
*,
threshold: float,
min_gap: int,
max_gap: int,
) -> tuple[list[int], dict[int, dict[str, float]]]:
total_frames = motion_deltas.shape[0]
if total_frames == 0:
return [], {}
selected = [0]
diagnostics: dict[int, dict[str, float]] = {0: {"motion_delta": 0.0, "cum_motion": 0.0}}
cumulative = 0.0
gap = 0
for idx in range(1, total_frames):
delta = float(motion_deltas[idx])
cumulative += delta
gap += 1
if gap < max(1, min_gap):
continue
should_select = cumulative >= threshold
if max_gap > 0 and gap >= max_gap:
should_select = True
if should_select:
selected.append(idx)
diagnostics[idx] = {"motion_delta": delta, "cum_motion": cumulative}
cumulative = 0.0
gap = 0
if selected[-1] != total_frames - 1:
selected.append(total_frames - 1)
diagnostics.setdefault(total_frames - 1, {"motion_delta": float(motion_deltas[-1]), "cum_motion": cumulative})
return selected, diagnostics
def _select_keyframes_motion_coverage(
frame_records: list[FrameRecord],
predictions: Mapping[str, np.ndarray],
settings: WorkerSettings,
requested_top_k: int,
) -> KeyframeSelectionResult | None:
extrinsic = np.asarray(predictions.get("extrinsic"))
if extrinsic.size == 0:
return None
rotations, translations = _camera_poses(extrinsic)
motion_deltas = _compute_motion_deltas(rotations, translations, settings.keyframe_rotation_weight)
motion_indices, motion_diag = _select_motion_indices(
motion_deltas,
threshold=settings.keyframe_motion_threshold,
min_gap=max(1, settings.keyframe_min_gap_frames),
max_gap=max(0, settings.keyframe_max_gap_frames),
)
total_frames = len(frame_records)
confidence = pose_confidence(predictions)
world_points = predictions.get("world_points")
if world_points is None:
world_points = predictions.get("world_points_from_depth")
voxel_sets: list[set[int]] = [set() for _ in range(total_frames)]
total_voxels = 0
mean_conf = np.zeros(total_frames, dtype=np.float32)
if confidence is not None:
mean_conf = confidence.reshape(confidence.shape[0], -1).mean(axis=1)
if confidence is not None and world_points is not None:
voxel_sets, total_voxels = _frame_voxel_sets(
np.asarray(world_points),
np.asarray(confidence),
threshold=settings.keyframe_coverage_confidence,
voxel_size=settings.keyframe_coverage_voxel_size,
max_points=max(1000, settings.keyframe_coverage_max_points),
)
total_voxels = max(total_voxels, 1)
top_k = requested_top_k if requested_top_k > 0 else settings.keyframe_default_top_k
top_k = max(min(top_k, total_frames), len(motion_indices))
selected_set: set[int] = set(motion_indices)
diagnostics: dict[int, dict[str, Any]] = {}
covered: set[int] = set()
for idx in motion_indices:
gain_count = len(voxel_sets[idx] - covered) if voxel_sets[idx] else 0
gain_ratio = gain_count / total_voxels
covered.update(voxel_sets[idx])
diagnostics[idx] = {
"frame_id": frame_records[idx].frame_id,
"frame_index": frame_records[idx].index,
"reason": "motion",
"motion_delta": float(motion_deltas[idx]),
"cum_motion": float(motion_diag.get(idx, {}).get("cum_motion", 0.0)),
"coverage_gain_ratio": float(gain_ratio),
"coverage_gain_count": int(gain_count),
"mean_confidence": float(mean_conf[idx]) if confidence is not None else None,
}
if len(selected_set) < top_k and total_voxels > 0:
min_gain_ratio = settings.keyframe_min_gain_ratio
remaining = [i for i in range(total_frames) if i not in selected_set and voxel_sets[i]]
while remaining and len(selected_set) < top_k:
best_idx = -1
best_gain = -1
best_ratio = -1.0
for idx in remaining:
gain = len(voxel_sets[idx] - covered)
if gain <= 0:
continue
ratio = gain / total_voxels
if ratio > best_ratio or (np.isclose(ratio, best_ratio) and gain > best_gain):
best_idx = idx
best_gain = gain
best_ratio = ratio
if best_idx == -1 or best_ratio < min_gain_ratio:
break
selected_set.add(best_idx)
covered.update(voxel_sets[best_idx])
diagnostics[best_idx] = {
"frame_id": frame_records[best_idx].frame_id,
"frame_index": frame_records[best_idx].index,
"reason": "coverage",
"motion_delta": float(motion_deltas[best_idx]),
"cum_motion": float(motion_diag.get(best_idx, {}).get("cum_motion", 0.0)),
"coverage_gain_ratio": float(best_ratio),
"coverage_gain_count": int(best_gain),
"mean_confidence": float(mean_conf[best_idx]) if confidence is not None else None,
}
remaining.remove(best_idx)
if requested_top_k > 0 and len(selected_set) > requested_top_k:
coverage_candidates = [idx for idx in selected_set if diagnostics[idx]["reason"] == "coverage"]
coverage_candidates.sort(key=lambda idx: diagnostics[idx].get("coverage_gain_ratio", 0.0))
while len(selected_set) > requested_top_k and coverage_candidates:
drop_idx = coverage_candidates.pop(0)
selected_set.remove(drop_idx)
diagnostics.pop(drop_idx, None)
final_indices = sorted(selected_set)
final_diags = [diagnostics[idx] for idx in final_indices]
return KeyframeSelectionResult(indices=final_indices, diagnostics=final_diags, top_k=len(final_indices))
def _compute_selected_frames(
predictions: Mapping[str, np.ndarray],
frame_records: list[FrameRecord],
top_k: int,
) -> list[dict[str, Any]]:
if top_k <= 0:
return []
confidence = pose_confidence(predictions)
if confidence is None:
return []
scores = confidence.reshape(confidence.shape[0], -1).mean(axis=1)
indices = np.argsort(scores)[::-1][:top_k]
result = []
for idx in indices:
record = frame_records[int(idx)]
result.append(
{
"frame_id": record.frame_id,
"frame_index": record.index,
"score": float(scores[idx]),
}
)
return result
def _run_keyframe_prepass(
*,
runtime: WorkerRuntime,
payload: Mapping[str, Any],
frame_records: list[FrameRecord],
mode: str,
streaming: bool,
window_size: int | None,
) -> KeyframeSelectionResult | None:
if len(frame_records) <= 1:
return None
settings = runtime.settings
top_k_payload = _as_int(payload.get("prepass_top_k") or payload.get("top_k_frames") or payload.get("top_k"), 0)
try:
inference = run_stream3r_inference(
runtime=runtime,
image_paths=[record.path for record in frame_records],
mode=mode,
streaming=streaming,
cache_output_path=None,
progress_cb=None,
window_size=window_size if streaming and mode == "window" else None,
)
except Exception:
logger.exception("Keyframe pre-pass inference failed")
return None
try:
selection = _select_keyframes_motion_coverage(
frame_records,
inference.predictions,
settings,
requested_top_k=top_k_payload,
)
finally:
del inference
return selection
def _save_scene_glb(
*,
runtime: WorkerRuntime,
scene_id: str,
predictions: Mapping[str, np.ndarray],
temp_dir: Path,
payload: Mapping[str, Any],
) -> str:
local_file = temp_dir / runtime.settings.scene_glb_filename
ceiling_percentile = payload.get("ceiling_percentile")
try:
ceiling_percentile_value = float(ceiling_percentile) if ceiling_percentile is not None else None
except (TypeError, ValueError):
ceiling_percentile_value = None
ceiling_margin_value = payload.get("ceiling_margin")
try:
ceiling_margin_value = float(ceiling_margin_value) if ceiling_margin_value is not None else 0.05
except (TypeError, ValueError):
ceiling_margin_value = 0.05
ceiling_z_max = payload.get("ceiling_z_max")
try:
ceiling_z_max_value = float(ceiling_z_max) if ceiling_z_max is not None else None
except (TypeError, ValueError):
ceiling_z_max_value = None
scene = predictions_to_glb(
dict(predictions),
conf_thres=float(payload.get("conf_thres", 3.0)),
filter_by_frames=payload.get("frame_filter", "All"),
mask_black_bg=_as_bool(payload.get("mask_black_bg"), False),
mask_white_bg=_as_bool(payload.get("mask_white_bg"), False),
show_cam=_as_bool(payload.get("show_cam"), False),
mask_sky=_as_bool(payload.get("mask_sky"), False),
target_dir=str(temp_dir),
prediction_mode=payload.get("prediction_mode", "Predicted Pointmap"),
ceiling_percentile=ceiling_percentile_value,
ceiling_margin=ceiling_margin_value,
ceiling_z_max=ceiling_z_max_value,
)
scene.export(file_obj=str(local_file))
key = runtime.storage.build_key(
scene_id,
runtime.settings.models_dir,
runtime.settings.scene_glb_filename,
)
return runtime.storage.upload_file(local_file, key, content_type="model/gltf-binary")
def _write_summary_json(
*,
runtime: WorkerRuntime,
scene_id: str,
summary: Mapping[str, Any],
temp_dir: Path,
) -> str:
filename = runtime.settings.result_filename
local_file = temp_dir / filename
local_file.write_text(json.dumps(summary, indent=2), encoding="utf-8")
key = runtime.storage.build_key(
scene_id,
runtime.settings.models_dir,
filename,
)
return runtime.storage.upload_file(local_file, key, content_type="application/json")
def _upload_result_record(
*,
runtime: WorkerRuntime,
scene_id: str,
job_id: str,
payload: Mapping[str, Any],
) -> str:
local = json.dumps(payload, indent=2).encode("utf-8")
key = runtime.storage.build_key(
scene_id,
runtime.settings.results_dir,
f"{job_id}.json",
)
return runtime.storage.upload_bytes(local, key, content_type="application/json")
def _model_dir_uri(runtime: WorkerRuntime, scene_id: str) -> str:
return runtime.storage.build_uri(
runtime.storage.build_key(scene_id, runtime.settings.models_dir)
)
def _generate_core_outputs(
*,
runtime: WorkerRuntime,
scene_id: str,
job_id: str,
predictions: Mapping[str, np.ndarray],
frame_records: list[FrameRecord],
inference: InferenceResult,
session_settings: Mapping[str, Any],
temp_dir: Path,
) -> dict[str, Any]:
pointmap_info = _save_pointmaps(
runtime=runtime,
scene_id=scene_id,
predictions=predictions,
frame_records=frame_records,
temp_dir=temp_dir,
)
poses_url = _write_poses_jsonl(
runtime=runtime,
scene_id=scene_id,
job_id=job_id,
predictions=predictions,
frame_records=frame_records,
temp_dir=temp_dir,
)
cache_url = _upload_cache(
runtime=runtime,
scene_id=scene_id,
cache_path=inference.cache_path,
)
predictions_url = _write_predictions_npz(
runtime=runtime,
scene_id=scene_id,
predictions=predictions,
temp_dir=temp_dir,
)
session_settings_url = _write_session_settings(
runtime=runtime,
scene_id=scene_id,
session_settings=session_settings,
temp_dir=temp_dir,
)
extrinsic = np.asarray(predictions.get("extrinsic"))
intrinsic = predictions.get("intrinsic")
if intrinsic is not None:
intrinsic = np.asarray(intrinsic)
frames_payload: list[dict[str, Any]] = []
for entry in pointmap_info["pointmaps"]:
idx = entry["frame_index"]
frame = frame_records[idx]
frame_payload = {
"frame_id": frame.frame_id,
"frame_index": frame.index,
"pointmap_url": entry["url"],
"extrinsic": extrinsic[idx].tolist(),
}
if intrinsic is not None:
frame_payload["intrinsic"] = intrinsic[idx].tolist()
if frame.timestamp is not None:
frame_payload["timestamp"] = frame.timestamp
if frame.source is not None:
frame_payload["source"] = frame.source
frames_payload.append(frame_payload)
artifacts = {
"poses_url": poses_url,
"pointmap_dir": pointmap_info["pointmap_dir"],
"pointmaps": pointmap_info["pointmaps"],
"predictions_url": predictions_url,
"session_settings_url": session_settings_url,
}
if cache_url:
artifacts["kv_cache_url"] = cache_url
return {
"artifacts": artifacts,
"frames": frames_payload,
}
def _handle_pose_pointmap(
*,
runtime: WorkerRuntime,
payload: Mapping[str, Any],
mode: str,
streaming: bool,
job_id: str,
scene_id: str,
frame_records: list[FrameRecord],
inference: InferenceResult,
session_settings: Mapping[str, Any],
temp_dir: Path,
) -> dict[str, Any]:
predictions = inference.predictions
core = _generate_core_outputs(
runtime=runtime,
scene_id=scene_id,
job_id=job_id,
predictions=predictions,
frame_records=frame_records,
inference=inference,
session_settings=session_settings,
temp_dir=temp_dir,
)
result_payload = {
"job_id": job_id,
"job_type": "pose_pointmap",
"scene_id": scene_id,
"mode": mode,
"streaming": streaming,
"frame_count": inference.total_frames,
"created_at": datetime.now(timezone.utc).isoformat(),
"artifacts": core["artifacts"],
"frames": core["frames"],
}
selected_frames_payload = payload.get("_selected_frames_info")
if selected_frames_payload:
result_payload["selected_frames"] = list(selected_frames_payload)
try:
selected_frames_url = _write_selected_frames(
runtime=runtime,
scene_id=scene_id,
selected_frames=list(selected_frames_payload),
top_k=_as_int(payload.get("_selected_top_k"), len(selected_frames_payload)),
temp_dir=temp_dir,
)
if selected_frames_url:
result_payload["artifacts"]["selected_frames_url"] = selected_frames_url
except Exception:
logger.exception("Failed to persist selected frames artifact for pose_pointmap job")
result_url = _upload_result_record(
runtime=runtime,
scene_id=scene_id,
job_id=job_id,
payload=result_payload,
)
result_payload["result_url"] = result_url
result_payload["model_dir"] = _model_dir_uri(runtime, scene_id)
return result_payload
JobHandler = Callable[..., dict[str, Any]]
def _execute_job(job_type: str, payload: Mapping[str, Any], handler: JobHandler) -> dict[str, Any]:
runtime = get_runtime()
job = get_current_job()
payload = dict(payload)
job_id = str(payload.get("job_id") or (job.id if job else uuid.uuid4()))
scene_id = payload.get("scene_id")
if not scene_id:
raise ValueError("Job payload is missing 'scene_id'")
payload.setdefault("job_type", job_type)
payload.setdefault("scene_id", scene_id)
mode = payload.get("mode") or runtime.settings.default_mode
streaming = _as_bool(payload.get("streaming"), runtime.settings.default_streaming)
window_size: int | None = None
if mode == "window":
streaming = True
payload["streaming"] = True
window_candidate = payload.get("window_size") or runtime.settings.stream_window_size
try:
window_size = int(window_candidate) if window_candidate else None
except (TypeError, ValueError):
window_size = runtime.settings.stream_window_size or None
if window_size and window_size > 0:
payload["window_size"] = window_size
else:
window_size = None
payload["mode"] = mode
desired_timeout = runtime.settings.default_job_timeout
timeout_override = payload.get("timeout")
applied_timeout: int | None = None
if timeout_override is not None:
try:
applied_timeout = int(timeout_override)
if job is not None:
job.timeout = applied_timeout
except (TypeError, ValueError):
applied_timeout = None
if applied_timeout is None and desired_timeout and desired_timeout > 0:
if job is not None:
current_timeout = getattr(job, "timeout", None)
try:
current_timeout_value = int(current_timeout) if current_timeout is not None else None
except (TypeError, ValueError):
current_timeout_value = None
if current_timeout_value is None or current_timeout_value < desired_timeout:
job.timeout = desired_timeout
applied_timeout = desired_timeout
else:
applied_timeout = current_timeout_value
else:
applied_timeout = desired_timeout
if applied_timeout is not None:
payload["timeout"] = applied_timeout
sanitized_payload = _sanitize_payload(payload)
job_meta = {
"job_id": job_id,
"job_type": job_type,
"scene_id": scene_id,
}
logger.info(
"Job %s (%s) started for scene %s (timeout=%s)",
job_id,
job_type,
scene_id,
applied_timeout or desired_timeout or "default",
)
start_time = perf_counter()
last_time = start_time
def log_progress(stage: str) -> None:
nonlocal last_time
now = perf_counter()
logger.info(
"Job %s (%s): %s [delta=%.2fs total=%.2fs]",
job_id,
job_type,
stage,
now - last_time,
now - start_time,
)
last_time = now
runtime.db.upsert_job(
job_id=job_id,
job_type=job_type,
scene_id=scene_id,
status="started",
payload=sanitized_payload,
)
runtime_emit(
runtime,
{
**job_meta,
"status": "started",
"progress": 0,
"ts": datetime.now(timezone.utc).timestamp(),
},
)
lock_ctx = nullcontext() if os.getenv("STREAM3R_GPU_LOCK_HELD") == "1" else runtime.gpu_lock()
try:
with lock_ctx:
with tempfile.TemporaryDirectory(prefix=f"stream3r_{job_id}_") as tmp_dir:
temp_path = Path(tmp_dir)
frame_records = _collect_frames(runtime, scene_id, payload, temp_path)
log_progress(f"collected frames ({len(frame_records)} items)")
selection_result: KeyframeSelectionResult | None = None
if runtime.settings.keyframe_prepass_enabled and len(frame_records) > 1:
log_progress("starting keyframe pre-pass")
try:
selection_result = _run_keyframe_prepass(
runtime=runtime,
payload=payload,
frame_records=frame_records,
mode=mode,
streaming=streaming,
window_size=window_size,
)
except Exception:
selection_result = None
logger.exception("Keyframe pre-pass failed; falling back to full frame set")
if selection_result and selection_result.indices:
log_progress(
f"pre-pass selected {len(selection_result.indices)} frames from {len(frame_records)}"
)
frame_records = [frame_records[i] for i in selection_result.indices]
for new_idx, record in enumerate(frame_records):
record.index = new_idx
payload["_selected_frames_info"] = selection_result.diagnostics
payload["_selected_top_k"] = selection_result.top_k
payload["_selected_frame_indices"] = selection_result.indices
if len(frame_records) <= runtime.settings.keyframe_full_mode_max_frames:
mode = "full"
streaming = False
window_size = None
payload["mode"] = mode
payload["streaming"] = streaming
else:
selection_result = None
cache_path = temp_path / runtime.settings.session_cache_filename if streaming else None
tracker = ProgressTracker(runtime, job_meta)
inference = run_stream3r_inference(
runtime=runtime,
image_paths=[record.path for record in frame_records],
mode=mode,
streaming=streaming,
cache_output_path=cache_path,
progress_cb=tracker,
window_size=window_size if streaming and mode == "window" else None,
)
log_progress(f"inference completed ({inference.total_frames} frames)")
session_settings = _prepare_session_settings(
payload,
mode=mode,
streaming=streaming,
frame_records=frame_records,
window_size=window_size,
)
result_payload = handler(
runtime=runtime,
payload=payload,
mode=mode,
streaming=streaming,
job_id=job_id,
scene_id=scene_id,
frame_records=frame_records,
inference=inference,
session_settings=session_settings,
temp_dir=temp_path,
)
log_progress("artifact generation completed")
except Exception as exc:
error_text = traceback.format_exc()
runtime.db.upsert_job(
job_id=job_id,
job_type=job_type,
scene_id=scene_id,
status="failed",
error=error_text,
)
runtime_emit(
runtime,
{
**job_meta,
"status": "failed",
"ts": datetime.now(timezone.utc).timestamp(),
"error": str(exc),
},
)
logger.exception(
"Job %s (%s) failed after %.2fs: %s",
job_id,
job_type,
perf_counter() - start_time,
exc,
)
raise
log_progress("job finished")
runtime.db.upsert_job(
job_id=job_id,
job_type=job_type,
scene_id=scene_id,
status="finished",
result=result_payload,
)
runtime_emit(
runtime,
{
**job_meta,
"status": "finished",
"progress": 100,
"result_url": result_payload.get("result_url"),
"model_dir": result_payload.get("model_dir"),
"ts": datetime.now(timezone.utc).timestamp(),
},
)
return result_payload
def pose_pointmap_job(payload: Mapping[str, Any]) -> dict[str, Any]:
"""Process a pose + pointmap job."""
return _execute_job("pose_pointmap", payload, _handle_pose_pointmap)
def model_build_job(payload: Mapping[str, Any]) -> dict[str, Any]:
"""Process a full model build job."""
return _execute_job("model_build", payload, _handle_model_build)
def _fallback_selection(frame_records: list[FrameRecord], top_k: int) -> KeyframeSelectionResult:
indices = linear_sample_indices(len(frame_records), top_k)
diagnostics = [
{
"frame_id": frame_records[idx].frame_id,
"frame_index": frame_records[idx].index,
"reason": "linear",
}
for idx in indices
]
return KeyframeSelectionResult(indices=indices, diagnostics=diagnostics, top_k=len(indices))
def keyframe_selection_job(payload: Mapping[str, Any]) -> dict[str, Any]:
runtime = get_runtime()
job = get_current_job()
payload = dict(payload)
job_id = str(payload.get("job_id") or (job.id if job else uuid.uuid4()))
scene_id = payload.get("scene_id")
if not scene_id:
raise ValueError("Keyframe job payload is missing 'scene_id'")
video_key = payload.get("video_key")
if not video_key:
raise ValueError("Keyframe job payload is missing 'video_key'")
job_type = "keyframe_selection"
job_meta = {
"job_id": job_id,
"job_type": job_type,
"scene_id": scene_id,
}
sanitized_payload = {
"scene_id": scene_id,
"video_key": video_key,
"top_k": payload.get("top_k"),
"extract_fps": payload.get("extract_fps"),
"extract_max_frames": payload.get("extract_max_frames"),
}
runtime.db.upsert_job(
job_id=job_id,
job_type=job_type,
scene_id=scene_id,
status="started",
payload=sanitized_payload,
)
runtime_emit(
runtime,
{
**job_meta,
"status": "started",
"progress": 0,
"ts": datetime.now(timezone.utc).timestamp(),
},
)
start_time = perf_counter()
try:
with tempfile.TemporaryDirectory(prefix=f"keyframe_{job_id}_") as tmp_dir:
temp_path = Path(tmp_dir)
video_path = temp_path / "input_video"
runtime.storage.download_to_path(video_key, video_path)
extract_fps = payload.get("extract_fps")
try:
extract_fps_value = float(extract_fps) if extract_fps is not None else runtime.settings.keyframe_extract_fps
except (TypeError, ValueError):
extract_fps_value = runtime.settings.keyframe_extract_fps
max_frames = _as_int(
payload.get("extract_max_frames"),
runtime.settings.keyframe_extract_max_frames,
)
frame_records, native_fps = extract_video_frames(
video_path,
temp_path / "frames",
target_fps=extract_fps_value,
max_frames=max_frames,
)
selection = run_keyframe_prepass(
runtime=runtime,
payload=payload,
frame_records=frame_records,
mode="window",
streaming=True,
window_size=runtime.settings.stream_window_size,
)
if selection is None or not selection.indices:
requested_top_k = _as_int(payload.get("top_k"), runtime.settings.keyframe_default_top_k)
selection = _fallback_selection(frame_records, requested_top_k)
selected_records = [frame_records[i] for i in selection.indices]
storage_entries, media_entries = build_keyframe_uploads(
runtime,
scene_id,
selected_records,
selection.diagnostics,
subdir=runtime.settings.keyframe_upload_dir,
)
_register_scene_media_entries(runtime, scene_id, media_entries)
result_payload = {
"job_id": job_id,
"job_type": job_type,
"scene_id": scene_id,
"video_key": video_key,
"native_fps": native_fps,
"total_frames": len(frame_records),
"selected_frames": storage_entries,
"selection": selection.diagnostics,
}
except Exception as exc:
error_text = traceback.format_exc()
runtime.db.upsert_job(
job_id=job_id,
job_type=job_type,
scene_id=scene_id,
status="failed",
error=error_text,
)
runtime_emit(
runtime,
{
**job_meta,
"status": "failed",
"ts": datetime.now(timezone.utc).timestamp(),
"error": str(exc),
},
)
logger.exception("Keyframe selection job %s failed", job_id)
raise
runtime.db.upsert_job(
job_id=job_id,
job_type=job_type,
scene_id=scene_id,
status="finished",
result=result_payload,
)
runtime_emit(
runtime,
{
**job_meta,
"status": "finished",
"progress": 100,
"ts": datetime.now(timezone.utc).timestamp(),
},
)
logger.info(
"Keyframe selection job %s finished in %.2fs (selected %d/%d frames)",
job_id,
perf_counter() - start_time,
len(selection.indices),
len(frame_records),
)
return result_payload
def _handle_model_build(
*,
runtime: WorkerRuntime,
payload: Mapping[str, Any],
mode: str,
streaming: bool,
job_id: str,
scene_id: str,
frame_records: list[FrameRecord],
inference: InferenceResult,
session_settings: Mapping[str, Any],
temp_dir: Path,
) -> dict[str, Any]:
predictions = inference.predictions
core = _generate_core_outputs(
runtime=runtime,
scene_id=scene_id,
job_id=job_id,
predictions=predictions,
frame_records=frame_records,
inference=inference,
session_settings=session_settings,
temp_dir=temp_dir,
)
artifacts = dict(core["artifacts"])
selected_frames_payload = payload.get("_selected_frames_info")
if selected_frames_payload:
top_k = _as_int(payload.get("_selected_top_k"), len(selected_frames_payload))
selected_frames = list(selected_frames_payload)
else:
top_k = _as_int(payload.get("top_k_frames") or payload.get("top_k"), 0)
selected_frames = _compute_selected_frames(predictions, frame_records, top_k)
selected_frames_url = _write_selected_frames(
runtime=runtime,
scene_id=scene_id,
selected_frames=selected_frames,
top_k=top_k,
temp_dir=temp_dir,
)
if selected_frames_url:
artifacts["selected_frames_url"] = selected_frames_url
scene_glb_url = _save_scene_glb(
runtime=runtime,
scene_id=scene_id,
predictions=predictions,
temp_dir=temp_dir,
payload=payload,
)
artifacts["scene_glb_url"] = scene_glb_url
summary_payload = {
"job_id": job_id,
"job_type": "model_build",
"scene_id": scene_id,
"frame_count": inference.total_frames,
"created_at": datetime.now(timezone.utc).isoformat(),
"artifacts": artifacts,
"selected_frames": selected_frames,
"parameters": {
"mode": mode,
"streaming": streaming,
"conf_thres": float(payload.get("conf_thres", 3.0)),
"frame_filter": payload.get("frame_filter", "All"),
"mask_black_bg": _as_bool(payload.get("mask_black_bg"), False),
"mask_white_bg": _as_bool(payload.get("mask_white_bg"), False),
"show_cam": _as_bool(payload.get("show_cam"), True),
"mask_sky": _as_bool(payload.get("mask_sky"), False),
"prediction_mode": payload.get("prediction_mode", "Predicted Pointmap"),
},
}
summary_url = _write_summary_json(
runtime=runtime,
scene_id=scene_id,
summary=summary_payload,
temp_dir=temp_dir,
)
artifacts["summary_url"] = summary_url
result_record = dict(summary_payload)
result_record["result_url"] = summary_url
result_record_url = _upload_result_record(
runtime=runtime,
scene_id=scene_id,
job_id=job_id,
payload=result_record,
)
result_payload = {
"job_id": job_id,
"job_type": "model_build",
"scene_id": scene_id,
"mode": mode,
"streaming": streaming,
"frame_count": inference.total_frames,
"created_at": summary_payload["created_at"],
"artifacts": artifacts,
"frames": core["frames"],
"selected_frames": selected_frames,
"summary_url": summary_url,
"result_url": summary_url,
"result_record_url": result_record_url,
"model_dir": _model_dir_uri(runtime, scene_id),
}
return result_payload