import json import os from typing import Any, Dict, Optional import torch from stream3r.models.stream3r import STream3R class StreamSession: """ A causal streaming inference session with KV cache management for STream3R. """ def __init__(self, model: STream3R, mode: str, *, window_size: Optional[int] = None, config_path: Optional[str] = None): self.model = model self.mode = mode self.aggregator_kv_cache_depth = model.aggregator.depth self.camera_head_kv_cache_depth = model.camera_head.trunk_depth self.camera_head_iterations = 4 self.window_size = self._resolve_window_size(window_size, config_path) if self.mode not in ["causal", "window"]: raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}") self.clear() def _clear_predictions(self): self.predictions = dict() def _update_predictions(self, predictions): for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]: if k in predictions: self.predictions[k] = torch.cat( [self.predictions.get(k, torch.empty(0, device=predictions[k].device)), predictions[k]], dim=1 ) def _clear_cache(self): self.aggregator_kv_cache_list = [[None, None] for _ in range(self.aggregator_kv_cache_depth)] self.camera_head_kv_cache_list = [[[None, None] for _ in range(self.camera_head_kv_cache_depth)] for _ in range(self.camera_head_iterations)] def _update_cache(self, aggregator_kv_cache_list, camera_head_kv_cache_list): if self.mode == "causal": self.aggregator_kv_cache_list = aggregator_kv_cache_list self.camera_head_kv_cache_list = camera_head_kv_cache_list elif self.mode == "window": for k in range(2): for i in range(self.aggregator_kv_cache_depth): h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3] P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx anchor_token = aggregator_kv_cache_list[i][k][:, :, :P] window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-self.window_size*P):] self.aggregator_kv_cache_list[i][k] = torch.cat( [ anchor_token, window_tokens ], dim=2 ) for i in range(self.camera_head_iterations): for j in range(self.camera_head_kv_cache_depth): anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1] window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-self.window_size):] self.camera_head_kv_cache_list[i][j][k] = torch.cat( [ anchor_token, window_tokens ], dim=2 ) else: raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}") def _get_cache(self): return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list def get_all_predictions(self): return self.predictions def get_last_prediction(self): last_predictions = dict() for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]: if k in self.predictions: last_predictions[k] = self.predictions[k][:, -1:] return last_predictions def clear(self): self._clear_predictions() self._clear_cache() @staticmethod def _detach_to_cpu(cache_like): if isinstance(cache_like, torch.Tensor): return cache_like.detach().cpu() if isinstance(cache_like, list): return [StreamSession._detach_to_cpu(elem) for elem in cache_like] if isinstance(cache_like, tuple): return tuple(StreamSession._detach_to_cpu(elem) for elem in cache_like) return cache_like @staticmethod def _to_device(cache_like, device: torch.device): if isinstance(cache_like, torch.Tensor): return cache_like.to(device) if isinstance(cache_like, list): return [StreamSession._to_device(elem, device) for elem in cache_like] if isinstance(cache_like, tuple): return tuple(StreamSession._to_device(elem, device) for elem in cache_like) return cache_like def _default_device(self) -> torch.device: try: return next(self.model.parameters()).device except StopIteration: return torch.device("cpu") def _resolve_window_size(self, override: Optional[int], config_path: Optional[str]) -> int: if override is not None: return override config_path = config_path or os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "configs", "stream_session.json") ) default_window_size = 25 if not os.path.exists(config_path): return default_window_size try: with open(config_path, "r", encoding="utf-8") as handle: data = json.load(handle) except (json.JSONDecodeError, OSError): return default_window_size window_size = data.get("window_size") if isinstance(window_size, int) and window_size > 0: return window_size return default_window_size def save_cache(self, file_path: str) -> None: aggregator_cache, camera_cache = self._get_cache() payload: Dict[str, Any] = { "metadata": { "mode": self.mode, "aggregator_depth": self.aggregator_kv_cache_depth, "camera_head_depth": self.camera_head_kv_cache_depth, "camera_head_iterations": self.camera_head_iterations, "window_size": self.window_size, "patch_size": getattr(self.model.aggregator, "patch_size", None), "patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None), }, "aggregator_cache": self._detach_to_cpu(aggregator_cache), "camera_cache": self._detach_to_cpu(camera_cache), "predictions": {k: v.detach().cpu() for k, v in self.predictions.items()}, } dir_name = os.path.dirname(file_path) if dir_name: os.makedirs(dir_name, exist_ok=True) torch.save(payload, file_path) def load_cache(self, file_path: str, *, device: Optional[torch.device] = None, strict: bool = True) -> None: if device is None: device = self._default_device() payload = torch.load(file_path, map_location="cpu") metadata: Dict[str, Any] = payload.get("metadata", {}) expected_metadata = { "mode": self.mode, "aggregator_depth": self.aggregator_kv_cache_depth, "camera_head_depth": self.camera_head_kv_cache_depth, "camera_head_iterations": self.camera_head_iterations, "window_size": self.window_size, } for key, expected_value in expected_metadata.items(): actual_value = metadata.get(key) if strict and actual_value != expected_value: raise ValueError( f"Loaded cache metadata mismatch for '{key}': expected {expected_value}, got {actual_value}" ) if strict: patch_size = getattr(self.model.aggregator, "patch_size", None) patch_start_idx = getattr(self.model.aggregator, "patch_start_idx", None) if metadata.get("patch_size") not in (None, patch_size): raise ValueError( f"Loaded cache metadata mismatch for 'patch_size': expected {patch_size}, got {metadata.get('patch_size')}" ) if metadata.get("patch_start_idx") not in (None, patch_start_idx): raise ValueError( f"Loaded cache metadata mismatch for 'patch_start_idx': expected {patch_start_idx}, got {metadata.get('patch_start_idx')}" ) self.aggregator_kv_cache_list = self._to_device(payload.get("aggregator_cache", []), device) self.camera_head_kv_cache_list = self._to_device(payload.get("camera_cache", []), device) self.predictions = { k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in payload.get("predictions", {}).items() } def forward_stream(self, images): aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache() outputs = self.model( images=images, mode=self.mode, aggregator_kv_cache_list=aggregator_kv_cache_list, camera_head_kv_cache_list=camera_head_kv_cache_list, ) self._update_predictions(outputs) self._update_cache(outputs["aggregator_kv_cache_list"], outputs["camera_head_kv_cache_list"]) return self.get_all_predictions()