dwellbot_stream3r / tests /test_stream_session_cache.py
brian4dwell's picture
add saving and reloading of session
4c075ec
import json
import os
import tempfile
import unittest
try:
import torch
except ImportError: # pragma: no cover - environment without torch
torch = None
if torch is not None:
from stream3r.stream_session import StreamSession
if torch is None:
class StreamSessionCacheTest(unittest.TestCase): # pragma: no cover - requires torch
@unittest.skip("PyTorch is required for this test")
def test_requires_torch(self):
pass
else:
class _DummyAggregator:
depth = 2
patch_size = 4
patch_start_idx = 3
class _DummyCameraHead:
trunk_depth = 3
class _DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.aggregator = _DummyAggregator()
self.camera_head = _DummyCameraHead()
self.register_parameter("_dummy_param", torch.nn.Parameter(torch.zeros(1)))
def forward(self, *args, **kwargs):
raise NotImplementedError
class StreamSessionCacheTest(unittest.TestCase):
def _make_populated_session(self) -> StreamSession:
model = _DummyModel()
session = StreamSession(model, mode="causal")
aggregator_cache = [
[torch.randn(1, 2, 3), torch.randn(1, 2, 3)]
for _ in range(session.aggregator_kv_cache_depth)
]
camera_cache = []
for _ in range(session.camera_head_iterations):
iter_cache = []
for _ in range(session.camera_head_kv_cache_depth):
iter_cache.append([torch.randn(1, 4, 5), torch.randn(1, 4, 5)])
camera_cache.append(iter_cache)
session.aggregator_kv_cache_list = aggregator_cache
session.camera_head_kv_cache_list = camera_cache
session.predictions = {
"depth": torch.randn(1, 2, 3, 3),
"pose_enc": torch.randn(1, 2, 9),
}
return session
def test_round_trip_save_and_load(self):
session = self._make_populated_session()
model = session.model
with tempfile.TemporaryDirectory() as tmpdir:
file_path = os.path.join(tmpdir, "kv_cache.pt")
session.save_cache(file_path)
restored_session = StreamSession(model, mode="causal")
restored_session.load_cache(file_path)
for original_layer, restored_layer in zip(
session.aggregator_kv_cache_list, restored_session.aggregator_kv_cache_list
):
for original_tensor, restored_tensor in zip(original_layer, restored_layer):
if original_tensor is None:
self.assertIsNone(restored_tensor)
else:
self.assertTrue(torch.equal(original_tensor, restored_tensor))
for original_iter, restored_iter in zip(
session.camera_head_kv_cache_list, restored_session.camera_head_kv_cache_list
):
for original_layer, restored_layer in zip(original_iter, restored_iter):
for original_tensor, restored_tensor in zip(original_layer, restored_layer):
if original_tensor is None:
self.assertIsNone(restored_tensor)
else:
self.assertTrue(torch.equal(original_tensor, restored_tensor))
for key, original_tensor in session.predictions.items():
restored_tensor = restored_session.predictions[key]
self.assertTrue(torch.equal(original_tensor, restored_tensor))
def test_window_size_from_config(self):
model = _DummyModel()
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "stream_session.json")
with open(config_path, "w", encoding="utf-8") as handle:
json.dump({"window_size": 7}, handle)
session = StreamSession(model, mode="window", config_path=config_path)
self.assertEqual(session.window_size, 7)
def test_window_size_override(self):
model = _DummyModel()
session = StreamSession(model, mode="window", window_size=11)
self.assertEqual(session.window_size, 11)
if __name__ == "__main__": # pragma: no cover - manual execution
unittest.main()