Spaces:
Configuration error
Configuration error
| import json | |
| import os | |
| import tempfile | |
| import unittest | |
| try: | |
| import torch | |
| except ImportError: # pragma: no cover - environment without torch | |
| torch = None | |
| if torch is not None: | |
| from stream3r.stream_session import StreamSession | |
| if torch is None: | |
| class StreamSessionCacheTest(unittest.TestCase): # pragma: no cover - requires torch | |
| def test_requires_torch(self): | |
| pass | |
| else: | |
| class _DummyAggregator: | |
| depth = 2 | |
| patch_size = 4 | |
| patch_start_idx = 3 | |
| class _DummyCameraHead: | |
| trunk_depth = 3 | |
| class _DummyModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.aggregator = _DummyAggregator() | |
| self.camera_head = _DummyCameraHead() | |
| self.register_parameter("_dummy_param", torch.nn.Parameter(torch.zeros(1))) | |
| def forward(self, *args, **kwargs): | |
| raise NotImplementedError | |
| class StreamSessionCacheTest(unittest.TestCase): | |
| def _make_populated_session(self) -> StreamSession: | |
| model = _DummyModel() | |
| session = StreamSession(model, mode="causal") | |
| aggregator_cache = [ | |
| [torch.randn(1, 2, 3), torch.randn(1, 2, 3)] | |
| for _ in range(session.aggregator_kv_cache_depth) | |
| ] | |
| camera_cache = [] | |
| for _ in range(session.camera_head_iterations): | |
| iter_cache = [] | |
| for _ in range(session.camera_head_kv_cache_depth): | |
| iter_cache.append([torch.randn(1, 4, 5), torch.randn(1, 4, 5)]) | |
| camera_cache.append(iter_cache) | |
| session.aggregator_kv_cache_list = aggregator_cache | |
| session.camera_head_kv_cache_list = camera_cache | |
| session.predictions = { | |
| "depth": torch.randn(1, 2, 3, 3), | |
| "pose_enc": torch.randn(1, 2, 9), | |
| } | |
| return session | |
| def test_round_trip_save_and_load(self): | |
| session = self._make_populated_session() | |
| model = session.model | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| file_path = os.path.join(tmpdir, "kv_cache.pt") | |
| session.save_cache(file_path) | |
| restored_session = StreamSession(model, mode="causal") | |
| restored_session.load_cache(file_path) | |
| for original_layer, restored_layer in zip( | |
| session.aggregator_kv_cache_list, restored_session.aggregator_kv_cache_list | |
| ): | |
| for original_tensor, restored_tensor in zip(original_layer, restored_layer): | |
| if original_tensor is None: | |
| self.assertIsNone(restored_tensor) | |
| else: | |
| self.assertTrue(torch.equal(original_tensor, restored_tensor)) | |
| for original_iter, restored_iter in zip( | |
| session.camera_head_kv_cache_list, restored_session.camera_head_kv_cache_list | |
| ): | |
| for original_layer, restored_layer in zip(original_iter, restored_iter): | |
| for original_tensor, restored_tensor in zip(original_layer, restored_layer): | |
| if original_tensor is None: | |
| self.assertIsNone(restored_tensor) | |
| else: | |
| self.assertTrue(torch.equal(original_tensor, restored_tensor)) | |
| for key, original_tensor in session.predictions.items(): | |
| restored_tensor = restored_session.predictions[key] | |
| self.assertTrue(torch.equal(original_tensor, restored_tensor)) | |
| def test_window_size_from_config(self): | |
| model = _DummyModel() | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| config_path = os.path.join(tmpdir, "stream_session.json") | |
| with open(config_path, "w", encoding="utf-8") as handle: | |
| json.dump({"window_size": 7}, handle) | |
| session = StreamSession(model, mode="window", config_path=config_path) | |
| self.assertEqual(session.window_size, 7) | |
| def test_window_size_override(self): | |
| model = _DummyModel() | |
| session = StreamSession(model, mode="window", window_size=11) | |
| self.assertEqual(session.window_size, 11) | |
| if __name__ == "__main__": # pragma: no cover - manual execution | |
| unittest.main() | |