Spaces:
Configuration error
Configuration error
| import json | |
| import os | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from stream3r.models.stream3r import STream3R | |
| class StreamSession: | |
| """ | |
| A causal streaming inference session with KV cache management for STream3R. | |
| """ | |
| def __init__(self, model: STream3R, mode: str, *, window_size: Optional[int] = None, config_path: Optional[str] = None): | |
| self.model = model | |
| self.mode = mode | |
| self.aggregator_kv_cache_depth = model.aggregator.depth | |
| self.camera_head_kv_cache_depth = model.camera_head.trunk_depth | |
| self.camera_head_iterations = 4 | |
| self.window_size = self._resolve_window_size(window_size, config_path) | |
| if self.mode not in ["causal", "window"]: | |
| raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}") | |
| self.clear() | |
| def _clear_predictions(self): | |
| self.predictions = dict() | |
| def _update_predictions(self, predictions): | |
| for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]: | |
| if k in predictions: | |
| self.predictions[k] = torch.cat( | |
| [self.predictions.get(k, torch.empty(0, device=predictions[k].device)), predictions[k]], | |
| dim=1 | |
| ) | |
| def _clear_cache(self): | |
| self.aggregator_kv_cache_list = [[None, None] for _ in range(self.aggregator_kv_cache_depth)] | |
| self.camera_head_kv_cache_list = [[[None, None] for _ in range(self.camera_head_kv_cache_depth)] for _ in range(self.camera_head_iterations)] | |
| def _update_cache(self, aggregator_kv_cache_list, camera_head_kv_cache_list): | |
| if self.mode == "causal": | |
| self.aggregator_kv_cache_list = aggregator_kv_cache_list | |
| self.camera_head_kv_cache_list = camera_head_kv_cache_list | |
| elif self.mode == "window": | |
| for k in range(2): | |
| for i in range(self.aggregator_kv_cache_depth): | |
| h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3] | |
| P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx | |
| anchor_token = aggregator_kv_cache_list[i][k][:, :, :P] | |
| window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-self.window_size*P):] | |
| self.aggregator_kv_cache_list[i][k] = torch.cat( | |
| [ | |
| anchor_token, | |
| window_tokens | |
| ], | |
| dim=2 | |
| ) | |
| for i in range(self.camera_head_iterations): | |
| for j in range(self.camera_head_kv_cache_depth): | |
| anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1] | |
| window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-self.window_size):] | |
| self.camera_head_kv_cache_list[i][j][k] = torch.cat( | |
| [ | |
| anchor_token, | |
| window_tokens | |
| ], | |
| dim=2 | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}") | |
| def _get_cache(self): | |
| return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list | |
| def get_all_predictions(self): | |
| return self.predictions | |
| def get_last_prediction(self): | |
| last_predictions = dict() | |
| for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]: | |
| if k in self.predictions: | |
| last_predictions[k] = self.predictions[k][:, -1:] | |
| return last_predictions | |
| def clear(self): | |
| self._clear_predictions() | |
| self._clear_cache() | |
| def _detach_to_cpu(cache_like): | |
| if isinstance(cache_like, torch.Tensor): | |
| return cache_like.detach().cpu() | |
| if isinstance(cache_like, list): | |
| return [StreamSession._detach_to_cpu(elem) for elem in cache_like] | |
| if isinstance(cache_like, tuple): | |
| return tuple(StreamSession._detach_to_cpu(elem) for elem in cache_like) | |
| return cache_like | |
| def _to_device(cache_like, device: torch.device): | |
| if isinstance(cache_like, torch.Tensor): | |
| return cache_like.to(device) | |
| if isinstance(cache_like, list): | |
| return [StreamSession._to_device(elem, device) for elem in cache_like] | |
| if isinstance(cache_like, tuple): | |
| return tuple(StreamSession._to_device(elem, device) for elem in cache_like) | |
| return cache_like | |
| def _default_device(self) -> torch.device: | |
| try: | |
| return next(self.model.parameters()).device | |
| except StopIteration: | |
| return torch.device("cpu") | |
| def _resolve_window_size(self, override: Optional[int], config_path: Optional[str]) -> int: | |
| if override is not None: | |
| return override | |
| config_path = config_path or os.path.abspath( | |
| os.path.join(os.path.dirname(__file__), "..", "configs", "stream_session.json") | |
| ) | |
| default_window_size = 25 | |
| if not os.path.exists(config_path): | |
| return default_window_size | |
| try: | |
| with open(config_path, "r", encoding="utf-8") as handle: | |
| data = json.load(handle) | |
| except (json.JSONDecodeError, OSError): | |
| return default_window_size | |
| window_size = data.get("window_size") | |
| if isinstance(window_size, int) and window_size > 0: | |
| return window_size | |
| return default_window_size | |
| def save_cache(self, file_path: str) -> None: | |
| aggregator_cache, camera_cache = self._get_cache() | |
| payload: Dict[str, Any] = { | |
| "metadata": { | |
| "mode": self.mode, | |
| "aggregator_depth": self.aggregator_kv_cache_depth, | |
| "camera_head_depth": self.camera_head_kv_cache_depth, | |
| "camera_head_iterations": self.camera_head_iterations, | |
| "window_size": self.window_size, | |
| "patch_size": getattr(self.model.aggregator, "patch_size", None), | |
| "patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None), | |
| }, | |
| "aggregator_cache": self._detach_to_cpu(aggregator_cache), | |
| "camera_cache": self._detach_to_cpu(camera_cache), | |
| "predictions": {k: v.detach().cpu() for k, v in self.predictions.items()}, | |
| } | |
| dir_name = os.path.dirname(file_path) | |
| if dir_name: | |
| os.makedirs(dir_name, exist_ok=True) | |
| torch.save(payload, file_path) | |
| def load_cache(self, file_path: str, *, device: Optional[torch.device] = None, strict: bool = True) -> None: | |
| if device is None: | |
| device = self._default_device() | |
| payload = torch.load(file_path, map_location="cpu") | |
| metadata: Dict[str, Any] = payload.get("metadata", {}) | |
| expected_metadata = { | |
| "mode": self.mode, | |
| "aggregator_depth": self.aggregator_kv_cache_depth, | |
| "camera_head_depth": self.camera_head_kv_cache_depth, | |
| "camera_head_iterations": self.camera_head_iterations, | |
| "window_size": self.window_size, | |
| } | |
| for key, expected_value in expected_metadata.items(): | |
| actual_value = metadata.get(key) | |
| if strict and actual_value != expected_value: | |
| raise ValueError( | |
| f"Loaded cache metadata mismatch for '{key}': expected {expected_value}, got {actual_value}" | |
| ) | |
| if strict: | |
| patch_size = getattr(self.model.aggregator, "patch_size", None) | |
| patch_start_idx = getattr(self.model.aggregator, "patch_start_idx", None) | |
| if metadata.get("patch_size") not in (None, patch_size): | |
| raise ValueError( | |
| f"Loaded cache metadata mismatch for 'patch_size': expected {patch_size}, got {metadata.get('patch_size')}" | |
| ) | |
| if metadata.get("patch_start_idx") not in (None, patch_start_idx): | |
| raise ValueError( | |
| f"Loaded cache metadata mismatch for 'patch_start_idx': expected {patch_start_idx}, got {metadata.get('patch_start_idx')}" | |
| ) | |
| self.aggregator_kv_cache_list = self._to_device(payload.get("aggregator_cache", []), device) | |
| self.camera_head_kv_cache_list = self._to_device(payload.get("camera_cache", []), device) | |
| self.predictions = { | |
| k: v.to(device) if isinstance(v, torch.Tensor) else v | |
| for k, v in payload.get("predictions", {}).items() | |
| } | |
| def forward_stream(self, images): | |
| aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache() | |
| outputs = self.model( | |
| images=images, | |
| mode=self.mode, | |
| aggregator_kv_cache_list=aggregator_kv_cache_list, | |
| camera_head_kv_cache_list=camera_head_kv_cache_list, | |
| ) | |
| self._update_predictions(outputs) | |
| self._update_cache(outputs["aggregator_kv_cache_list"], outputs["camera_head_kv_cache_list"]) | |
| return self.get_all_predictions() | |