brian4dwell's picture
working for larger batches now
f1e0138
"""Inference pipeline utilities reused by worker jobs."""
from __future__ import annotations
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, Mapping
import numpy as np
import torch
from stream3r.models.components.utils.load_fn import load_and_preprocess_images
from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri
from stream3r.stream_session import StreamSession
from .runtime import WorkerRuntime
ProgressCallback = Callable[[int, int], None]
@dataclass
class InferenceResult:
predictions: dict[str, np.ndarray]
total_frames: int
cache_path: Path | None
def _to_numpy(payload):
if isinstance(payload, torch.Tensor):
return payload.detach().cpu().numpy()
if isinstance(payload, dict):
return {k: _to_numpy(v) for k, v in payload.items()}
if isinstance(payload, (list, tuple)):
converted = [_to_numpy(item) for item in payload]
return type(payload)(converted)
return payload
def run_stream3r_inference(
*,
runtime: WorkerRuntime,
image_paths: Iterable[Path],
mode: str,
streaming: bool,
cache_output_path: Path | None,
progress_cb: ProgressCallback | None = None,
window_size: int | None = None,
) -> InferenceResult:
"""Execute STream3R inference for the provided frames."""
image_list = [Path(p) for p in image_paths]
if not image_list:
raise ValueError("No images provided to inference pipeline")
model = runtime.get_model()
device = runtime.model_device()
images = load_and_preprocess_images([str(path) for path in image_list])
total_frames = images.shape[0]
autocast_dtype = runtime.autocast_dtype()
autocast_ctx = (
torch.autocast(device_type=device.type, dtype=autocast_dtype)
if device.type == "cuda"
else nullcontext()
)
predictions: Mapping[str, torch.Tensor]
cache_path: Path | None = None
model.eval()
if window_size is not None and window_size <= 0:
window_size = None
with torch.no_grad():
if streaming:
session_kwargs = {"mode": mode}
if window_size is not None:
session_kwargs["window_size"] = window_size
session = StreamSession(model, **session_kwargs)
session.clear()
for idx in range(total_frames):
frame = images[idx : idx + 1].to(device)
with autocast_ctx:
session.forward_stream(frame)
print(f"Processed frame {idx + 1}/{total_frames}")
if progress_cb is not None:
progress_cb(idx + 1, total_frames)
if cache_output_path is not None:
session.save_cache(str(cache_output_path))
cache_path = cache_output_path
predictions = session.get_all_predictions()
else:
inputs = images.to(device)
with autocast_ctx:
predictions = model(inputs, mode=mode)
if progress_cb is not None:
progress_cb(total_frames, total_frames)
predictions = dict(predictions)
# Augment predictions with pose matrices and world coordinates
height, width = images.shape[-2:]
pose_enc = predictions.get("pose_enc")
if pose_enc is None:
raise RuntimeError("Model predictions missing 'pose_enc'")
if not isinstance(pose_enc, torch.Tensor):
pose_enc = torch.as_tensor(pose_enc)
if pose_enc.dim() == 2: # streaming cache might drop batch dim
pose_enc = pose_enc.unsqueeze(0)
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, (height, width))
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic
for key, value in list(predictions.items()):
if isinstance(value, torch.Tensor):
predictions[key] = value
predictions_np = {key: _to_numpy(value) for key, value in predictions.items()}
pose_enc_np = predictions_np.pop("pose_enc", None)
if pose_enc_np is not None and pose_enc_np.ndim >= 3:
predictions_np["pose_enc"] = pose_enc_np
# Remove batch dimension if present
for key, value in list(predictions_np.items()):
if isinstance(value, np.ndarray) and value.shape[0] == 1:
predictions_np[key] = np.squeeze(value, axis=0)
torch.cuda.empty_cache()
return InferenceResult(
predictions=predictions_np,
total_frames=total_frames,
cache_path=cache_path,
)