Spaces:
Configuration error
Configuration error
Commit
·
6805b8e
1
Parent(s):
d255d9f
saving kv-cache
Browse files- .gitignore +70 -1
- stream3r/__pycache__/__init__.cpython-311.pyc +0 -0
- stream3r/__pycache__/stream_session.cpython-311.pyc +0 -0
- stream3r/dust3r/__pycache__/__init__.cpython-311.pyc +0 -0
- stream3r/dust3r/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- stream3r/dust3r/utils/__pycache__/misc.cpython-311.pyc +0 -0
- stream3r/models/__pycache__/__init__.cpython-311.pyc +0 -0
- stream3r/models/__pycache__/stream3r.cpython-311.pyc +0 -0
- stream3r/models/components/aggregator/__pycache__/streamaggregator.cpython-311.pyc +0 -0
- stream3r/models/components/heads/__pycache__/camera_head.cpython-311.pyc +0 -0
- stream3r/models/components/heads/__pycache__/dpt_head.cpython-311.pyc +0 -0
- stream3r/models/components/heads/__pycache__/head_act.cpython-311.pyc +0 -0
- stream3r/models/components/heads/__pycache__/utils.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/attention.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/block.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/drop_path.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/layer_scale.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/mlp.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/patch_embed.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/rope.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/swiglu_ffn.cpython-311.pyc +0 -0
- stream3r/models/components/layers/__pycache__/vision_transformer.cpython-311.pyc +0 -0
- stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc +0 -0
- stream3r/models/components/utils/__pycache__/load_fn.cpython-311.pyc +0 -0
- stream3r/models/components/utils/__pycache__/pose_enc.cpython-311.pyc +0 -0
- stream3r/models/components/utils/__pycache__/rotation.cpython-311.pyc +0 -0
- stream3r/stream_session.py +94 -1
- stream3r/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- stream3r/utils/__pycache__/instantiators.cpython-311.pyc +0 -0
- stream3r/utils/__pycache__/logging_utils.cpython-311.pyc +0 -0
- stream3r/utils/__pycache__/pylogger.cpython-311.pyc +0 -0
- stream3r/utils/__pycache__/rich_utils.cpython-311.pyc +0 -0
- stream3r/utils/__pycache__/utils.cpython-311.pyc +0 -0
- stream3r/utils/__pycache__/visual_utils.cpython-311.pyc +0 -0
- tests/test_stream_session_cache.py +106 -0
.gitignore
CHANGED
|
@@ -1 +1,70 @@
|
|
| 1 |
-
demo_cache/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
demo_cache/
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Byte-compiled / optimized / DLL files
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[codz]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# C extensions
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# Distribution / packaging
|
| 14 |
+
.Python
|
| 15 |
+
build/
|
| 16 |
+
develop-eggs/
|
| 17 |
+
dist/
|
| 18 |
+
downloads/
|
| 19 |
+
eggs/
|
| 20 |
+
.eggs/
|
| 21 |
+
lib/
|
| 22 |
+
lib64/
|
| 23 |
+
parts/
|
| 24 |
+
sdist/
|
| 25 |
+
var/
|
| 26 |
+
wheels/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py.cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
cover/
|
| 57 |
+
|
| 58 |
+
# Translations
|
| 59 |
+
*.mo
|
| 60 |
+
*.pot
|
| 61 |
+
|
| 62 |
+
# Django stuff:
|
| 63 |
+
*.log
|
| 64 |
+
local_settings.py
|
| 65 |
+
db.sqlite3
|
| 66 |
+
db.sqlite3-journal
|
| 67 |
+
|
| 68 |
+
# Flask stuff:
|
| 69 |
+
instance/
|
| 70 |
+
.webassets-cache
|
stream3r/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/__pycache__/__init__.cpython-311.pyc and b/stream3r/__pycache__/__init__.cpython-311.pyc differ
|
|
|
stream3r/__pycache__/stream_session.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/__pycache__/stream_session.cpython-311.pyc and b/stream3r/__pycache__/stream_session.cpython-311.pyc differ
|
|
|
stream3r/dust3r/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/dust3r/__pycache__/__init__.cpython-311.pyc and b/stream3r/dust3r/__pycache__/__init__.cpython-311.pyc differ
|
|
|
stream3r/dust3r/utils/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/dust3r/utils/__pycache__/__init__.cpython-311.pyc and b/stream3r/dust3r/utils/__pycache__/__init__.cpython-311.pyc differ
|
|
|
stream3r/dust3r/utils/__pycache__/misc.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/dust3r/utils/__pycache__/misc.cpython-311.pyc and b/stream3r/dust3r/utils/__pycache__/misc.cpython-311.pyc differ
|
|
|
stream3r/models/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/__pycache__/__init__.cpython-311.pyc and b/stream3r/models/__pycache__/__init__.cpython-311.pyc differ
|
|
|
stream3r/models/__pycache__/stream3r.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/__pycache__/stream3r.cpython-311.pyc and b/stream3r/models/__pycache__/stream3r.cpython-311.pyc differ
|
|
|
stream3r/models/components/aggregator/__pycache__/streamaggregator.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/aggregator/__pycache__/streamaggregator.cpython-311.pyc and b/stream3r/models/components/aggregator/__pycache__/streamaggregator.cpython-311.pyc differ
|
|
|
stream3r/models/components/heads/__pycache__/camera_head.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/heads/__pycache__/camera_head.cpython-311.pyc and b/stream3r/models/components/heads/__pycache__/camera_head.cpython-311.pyc differ
|
|
|
stream3r/models/components/heads/__pycache__/dpt_head.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/heads/__pycache__/dpt_head.cpython-311.pyc and b/stream3r/models/components/heads/__pycache__/dpt_head.cpython-311.pyc differ
|
|
|
stream3r/models/components/heads/__pycache__/head_act.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/heads/__pycache__/head_act.cpython-311.pyc and b/stream3r/models/components/heads/__pycache__/head_act.cpython-311.pyc differ
|
|
|
stream3r/models/components/heads/__pycache__/utils.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/heads/__pycache__/utils.cpython-311.pyc and b/stream3r/models/components/heads/__pycache__/utils.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/__init__.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/__init__.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/attention.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/attention.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/attention.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/block.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/block.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/block.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/drop_path.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/drop_path.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/drop_path.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/layer_scale.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/layer_scale.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/layer_scale.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/mlp.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/mlp.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/mlp.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/patch_embed.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/patch_embed.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/patch_embed.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/rope.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/rope.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/rope.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/swiglu_ffn.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/swiglu_ffn.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/swiglu_ffn.cpython-311.pyc differ
|
|
|
stream3r/models/components/layers/__pycache__/vision_transformer.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/layers/__pycache__/vision_transformer.cpython-311.pyc and b/stream3r/models/components/layers/__pycache__/vision_transformer.cpython-311.pyc differ
|
|
|
stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc and b/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc differ
|
|
|
stream3r/models/components/utils/__pycache__/load_fn.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/utils/__pycache__/load_fn.cpython-311.pyc and b/stream3r/models/components/utils/__pycache__/load_fn.cpython-311.pyc differ
|
|
|
stream3r/models/components/utils/__pycache__/pose_enc.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/utils/__pycache__/pose_enc.cpython-311.pyc and b/stream3r/models/components/utils/__pycache__/pose_enc.cpython-311.pyc differ
|
|
|
stream3r/models/components/utils/__pycache__/rotation.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/utils/__pycache__/rotation.cpython-311.pyc and b/stream3r/models/components/utils/__pycache__/rotation.cpython-311.pyc differ
|
|
|
stream3r/stream_session.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from stream3r.models.stream3r import STream3R
|
| 3 |
|
|
@@ -68,7 +71,7 @@ class StreamSession:
|
|
| 68 |
|
| 69 |
def _get_cache(self):
|
| 70 |
return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list
|
| 71 |
-
|
| 72 |
def get_all_predictions(self):
|
| 73 |
return self.predictions
|
| 74 |
|
|
@@ -83,6 +86,96 @@ class StreamSession:
|
|
| 83 |
self._clear_predictions()
|
| 84 |
self._clear_cache()
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def forward_stream(self, images):
|
| 87 |
aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache()
|
| 88 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
|
| 4 |
import torch
|
| 5 |
from stream3r.models.stream3r import STream3R
|
| 6 |
|
|
|
|
| 71 |
|
| 72 |
def _get_cache(self):
|
| 73 |
return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list
|
| 74 |
+
|
| 75 |
def get_all_predictions(self):
|
| 76 |
return self.predictions
|
| 77 |
|
|
|
|
| 86 |
self._clear_predictions()
|
| 87 |
self._clear_cache()
|
| 88 |
|
| 89 |
+
@staticmethod
|
| 90 |
+
def _detach_to_cpu(cache_like):
|
| 91 |
+
if isinstance(cache_like, torch.Tensor):
|
| 92 |
+
return cache_like.detach().cpu()
|
| 93 |
+
if isinstance(cache_like, list):
|
| 94 |
+
return [StreamSession._detach_to_cpu(elem) for elem in cache_like]
|
| 95 |
+
if isinstance(cache_like, tuple):
|
| 96 |
+
return tuple(StreamSession._detach_to_cpu(elem) for elem in cache_like)
|
| 97 |
+
return cache_like
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def _to_device(cache_like, device: torch.device):
|
| 101 |
+
if isinstance(cache_like, torch.Tensor):
|
| 102 |
+
return cache_like.to(device)
|
| 103 |
+
if isinstance(cache_like, list):
|
| 104 |
+
return [StreamSession._to_device(elem, device) for elem in cache_like]
|
| 105 |
+
if isinstance(cache_like, tuple):
|
| 106 |
+
return tuple(StreamSession._to_device(elem, device) for elem in cache_like)
|
| 107 |
+
return cache_like
|
| 108 |
+
|
| 109 |
+
def _default_device(self) -> torch.device:
|
| 110 |
+
try:
|
| 111 |
+
return next(self.model.parameters()).device
|
| 112 |
+
except StopIteration:
|
| 113 |
+
return torch.device("cpu")
|
| 114 |
+
|
| 115 |
+
def save_cache(self, file_path: str) -> None:
|
| 116 |
+
aggregator_cache, camera_cache = self._get_cache()
|
| 117 |
+
|
| 118 |
+
payload: Dict[str, Any] = {
|
| 119 |
+
"metadata": {
|
| 120 |
+
"mode": self.mode,
|
| 121 |
+
"aggregator_depth": self.aggregator_kv_cache_depth,
|
| 122 |
+
"camera_head_depth": self.camera_head_kv_cache_depth,
|
| 123 |
+
"camera_head_iterations": self.camera_head_iterations,
|
| 124 |
+
"patch_size": getattr(self.model.aggregator, "patch_size", None),
|
| 125 |
+
"patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
|
| 126 |
+
},
|
| 127 |
+
"aggregator_cache": self._detach_to_cpu(aggregator_cache),
|
| 128 |
+
"camera_cache": self._detach_to_cpu(camera_cache),
|
| 129 |
+
"predictions": {k: v.detach().cpu() for k, v in self.predictions.items()},
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
dir_name = os.path.dirname(file_path)
|
| 133 |
+
if dir_name:
|
| 134 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
torch.save(payload, file_path)
|
| 137 |
+
|
| 138 |
+
def load_cache(self, file_path: str, *, device: Optional[torch.device] = None, strict: bool = True) -> None:
|
| 139 |
+
if device is None:
|
| 140 |
+
device = self._default_device()
|
| 141 |
+
|
| 142 |
+
payload = torch.load(file_path, map_location="cpu")
|
| 143 |
+
|
| 144 |
+
metadata: Dict[str, Any] = payload.get("metadata", {})
|
| 145 |
+
|
| 146 |
+
expected_metadata = {
|
| 147 |
+
"mode": self.mode,
|
| 148 |
+
"aggregator_depth": self.aggregator_kv_cache_depth,
|
| 149 |
+
"camera_head_depth": self.camera_head_kv_cache_depth,
|
| 150 |
+
"camera_head_iterations": self.camera_head_iterations,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
for key, expected_value in expected_metadata.items():
|
| 154 |
+
actual_value = metadata.get(key)
|
| 155 |
+
if strict and actual_value != expected_value:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"Loaded cache metadata mismatch for '{key}': expected {expected_value}, got {actual_value}"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if strict:
|
| 161 |
+
patch_size = getattr(self.model.aggregator, "patch_size", None)
|
| 162 |
+
patch_start_idx = getattr(self.model.aggregator, "patch_start_idx", None)
|
| 163 |
+
if metadata.get("patch_size") not in (None, patch_size):
|
| 164 |
+
raise ValueError(
|
| 165 |
+
f"Loaded cache metadata mismatch for 'patch_size': expected {patch_size}, got {metadata.get('patch_size')}"
|
| 166 |
+
)
|
| 167 |
+
if metadata.get("patch_start_idx") not in (None, patch_start_idx):
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"Loaded cache metadata mismatch for 'patch_start_idx': expected {patch_start_idx}, got {metadata.get('patch_start_idx')}"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
self.aggregator_kv_cache_list = self._to_device(payload.get("aggregator_cache", []), device)
|
| 173 |
+
self.camera_head_kv_cache_list = self._to_device(payload.get("camera_cache", []), device)
|
| 174 |
+
self.predictions = {
|
| 175 |
+
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 176 |
+
for k, v in payload.get("predictions", {}).items()
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
def forward_stream(self, images):
|
| 180 |
aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache()
|
| 181 |
|
stream3r/utils/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/utils/__pycache__/__init__.cpython-311.pyc and b/stream3r/utils/__pycache__/__init__.cpython-311.pyc differ
|
|
|
stream3r/utils/__pycache__/instantiators.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/utils/__pycache__/instantiators.cpython-311.pyc and b/stream3r/utils/__pycache__/instantiators.cpython-311.pyc differ
|
|
|
stream3r/utils/__pycache__/logging_utils.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/utils/__pycache__/logging_utils.cpython-311.pyc and b/stream3r/utils/__pycache__/logging_utils.cpython-311.pyc differ
|
|
|
stream3r/utils/__pycache__/pylogger.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/utils/__pycache__/pylogger.cpython-311.pyc and b/stream3r/utils/__pycache__/pylogger.cpython-311.pyc differ
|
|
|
stream3r/utils/__pycache__/rich_utils.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/utils/__pycache__/rich_utils.cpython-311.pyc and b/stream3r/utils/__pycache__/rich_utils.cpython-311.pyc differ
|
|
|
stream3r/utils/__pycache__/utils.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/utils/__pycache__/utils.cpython-311.pyc and b/stream3r/utils/__pycache__/utils.cpython-311.pyc differ
|
|
|
stream3r/utils/__pycache__/visual_utils.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/utils/__pycache__/visual_utils.cpython-311.pyc and b/stream3r/utils/__pycache__/visual_utils.cpython-311.pyc differ
|
|
|
tests/test_stream_session_cache.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import unittest
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import torch
|
| 7 |
+
except ImportError: # pragma: no cover - environment without torch
|
| 8 |
+
torch = None
|
| 9 |
+
|
| 10 |
+
if torch is not None:
|
| 11 |
+
from stream3r.stream_session import StreamSession
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if torch is None:
|
| 15 |
+
|
| 16 |
+
class StreamSessionCacheTest(unittest.TestCase): # pragma: no cover - requires torch
|
| 17 |
+
@unittest.skip("PyTorch is required for this test")
|
| 18 |
+
def test_requires_torch(self):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
else:
|
| 22 |
+
|
| 23 |
+
class _DummyAggregator:
|
| 24 |
+
depth = 2
|
| 25 |
+
patch_size = 4
|
| 26 |
+
patch_start_idx = 3
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class _DummyCameraHead:
|
| 30 |
+
trunk_depth = 3
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class _DummyModel(torch.nn.Module):
|
| 34 |
+
def __init__(self):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.aggregator = _DummyAggregator()
|
| 37 |
+
self.camera_head = _DummyCameraHead()
|
| 38 |
+
self.register_parameter("_dummy_param", torch.nn.Parameter(torch.zeros(1)))
|
| 39 |
+
|
| 40 |
+
def forward(self, *args, **kwargs):
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class StreamSessionCacheTest(unittest.TestCase):
|
| 45 |
+
def _make_populated_session(self) -> StreamSession:
|
| 46 |
+
model = _DummyModel()
|
| 47 |
+
session = StreamSession(model, mode="causal")
|
| 48 |
+
|
| 49 |
+
aggregator_cache = [
|
| 50 |
+
[torch.randn(1, 2, 3), torch.randn(1, 2, 3)]
|
| 51 |
+
for _ in range(session.aggregator_kv_cache_depth)
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
camera_cache = []
|
| 55 |
+
for _ in range(session.camera_head_iterations):
|
| 56 |
+
iter_cache = []
|
| 57 |
+
for _ in range(session.camera_head_kv_cache_depth):
|
| 58 |
+
iter_cache.append([torch.randn(1, 4, 5), torch.randn(1, 4, 5)])
|
| 59 |
+
camera_cache.append(iter_cache)
|
| 60 |
+
|
| 61 |
+
session.aggregator_kv_cache_list = aggregator_cache
|
| 62 |
+
session.camera_head_kv_cache_list = camera_cache
|
| 63 |
+
session.predictions = {
|
| 64 |
+
"depth": torch.randn(1, 2, 3, 3),
|
| 65 |
+
"pose_enc": torch.randn(1, 2, 9),
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return session
|
| 69 |
+
|
| 70 |
+
def test_round_trip_save_and_load(self):
|
| 71 |
+
session = self._make_populated_session()
|
| 72 |
+
model = session.model
|
| 73 |
+
|
| 74 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 75 |
+
file_path = os.path.join(tmpdir, "kv_cache.pt")
|
| 76 |
+
session.save_cache(file_path)
|
| 77 |
+
|
| 78 |
+
restored_session = StreamSession(model, mode="causal")
|
| 79 |
+
restored_session.load_cache(file_path)
|
| 80 |
+
|
| 81 |
+
for original_layer, restored_layer in zip(
|
| 82 |
+
session.aggregator_kv_cache_list, restored_session.aggregator_kv_cache_list
|
| 83 |
+
):
|
| 84 |
+
for original_tensor, restored_tensor in zip(original_layer, restored_layer):
|
| 85 |
+
if original_tensor is None:
|
| 86 |
+
self.assertIsNone(restored_tensor)
|
| 87 |
+
else:
|
| 88 |
+
self.assertTrue(torch.equal(original_tensor, restored_tensor))
|
| 89 |
+
|
| 90 |
+
for original_iter, restored_iter in zip(
|
| 91 |
+
session.camera_head_kv_cache_list, restored_session.camera_head_kv_cache_list
|
| 92 |
+
):
|
| 93 |
+
for original_layer, restored_layer in zip(original_iter, restored_iter):
|
| 94 |
+
for original_tensor, restored_tensor in zip(original_layer, restored_layer):
|
| 95 |
+
if original_tensor is None:
|
| 96 |
+
self.assertIsNone(restored_tensor)
|
| 97 |
+
else:
|
| 98 |
+
self.assertTrue(torch.equal(original_tensor, restored_tensor))
|
| 99 |
+
|
| 100 |
+
for key, original_tensor in session.predictions.items():
|
| 101 |
+
restored_tensor = restored_session.predictions[key]
|
| 102 |
+
self.assertTrue(torch.equal(original_tensor, restored_tensor))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__": # pragma: no cover - manual execution
|
| 106 |
+
unittest.main()
|