Spaces:
Configuration error
Configuration error
File size: 9,489 Bytes
4c075ec 6805b8e 9d31508 4c075ec 9d31508 4c075ec 9d31508 4c075ec 9d31508 4c075ec 9d31508 6805b8e 9d31508 6805b8e 4c075ec 6805b8e 4c075ec 6805b8e 4c075ec 6805b8e 9d31508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import json
import os
from typing import Any, Dict, Optional
import torch
from stream3r.models.stream3r import STream3R
class StreamSession:
"""
A causal streaming inference session with KV cache management for STream3R.
"""
def __init__(self, model: STream3R, mode: str, *, window_size: Optional[int] = None, config_path: Optional[str] = None):
self.model = model
self.mode = mode
self.aggregator_kv_cache_depth = model.aggregator.depth
self.camera_head_kv_cache_depth = model.camera_head.trunk_depth
self.camera_head_iterations = 4
self.window_size = self._resolve_window_size(window_size, config_path)
if self.mode not in ["causal", "window"]:
raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
self.clear()
def _clear_predictions(self):
self.predictions = dict()
def _update_predictions(self, predictions):
for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]:
if k in predictions:
self.predictions[k] = torch.cat(
[self.predictions.get(k, torch.empty(0, device=predictions[k].device)), predictions[k]],
dim=1
)
def _clear_cache(self):
self.aggregator_kv_cache_list = [[None, None] for _ in range(self.aggregator_kv_cache_depth)]
self.camera_head_kv_cache_list = [[[None, None] for _ in range(self.camera_head_kv_cache_depth)] for _ in range(self.camera_head_iterations)]
def _update_cache(self, aggregator_kv_cache_list, camera_head_kv_cache_list):
if self.mode == "causal":
self.aggregator_kv_cache_list = aggregator_kv_cache_list
self.camera_head_kv_cache_list = camera_head_kv_cache_list
elif self.mode == "window":
for k in range(2):
for i in range(self.aggregator_kv_cache_depth):
h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3]
P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx
anchor_token = aggregator_kv_cache_list[i][k][:, :, :P]
window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-self.window_size*P):]
self.aggregator_kv_cache_list[i][k] = torch.cat(
[
anchor_token,
window_tokens
],
dim=2
)
for i in range(self.camera_head_iterations):
for j in range(self.camera_head_kv_cache_depth):
anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1]
window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-self.window_size):]
self.camera_head_kv_cache_list[i][j][k] = torch.cat(
[
anchor_token,
window_tokens
],
dim=2
)
else:
raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")
def _get_cache(self):
return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list
def get_all_predictions(self):
return self.predictions
def get_last_prediction(self):
last_predictions = dict()
for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]:
if k in self.predictions:
last_predictions[k] = self.predictions[k][:, -1:]
return last_predictions
def clear(self):
self._clear_predictions()
self._clear_cache()
@staticmethod
def _detach_to_cpu(cache_like):
if isinstance(cache_like, torch.Tensor):
return cache_like.detach().cpu()
if isinstance(cache_like, list):
return [StreamSession._detach_to_cpu(elem) for elem in cache_like]
if isinstance(cache_like, tuple):
return tuple(StreamSession._detach_to_cpu(elem) for elem in cache_like)
return cache_like
@staticmethod
def _to_device(cache_like, device: torch.device):
if isinstance(cache_like, torch.Tensor):
return cache_like.to(device)
if isinstance(cache_like, list):
return [StreamSession._to_device(elem, device) for elem in cache_like]
if isinstance(cache_like, tuple):
return tuple(StreamSession._to_device(elem, device) for elem in cache_like)
return cache_like
def _default_device(self) -> torch.device:
try:
return next(self.model.parameters()).device
except StopIteration:
return torch.device("cpu")
def _resolve_window_size(self, override: Optional[int], config_path: Optional[str]) -> int:
if override is not None:
return override
config_path = config_path or os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "configs", "stream_session.json")
)
default_window_size = 25
if not os.path.exists(config_path):
return default_window_size
try:
with open(config_path, "r", encoding="utf-8") as handle:
data = json.load(handle)
except (json.JSONDecodeError, OSError):
return default_window_size
window_size = data.get("window_size")
if isinstance(window_size, int) and window_size > 0:
return window_size
return default_window_size
def save_cache(self, file_path: str) -> None:
aggregator_cache, camera_cache = self._get_cache()
payload: Dict[str, Any] = {
"metadata": {
"mode": self.mode,
"aggregator_depth": self.aggregator_kv_cache_depth,
"camera_head_depth": self.camera_head_kv_cache_depth,
"camera_head_iterations": self.camera_head_iterations,
"window_size": self.window_size,
"patch_size": getattr(self.model.aggregator, "patch_size", None),
"patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
},
"aggregator_cache": self._detach_to_cpu(aggregator_cache),
"camera_cache": self._detach_to_cpu(camera_cache),
"predictions": {k: v.detach().cpu() for k, v in self.predictions.items()},
}
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
torch.save(payload, file_path)
def load_cache(self, file_path: str, *, device: Optional[torch.device] = None, strict: bool = True) -> None:
if device is None:
device = self._default_device()
payload = torch.load(file_path, map_location="cpu")
metadata: Dict[str, Any] = payload.get("metadata", {})
expected_metadata = {
"mode": self.mode,
"aggregator_depth": self.aggregator_kv_cache_depth,
"camera_head_depth": self.camera_head_kv_cache_depth,
"camera_head_iterations": self.camera_head_iterations,
"window_size": self.window_size,
}
for key, expected_value in expected_metadata.items():
actual_value = metadata.get(key)
if strict and actual_value != expected_value:
raise ValueError(
f"Loaded cache metadata mismatch for '{key}': expected {expected_value}, got {actual_value}"
)
if strict:
patch_size = getattr(self.model.aggregator, "patch_size", None)
patch_start_idx = getattr(self.model.aggregator, "patch_start_idx", None)
if metadata.get("patch_size") not in (None, patch_size):
raise ValueError(
f"Loaded cache metadata mismatch for 'patch_size': expected {patch_size}, got {metadata.get('patch_size')}"
)
if metadata.get("patch_start_idx") not in (None, patch_start_idx):
raise ValueError(
f"Loaded cache metadata mismatch for 'patch_start_idx': expected {patch_start_idx}, got {metadata.get('patch_start_idx')}"
)
self.aggregator_kv_cache_list = self._to_device(payload.get("aggregator_cache", []), device)
self.camera_head_kv_cache_list = self._to_device(payload.get("camera_cache", []), device)
self.predictions = {
k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in payload.get("predictions", {}).items()
}
def forward_stream(self, images):
aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache()
outputs = self.model(
images=images,
mode=self.mode,
aggregator_kv_cache_list=aggregator_kv_cache_list,
camera_head_kv_cache_list=camera_head_kv_cache_list,
)
self._update_predictions(outputs)
self._update_cache(outputs["aggregator_kv_cache_list"], outputs["camera_head_kv_cache_list"])
return self.get_all_predictions()
|