import json import os import tempfile import unittest try: import torch except ImportError: # pragma: no cover - environment without torch torch = None if torch is not None: from stream3r.stream_session import StreamSession if torch is None: class StreamSessionCacheTest(unittest.TestCase): # pragma: no cover - requires torch @unittest.skip("PyTorch is required for this test") def test_requires_torch(self): pass else: class _DummyAggregator: depth = 2 patch_size = 4 patch_start_idx = 3 class _DummyCameraHead: trunk_depth = 3 class _DummyModel(torch.nn.Module): def __init__(self): super().__init__() self.aggregator = _DummyAggregator() self.camera_head = _DummyCameraHead() self.register_parameter("_dummy_param", torch.nn.Parameter(torch.zeros(1))) def forward(self, *args, **kwargs): raise NotImplementedError class StreamSessionCacheTest(unittest.TestCase): def _make_populated_session(self) -> StreamSession: model = _DummyModel() session = StreamSession(model, mode="causal") aggregator_cache = [ [torch.randn(1, 2, 3), torch.randn(1, 2, 3)] for _ in range(session.aggregator_kv_cache_depth) ] camera_cache = [] for _ in range(session.camera_head_iterations): iter_cache = [] for _ in range(session.camera_head_kv_cache_depth): iter_cache.append([torch.randn(1, 4, 5), torch.randn(1, 4, 5)]) camera_cache.append(iter_cache) session.aggregator_kv_cache_list = aggregator_cache session.camera_head_kv_cache_list = camera_cache session.predictions = { "depth": torch.randn(1, 2, 3, 3), "pose_enc": torch.randn(1, 2, 9), } return session def test_round_trip_save_and_load(self): session = self._make_populated_session() model = session.model with tempfile.TemporaryDirectory() as tmpdir: file_path = os.path.join(tmpdir, "kv_cache.pt") session.save_cache(file_path) restored_session = StreamSession(model, mode="causal") restored_session.load_cache(file_path) for original_layer, restored_layer in zip( session.aggregator_kv_cache_list, restored_session.aggregator_kv_cache_list ): for original_tensor, restored_tensor in zip(original_layer, restored_layer): if original_tensor is None: self.assertIsNone(restored_tensor) else: self.assertTrue(torch.equal(original_tensor, restored_tensor)) for original_iter, restored_iter in zip( session.camera_head_kv_cache_list, restored_session.camera_head_kv_cache_list ): for original_layer, restored_layer in zip(original_iter, restored_iter): for original_tensor, restored_tensor in zip(original_layer, restored_layer): if original_tensor is None: self.assertIsNone(restored_tensor) else: self.assertTrue(torch.equal(original_tensor, restored_tensor)) for key, original_tensor in session.predictions.items(): restored_tensor = restored_session.predictions[key] self.assertTrue(torch.equal(original_tensor, restored_tensor)) def test_window_size_from_config(self): model = _DummyModel() with tempfile.TemporaryDirectory() as tmpdir: config_path = os.path.join(tmpdir, "stream_session.json") with open(config_path, "w", encoding="utf-8") as handle: json.dump({"window_size": 7}, handle) session = StreamSession(model, mode="window", config_path=config_path) self.assertEqual(session.window_size, 7) def test_window_size_override(self): model = _DummyModel() session = StreamSession(model, mode="window", window_size=11) self.assertEqual(session.window_size, 11) if __name__ == "__main__": # pragma: no cover - manual execution unittest.main()