File size: 4,500 Bytes
4c075ec
6805b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c075ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6805b8e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import os
import tempfile
import unittest

try:
    import torch
except ImportError:  # pragma: no cover - environment without torch
    torch = None

if torch is not None:
    from stream3r.stream_session import StreamSession


if torch is None:

    class StreamSessionCacheTest(unittest.TestCase):  # pragma: no cover - requires torch
        @unittest.skip("PyTorch is required for this test")
        def test_requires_torch(self):
            pass

else:

    class _DummyAggregator:
        depth = 2
        patch_size = 4
        patch_start_idx = 3


    class _DummyCameraHead:
        trunk_depth = 3


    class _DummyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.aggregator = _DummyAggregator()
            self.camera_head = _DummyCameraHead()
            self.register_parameter("_dummy_param", torch.nn.Parameter(torch.zeros(1)))

        def forward(self, *args, **kwargs):
            raise NotImplementedError


    class StreamSessionCacheTest(unittest.TestCase):
        def _make_populated_session(self) -> StreamSession:
            model = _DummyModel()
            session = StreamSession(model, mode="causal")

            aggregator_cache = [
                [torch.randn(1, 2, 3), torch.randn(1, 2, 3)]
                for _ in range(session.aggregator_kv_cache_depth)
            ]

            camera_cache = []
            for _ in range(session.camera_head_iterations):
                iter_cache = []
                for _ in range(session.camera_head_kv_cache_depth):
                    iter_cache.append([torch.randn(1, 4, 5), torch.randn(1, 4, 5)])
                camera_cache.append(iter_cache)

            session.aggregator_kv_cache_list = aggregator_cache
            session.camera_head_kv_cache_list = camera_cache
            session.predictions = {
                "depth": torch.randn(1, 2, 3, 3),
                "pose_enc": torch.randn(1, 2, 9),
            }

            return session

        def test_round_trip_save_and_load(self):
            session = self._make_populated_session()
            model = session.model

            with tempfile.TemporaryDirectory() as tmpdir:
                file_path = os.path.join(tmpdir, "kv_cache.pt")
                session.save_cache(file_path)

                restored_session = StreamSession(model, mode="causal")
                restored_session.load_cache(file_path)

            for original_layer, restored_layer in zip(
                session.aggregator_kv_cache_list, restored_session.aggregator_kv_cache_list
            ):
                for original_tensor, restored_tensor in zip(original_layer, restored_layer):
                    if original_tensor is None:
                        self.assertIsNone(restored_tensor)
                    else:
                        self.assertTrue(torch.equal(original_tensor, restored_tensor))

            for original_iter, restored_iter in zip(
                session.camera_head_kv_cache_list, restored_session.camera_head_kv_cache_list
            ):
                for original_layer, restored_layer in zip(original_iter, restored_iter):
                    for original_tensor, restored_tensor in zip(original_layer, restored_layer):
                        if original_tensor is None:
                            self.assertIsNone(restored_tensor)
                        else:
                            self.assertTrue(torch.equal(original_tensor, restored_tensor))

            for key, original_tensor in session.predictions.items():
                restored_tensor = restored_session.predictions[key]
                self.assertTrue(torch.equal(original_tensor, restored_tensor))

        def test_window_size_from_config(self):
            model = _DummyModel()
            with tempfile.TemporaryDirectory() as tmpdir:
                config_path = os.path.join(tmpdir, "stream_session.json")
                with open(config_path, "w", encoding="utf-8") as handle:
                    json.dump({"window_size": 7}, handle)

                session = StreamSession(model, mode="window", config_path=config_path)

            self.assertEqual(session.window_size, 7)

        def test_window_size_override(self):
            model = _DummyModel()
            session = StreamSession(model, mode="window", window_size=11)
            self.assertEqual(session.window_size, 11)


if __name__ == "__main__":  # pragma: no cover - manual execution
    unittest.main()