Spaces:
Configuration error
Configuration error
| """Inference pipeline utilities reused by worker jobs.""" | |
| from __future__ import annotations | |
| from contextlib import nullcontext | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Callable, Iterable, Mapping | |
| import numpy as np | |
| import torch | |
| from stream3r.models.components.utils.load_fn import load_and_preprocess_images | |
| from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri | |
| from stream3r.stream_session import StreamSession | |
| from .runtime import WorkerRuntime | |
| ProgressCallback = Callable[[int, int], None] | |
| class InferenceResult: | |
| predictions: dict[str, np.ndarray] | |
| total_frames: int | |
| cache_path: Path | None | |
| def _to_numpy(payload): | |
| if isinstance(payload, torch.Tensor): | |
| return payload.detach().cpu().numpy() | |
| if isinstance(payload, dict): | |
| return {k: _to_numpy(v) for k, v in payload.items()} | |
| if isinstance(payload, (list, tuple)): | |
| converted = [_to_numpy(item) for item in payload] | |
| return type(payload)(converted) | |
| return payload | |
| def run_stream3r_inference( | |
| *, | |
| runtime: WorkerRuntime, | |
| image_paths: Iterable[Path], | |
| mode: str, | |
| streaming: bool, | |
| cache_output_path: Path | None, | |
| progress_cb: ProgressCallback | None = None, | |
| window_size: int | None = None, | |
| ) -> InferenceResult: | |
| """Execute STream3R inference for the provided frames.""" | |
| image_list = [Path(p) for p in image_paths] | |
| if not image_list: | |
| raise ValueError("No images provided to inference pipeline") | |
| model = runtime.get_model() | |
| device = runtime.model_device() | |
| images = load_and_preprocess_images([str(path) for path in image_list]) | |
| total_frames = images.shape[0] | |
| autocast_dtype = runtime.autocast_dtype() | |
| autocast_ctx = ( | |
| torch.autocast(device_type=device.type, dtype=autocast_dtype) | |
| if device.type == "cuda" | |
| else nullcontext() | |
| ) | |
| predictions: Mapping[str, torch.Tensor] | |
| cache_path: Path | None = None | |
| model.eval() | |
| if window_size is not None and window_size <= 0: | |
| window_size = None | |
| with torch.no_grad(): | |
| if streaming: | |
| session_kwargs = {"mode": mode} | |
| if window_size is not None: | |
| session_kwargs["window_size"] = window_size | |
| session = StreamSession(model, **session_kwargs) | |
| session.clear() | |
| for idx in range(total_frames): | |
| frame = images[idx : idx + 1].to(device) | |
| with autocast_ctx: | |
| session.forward_stream(frame) | |
| print(f"Processed frame {idx + 1}/{total_frames}") | |
| if progress_cb is not None: | |
| progress_cb(idx + 1, total_frames) | |
| if cache_output_path is not None: | |
| session.save_cache(str(cache_output_path)) | |
| cache_path = cache_output_path | |
| predictions = session.get_all_predictions() | |
| else: | |
| inputs = images.to(device) | |
| with autocast_ctx: | |
| predictions = model(inputs, mode=mode) | |
| if progress_cb is not None: | |
| progress_cb(total_frames, total_frames) | |
| predictions = dict(predictions) | |
| # Augment predictions with pose matrices and world coordinates | |
| height, width = images.shape[-2:] | |
| pose_enc = predictions.get("pose_enc") | |
| if pose_enc is None: | |
| raise RuntimeError("Model predictions missing 'pose_enc'") | |
| if not isinstance(pose_enc, torch.Tensor): | |
| pose_enc = torch.as_tensor(pose_enc) | |
| if pose_enc.dim() == 2: # streaming cache might drop batch dim | |
| pose_enc = pose_enc.unsqueeze(0) | |
| extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, (height, width)) | |
| predictions["extrinsic"] = extrinsic | |
| predictions["intrinsic"] = intrinsic | |
| for key, value in list(predictions.items()): | |
| if isinstance(value, torch.Tensor): | |
| predictions[key] = value | |
| predictions_np = {key: _to_numpy(value) for key, value in predictions.items()} | |
| pose_enc_np = predictions_np.pop("pose_enc", None) | |
| if pose_enc_np is not None and pose_enc_np.ndim >= 3: | |
| predictions_np["pose_enc"] = pose_enc_np | |
| # Remove batch dimension if present | |
| for key, value in list(predictions_np.items()): | |
| if isinstance(value, np.ndarray) and value.shape[0] == 1: | |
| predictions_np[key] = np.squeeze(value, axis=0) | |
| torch.cuda.empty_cache() | |
| return InferenceResult( | |
| predictions=predictions_np, | |
| total_frames=total_frames, | |
| cache_path=cache_path, | |
| ) | |