File size: 9,489 Bytes
4c075ec
6805b8e
 
 
9d31508
 
 
 
 
 
 
 
4c075ec
9d31508
 
 
 
 
4c075ec
9d31508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c075ec
9d31508
 
 
 
 
 
 
 
 
 
4c075ec
9d31508
 
 
 
 
 
 
 
 
 
 
 
6805b8e
9d31508
 
 
 
 
 
 
 
 
 
 
 
 
 
6805b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c075ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6805b8e
 
 
 
 
 
 
 
 
4c075ec
6805b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c075ec
6805b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d31508
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import json
import os
from typing import Any, Dict, Optional

import torch
from stream3r.models.stream3r import STream3R


class StreamSession:
    """
    A causal streaming inference session with KV cache management for STream3R.
    """
    def __init__(self, model: STream3R, mode: str, *, window_size: Optional[int] = None, config_path: Optional[str] = None):
        self.model = model
        self.mode = mode
        self.aggregator_kv_cache_depth = model.aggregator.depth
        self.camera_head_kv_cache_depth = model.camera_head.trunk_depth
        self.camera_head_iterations = 4
        self.window_size = self._resolve_window_size(window_size, config_path)

        if self.mode not in ["causal", "window"]:
            raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")

        self.clear()

    def _clear_predictions(self):
        self.predictions = dict()
    
    def _update_predictions(self, predictions):
        for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]:
            if k in predictions:
                self.predictions[k] = torch.cat(
                    [self.predictions.get(k, torch.empty(0, device=predictions[k].device)), predictions[k]],
                    dim=1
                )

    def _clear_cache(self):
        self.aggregator_kv_cache_list = [[None, None] for _ in range(self.aggregator_kv_cache_depth)]
        self.camera_head_kv_cache_list = [[[None, None] for _ in range(self.camera_head_kv_cache_depth)] for _ in range(self.camera_head_iterations)]
    
    def _update_cache(self, aggregator_kv_cache_list, camera_head_kv_cache_list):
        if self.mode == "causal":
            self.aggregator_kv_cache_list = aggregator_kv_cache_list
            self.camera_head_kv_cache_list = camera_head_kv_cache_list
        elif self.mode == "window":
            for k in range(2):
                for i in range(self.aggregator_kv_cache_depth):
                    h, w = self.predictions["depth"].shape[2], self.predictions["depth"].shape[3]
                    P = h * w // self.model.aggregator.patch_size // self.model.aggregator.patch_size + self.model.aggregator.patch_start_idx
                    anchor_token = aggregator_kv_cache_list[i][k][:, :, :P]
                    window_tokens = aggregator_kv_cache_list[i][k][:, :, max(P, aggregator_kv_cache_list[i][k].size(2)-self.window_size*P):]
                    self.aggregator_kv_cache_list[i][k] = torch.cat(
                        [
                            anchor_token,
                            window_tokens
                        ],
                        dim=2
                    )
                for i in range(self.camera_head_iterations):
                    for j in range(self.camera_head_kv_cache_depth):
                        anchor_token = camera_head_kv_cache_list[i][j][k][:, :, :1]
                        window_tokens = camera_head_kv_cache_list[i][j][k][:, :, max(1, camera_head_kv_cache_list[i][j][k].size(2)-self.window_size):]
                        self.camera_head_kv_cache_list[i][j][k] = torch.cat(
                            [
                                anchor_token,
                                window_tokens
                            ],
                            dim=2
                        )
        else:
            raise ValueError(f"Unsupported attention mode when using kv_cache: {self.mode}")

    def _get_cache(self):
        return self.aggregator_kv_cache_list, self.camera_head_kv_cache_list

    def get_all_predictions(self):
        return self.predictions
    
    def get_last_prediction(self):
        last_predictions = dict()
        for k in ["pose_enc", "world_points", "world_points_conf", "depth", "depth_conf", "images"]:
            if k in self.predictions:
                last_predictions[k] = self.predictions[k][:, -1:]
        return last_predictions

    def clear(self):
        self._clear_predictions()
        self._clear_cache()

    @staticmethod
    def _detach_to_cpu(cache_like):
        if isinstance(cache_like, torch.Tensor):
            return cache_like.detach().cpu()
        if isinstance(cache_like, list):
            return [StreamSession._detach_to_cpu(elem) for elem in cache_like]
        if isinstance(cache_like, tuple):
            return tuple(StreamSession._detach_to_cpu(elem) for elem in cache_like)
        return cache_like

    @staticmethod
    def _to_device(cache_like, device: torch.device):
        if isinstance(cache_like, torch.Tensor):
            return cache_like.to(device)
        if isinstance(cache_like, list):
            return [StreamSession._to_device(elem, device) for elem in cache_like]
        if isinstance(cache_like, tuple):
            return tuple(StreamSession._to_device(elem, device) for elem in cache_like)
        return cache_like

    def _default_device(self) -> torch.device:
        try:
            return next(self.model.parameters()).device
        except StopIteration:
            return torch.device("cpu")

    def _resolve_window_size(self, override: Optional[int], config_path: Optional[str]) -> int:
        if override is not None:
            return override

        config_path = config_path or os.path.abspath(
            os.path.join(os.path.dirname(__file__), "..", "configs", "stream_session.json")
        )

        default_window_size = 25

        if not os.path.exists(config_path):
            return default_window_size

        try:
            with open(config_path, "r", encoding="utf-8") as handle:
                data = json.load(handle)
        except (json.JSONDecodeError, OSError):
            return default_window_size

        window_size = data.get("window_size")

        if isinstance(window_size, int) and window_size > 0:
            return window_size

        return default_window_size

    def save_cache(self, file_path: str) -> None:
        aggregator_cache, camera_cache = self._get_cache()

        payload: Dict[str, Any] = {
            "metadata": {
                "mode": self.mode,
                "aggregator_depth": self.aggregator_kv_cache_depth,
                "camera_head_depth": self.camera_head_kv_cache_depth,
                "camera_head_iterations": self.camera_head_iterations,
                "window_size": self.window_size,
                "patch_size": getattr(self.model.aggregator, "patch_size", None),
                "patch_start_idx": getattr(self.model.aggregator, "patch_start_idx", None),
            },
            "aggregator_cache": self._detach_to_cpu(aggregator_cache),
            "camera_cache": self._detach_to_cpu(camera_cache),
            "predictions": {k: v.detach().cpu() for k, v in self.predictions.items()},
        }

        dir_name = os.path.dirname(file_path)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)

        torch.save(payload, file_path)

    def load_cache(self, file_path: str, *, device: Optional[torch.device] = None, strict: bool = True) -> None:
        if device is None:
            device = self._default_device()

        payload = torch.load(file_path, map_location="cpu")

        metadata: Dict[str, Any] = payload.get("metadata", {})

        expected_metadata = {
            "mode": self.mode,
            "aggregator_depth": self.aggregator_kv_cache_depth,
            "camera_head_depth": self.camera_head_kv_cache_depth,
            "camera_head_iterations": self.camera_head_iterations,
            "window_size": self.window_size,
        }

        for key, expected_value in expected_metadata.items():
            actual_value = metadata.get(key)
            if strict and actual_value != expected_value:
                raise ValueError(
                    f"Loaded cache metadata mismatch for '{key}': expected {expected_value}, got {actual_value}"
                )

        if strict:
            patch_size = getattr(self.model.aggregator, "patch_size", None)
            patch_start_idx = getattr(self.model.aggregator, "patch_start_idx", None)
            if metadata.get("patch_size") not in (None, patch_size):
                raise ValueError(
                    f"Loaded cache metadata mismatch for 'patch_size': expected {patch_size}, got {metadata.get('patch_size')}"
                )
            if metadata.get("patch_start_idx") not in (None, patch_start_idx):
                raise ValueError(
                    f"Loaded cache metadata mismatch for 'patch_start_idx': expected {patch_start_idx}, got {metadata.get('patch_start_idx')}"
                )

        self.aggregator_kv_cache_list = self._to_device(payload.get("aggregator_cache", []), device)
        self.camera_head_kv_cache_list = self._to_device(payload.get("camera_cache", []), device)
        self.predictions = {
            k: v.to(device) if isinstance(v, torch.Tensor) else v
            for k, v in payload.get("predictions", {}).items()
        }

    def forward_stream(self, images):
        aggregator_kv_cache_list, camera_head_kv_cache_list = self._get_cache()

        outputs = self.model(
            images=images, 
            mode=self.mode, 
            aggregator_kv_cache_list=aggregator_kv_cache_list, 
            camera_head_kv_cache_list=camera_head_kv_cache_list, 
        )

        self._update_predictions(outputs)
        self._update_cache(outputs["aggregator_kv_cache_list"], outputs["camera_head_kv_cache_list"])

        return self.get_all_predictions()