dwellbot_stream3r / stream3r /stream_session.py
brian4dwell's picture
add saving and reloading of session
4c075ec
import json
import os
from typing import Any, Dict, Optional
import torch
from stream3r.models.stream3r import STream3R
class StreamSession:
"""
A causal streaming inference session with KV cache management for STream3R.
"""
def __init__(self, model: STream3R, mode: str, *, window_size: Optional[int] = None, config_path: Optional[str] = None):
self.model = model
self.mode = mode
self.aggregator_kv_cache_depth = model.aggregator.depth
self.camera_head_kv_cache_depth = model.camera_head.trunk_depth
self.camera_head_iterations = 4
self.window_size = self._resolve_window_size(window_size, config_path)
if self.mode not in ["causal", "window"]:
raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
self.clear()
def _clear_predictions(self):
self.predictions = dict()
def _update_predictions(self, predictions):
for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]:
if k in predictions:
self.predictions[k] = torch.cat(
[self.predictions.get(k, torch.empty(0, device=predictions[k].device)), predictions[k]],
dim=1
)
def _clear_cache(self):
self.aggregator_kv_cache_list = [[None, None] for _ in range(self.aggregator_kv_cache_depth)]
self.camera_head_kv_cache_list = [[[None, None] for _ in range(self.camera_head_kv_cache_depth)] for _ in range(self.camera_head_iterations)]
def _update_cache(self, aggregator_kv_cache_list, camera_head_kv_cache_list):
if self.mode == "causal":
self.aggregator_kv_cache_list = aggregator_kv_cache_list
self.camera_head_kv_cache_list = camera_head_kv_cache_list
elif self.mode == "window":
for k in range(2):
for i in range(self.aggregator_kv_cache_depth):
h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3]
P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx
anchor_token = aggregator_kv_cache_list[i][k][:, :, :P]
window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-self.window_size*P):]
self.aggregator_kv_cache_list[i][k] = torch.cat(
[
anchor_token,
window_tokens
],
dim=2
)
for i in range(self.camera_head_iterations):
for j in range(self.camera_head_kv_cache_depth):
anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1]
window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-self.window_size):]
self.camera_head_kv_cache_list[i][j][k] = torch.cat(
[
anchor_token,
window_tokens
],
dim=2
)
else:
raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
def _get_cache(self):
return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list
def get_all_predictions(self):
return self.predictions
def get_last_prediction(self):
last_predictions = dict()
for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]:
if k in self.predictions:
last_predictions[k] = self.predictions[k][:, -1:]
return last_predictions
def clear(self):
self._clear_predictions()
self._clear_cache()
@staticmethod
def _detach_to_cpu(cache_like):
if isinstance(cache_like, torch.Tensor):
return cache_like.detach().cpu()
if isinstance(cache_like, list):
return [StreamSession._detach_to_cpu(elem) for elem in cache_like]
if isinstance(cache_like, tuple):
return tuple(StreamSession._detach_to_cpu(elem) for elem in cache_like)
return cache_like
@staticmethod
def _to_device(cache_like, device: torch.device):
if isinstance(cache_like, torch.Tensor):
return cache_like.to(device)
if isinstance(cache_like, list):
return [StreamSession._to_device(elem, device) for elem in cache_like]
if isinstance(cache_like, tuple):
return tuple(StreamSession._to_device(elem, device) for elem in cache_like)
return cache_like
def _default_device(self) -> torch.device:
try:
return next(self.model.parameters()).device
except StopIteration:
return torch.device("cpu")
def _resolve_window_size(self, override: Optional[int], config_path: Optional[str]) -> int:
if override is not None:
return override
config_path = config_path or os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "configs", "stream_session.json")
)
default_window_size = 25
if not os.path.exists(config_path):
return default_window_size
try:
with open(config_path, "r", encoding="utf-8") as handle:
data = json.load(handle)
except (json.JSONDecodeError, OSError):
return default_window_size
window_size = data.get("window_size")
if isinstance(window_size, int) and window_size > 0:
return window_size
return default_window_size
def save_cache(self, file_path: str) -> None:
aggregator_cache, camera_cache = self._get_cache()
payload: Dict[str, Any] = {
"metadata": {
"mode": self.mode,
"aggregator_depth": self.aggregator_kv_cache_depth,
"camera_head_depth": self.camera_head_kv_cache_depth,
"camera_head_iterations": self.camera_head_iterations,
"window_size": self.window_size,
"patch_size": getattr(self.model.aggregator, "patch_size", None),
"patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
},
"aggregator_cache": self._detach_to_cpu(aggregator_cache),
"camera_cache": self._detach_to_cpu(camera_cache),
"predictions": {k: v.detach().cpu() for k, v in self.predictions.items()},
}
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
torch.save(payload, file_path)
def load_cache(self, file_path: str, *, device: Optional[torch.device] = None, strict: bool = True) -> None:
if device is None:
device = self._default_device()
payload = torch.load(file_path, map_location="cpu")
metadata: Dict[str, Any] = payload.get("metadata", {})
expected_metadata = {
"mode": self.mode,
"aggregator_depth": self.aggregator_kv_cache_depth,
"camera_head_depth": self.camera_head_kv_cache_depth,
"camera_head_iterations": self.camera_head_iterations,
"window_size": self.window_size,
}
for key, expected_value in expected_metadata.items():
actual_value = metadata.get(key)
if strict and actual_value != expected_value:
raise ValueError(
f"Loaded cache metadata mismatch for '{key}': expected {expected_value}, got {actual_value}"
)
if strict:
patch_size = getattr(self.model.aggregator, "patch_size", None)
patch_start_idx = getattr(self.model.aggregator, "patch_start_idx", None)
if metadata.get("patch_size") not in (None, patch_size):
raise ValueError(
f"Loaded cache metadata mismatch for 'patch_size': expected {patch_size}, got {metadata.get('patch_size')}"
)
if metadata.get("patch_start_idx") not in (None, patch_start_idx):
raise ValueError(
f"Loaded cache metadata mismatch for 'patch_start_idx': expected {patch_start_idx}, got {metadata.get('patch_start_idx')}"
)
self.aggregator_kv_cache_list = self._to_device(payload.get("aggregator_cache", []), device)
self.camera_head_kv_cache_list = self._to_device(payload.get("camera_cache", []), device)
self.predictions = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in payload.get("predictions", {}).items()
}
def forward_stream(self, images):
aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache()
outputs = self.model(
images=images,
mode=self.mode,
aggregator_kv_cache_list=aggregator_kv_cache_list,
camera_head_kv_cache_list=camera_head_kv_cache_list,
)
self._update_predictions(outputs)
self._update_cache(outputs["aggregator_kv_cache_list"], outputs["camera_head_kv_cache_list"])
return self.get_all_predictions()