File size: 14,503 Bytes
7bf638f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""Unvectorized reference implementation of the MoSRAH sparse KV cache.

This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
interface and storage layout as MoSRAHCache but uses an explicit Python loop over
(b, l, t) triples in update(). The loop is obviously correct by inspection: each active
position's key and value are written to the next available slot for that (batch, head)
pair, in the order positions appear along the T dimension, which directly enforces
causal ordering without any index arithmetic to verify.

SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
vectorized implementation is validated by asserting exact agreement with this one on all
test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
(test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
an oracle.
"""

import torch
from transformers.cache_utils import CacheLayerMixin


class SlowMoSRAHCache(CacheLayerMixin):
    """Unvectorized reference implementation of the MoSRAH KV cache.

    Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
    mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
    with the same constructor signature and the same CacheLayerMixin protocol methods.
    The sole difference is update(), which uses an explicit Python loop over (b, l, t)
    triples rather than vectorized index arithmetic.

    This class is not used in the model path. It exists so that MoSRAHCache.update()
    can be validated by asserting exact agreement with this implementation on all test
    inputs. See module docstring for the trust chain this enables.

    Args:
        num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
            second dimension of all storage tensors.
        head_dim: Bottlenecked head embedding width (u). Determines the fourth
            dimension of all storage tensors.
        batch_size: Number of sequences in the batch. Determines the first dimension
            of all storage tensors.
        device: Device on which to allocate all tensors. Should match the model device.
        initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
            when any slot overflows. Defaults to 64 to avoid repeated reallocation
            during prompt processing.
    """

    is_compileable = False
    is_sliding = False

    def __init__(
        self,
        num_mosrah_heads: int,
        head_dim: int,
        batch_size: int,
        device: torch.device,
        initial_buffer_size: int = 64,
    ) -> None:
        super().__init__()
        self.num_mosrah_heads = num_mosrah_heads
        self.head_dim = head_dim
        self.batch_size = batch_size
        self.device = device

        # Allocate primary storage into the mixin-standard self.keys / self.values so
        # that inherited methods (offload, prefetch) operate on real tensors. _counts
        # tracks valid occupancy per (batch, head) slot.
        self.keys: torch.Tensor = torch.zeros(
            batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
        )
        self.values: torch.Tensor = torch.zeros(
            batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
        )
        self._counts: torch.Tensor = torch.zeros(
            batch_size, num_mosrah_heads, dtype=torch.long, device=device
        )

        # Storage is fully allocated at construction β€” the cache is initialized.
        self.is_initialized = True

    # ---------------------------------------------------------------------------
    # Properties
    # ---------------------------------------------------------------------------

    @property
    def buffer_capacity(self) -> int:
        """Current number of slots allocated per (batch, head) pair.

        Derived directly from self.keys rather than tracked separately, so it is
        always consistent with the actual buffer after expansion.
        """
        return self.keys.shape[2]

    # ---------------------------------------------------------------------------
    # Primary API
    # ---------------------------------------------------------------------------

    def update(  # type: ignore[override]
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        active_mask: torch.Tensor,
        cache_kwargs: dict | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Scatter active key/value states using an explicit loop; return full cache state.

        Iterates over every (b, l, t) triple. For each position where active_mask is
        True, the key and value are written to the next available slot for that
        (batch, head) pair and the count is incremented. Causal ordering is guaranteed
        because the t dimension is traversed from 0 to T-1 and counts are updated
        immediately after each write.

        Buffer expansion (doubling buffer_capacity) is triggered before any writes if
        the incoming tokens would cause any slot to overflow the current capacity.

        Args:
            key_states: Shape (B, L, T, u) β€” post-RoPE key vectors in expert-choice layout.
            value_states: Shape (B, L, T, u) β€” value vectors in expert-choice layout.
            active_mask: Shape (B, L, T) bool β€” True for real tokens, False for padding.
            cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.

        Returns:
            Tuple of (keys, values, active_mask):
              keys: (B, L, T, u) float β€” full key buffer including junk slots.
              values: (B, L, T, u) float β€” full value buffer including junk slots.
              active_mask: (B, L, T) bool β€” True iff slot (b, l, t) has been written.
        """
        B, L, T = active_mask.shape

        # Expansion check uses the total active tokens per slot, same as the
        # vectorized implementation, so both expand under identical conditions.
        incoming_delta = active_mask.long().sum(dim=2)  # (B, L)
        if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
            self._expand()

        # Write each active position into the next available slot for its (batch, head)
        # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
        for b in range(B):
            for l in range(L):
                for t in range(T):
                    if active_mask[b, l, t]:
                        pos = self._counts[b, l].item()
                        self.keys[b, l, pos, :] = key_states[b, l, t, :]
                        self.values[b, l, pos, :] = value_states[b, l, t, :]
                        self._counts[b, l] += 1

        return self.keys, self.values, self._make_active_mask()

    def get_heads_lengths(self) -> torch.Tensor:
        """Return the per-(batch, head) token count for this layer.

        This is the authoritative occupancy tensor consumed by BEA for attention
        masking and by position computation (Unit 10.A) for semantic-sequence
        position computation.

        Returns:
            Integer tensor of shape (B, L) where entry [b, h] is the number of valid
            tokens stored in the (b, h) slot. Zero for slots with no writes yet.
        """
        return self._counts

    # ---------------------------------------------------------------------------
    # CacheLayerMixin β€” overridden coordination methods
    # ---------------------------------------------------------------------------

    def reset(self) -> None:
        """Clear all cached key and value tensors.

        Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
        and is_initialized remains True β€” only the contents are cleared.
        """
        self.keys.zero_()
        self.values.zero_()
        self._counts.zero_()

    def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
        """Reorder the batch dimension of all cached tensors for beam search.

        Applied atomically across self.keys, self.values, and _counts. Beam search
        must reorder all three together or the occupancy counts and buffer contents
        will correspond to different beam hypotheses.

        Overrides the parent because the parent's implementation calls get_seq_length(),
        which is not supported for this cache.

        Args:
            beam_idx: Permutation indices of shape (batch,) produced by the beam
                search algorithm.
        """
        self.keys = self.keys[beam_idx]
        self.values = self.values[beam_idx]
        self._counts = self._counts[beam_idx]

    def batch_repeat_interleave(self, repeats: int) -> None:
        """Expand the batch dimension by repeating each entry repeats times.

        Used at beam search initialisation to expand the cache from batch size B to
        B * repeats, matching the expanded beam candidate batch. Applied atomically
        across keys, values, and _counts; batch_size is updated to reflect the new size.

        Args:
            repeats: Number of times to repeat each batch entry.
        """
        self.keys = self.keys.repeat_interleave(repeats, dim=0)
        self.values = self.values.repeat_interleave(repeats, dim=0)
        self._counts = self._counts.repeat_interleave(repeats, dim=0)
        self.batch_size = self.batch_size * repeats

    def batch_select_indices(self, indices: torch.Tensor) -> None:
        """Select a subset of batch entries by index.

        Used in contrastive search to retain only the selected candidate entries.
        Applied atomically across keys, values, and _counts; batch_size is updated
        to reflect the number of retained entries.

        Args:
            indices: 1-D integer tensor of batch indices to retain.
        """
        self.keys = self.keys[indices]
        self.values = self.values[indices]
        self._counts = self._counts[indices]
        self.batch_size = indices.shape[0]

    def offload(self) -> None:
        """Offload all cached tensors to CPU.

        Extends the parent to also offload _counts, which the parent does not know
        about. All three tensors are moved atomically so device state remains consistent.
        """
        super().offload()
        self._counts = self._counts.to("cpu", non_blocking=True)

    def prefetch(self) -> None:
        """Move all cached tensors back to the model device ahead of time.

        Extends the parent to also prefetch _counts, which the parent does not know
        about. _counts is synced to self.keys.device after the parent moves keys and
        values, so all three remain consistent.
        """
        super().prefetch()
        if self._counts.device != self.keys.device:
            self._counts = self._counts.to(self.keys.device, non_blocking=True)

    def lazy_initialization(  # type: ignore[override]
        self, key_states: torch.Tensor, value_states: torch.Tensor
    ) -> None:
        """No-op β€” storage is fully allocated at construction time."""
        pass

    # ---------------------------------------------------------------------------
    # CacheLayerMixin β€” unsupported abstract methods
    # ---------------------------------------------------------------------------

    def get_seq_length(self) -> int:  # type: ignore[override]
        """Not supported β€” no single sequence length represents this cache's state.

        MoSRAH heads accumulate independently; (batch, head) slots have different
        lengths depending on routing history. There is no meaningful scalar summary.
        Use get_heads_lengths() for per-head occupancy.
        """
        raise NotImplementedError(
            "SlowMoSRAHCache has no single sequence length. "
            "Use get_heads_lengths() for per-head occupancy."
        )

    def get_max_cache_shape(self) -> int:  # type: ignore[override]
        """Not supported β€” SlowMoSRAHCache is dynamic and unbounded."""
        raise NotImplementedError(
            "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
        )

    def get_mask_sizes(  # type: ignore[override]
        self,
        cache_position: torch.Tensor,
    ) -> tuple[int, int]:
        """Not supported β€” SlowMoSRAHCache does not participate in HF mask construction."""
        raise NotImplementedError(
            "SlowMoSRAHCache does not support get_mask_sizes()."
        )

    # ---------------------------------------------------------------------------
    # Internal helpers
    # ---------------------------------------------------------------------------

    def _make_active_mask(self) -> torch.Tensor:
        """Construct the (B, L, T) active mask from current counts.

        Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
        has been written. Positions at or beyond the count are junk and must be
        excluded by downstream attention.
        """
        cap = self.buffer_capacity
        return (
            torch.arange(cap, device=self.keys.device)
            .expand(self.batch_size, self.num_mosrah_heads, cap)
            < self._counts.unsqueeze(-1)
        )

    def _expand(self) -> None:
        """Double the buffer capacity, preserving existing data.

        Called by update() when an incoming batch of tokens would cause any
        (batch, head) slot to exceed the current buffer capacity. All existing
        key and value data is copied into the low half of the new buffer; the
        high half is zero-initialised and will be filled by subsequent writes.
        After reassignment, buffer_capacity reflects the new size automatically.
        """
        old_cap = self.buffer_capacity
        new_cap = old_cap * 2
        dev = self.keys.device
        new_keys = torch.zeros(
            self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
        )
        new_values = torch.zeros(
            self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
        )
        new_keys[:, :, :old_cap, :] = self.keys
        new_values[:, :, :old_cap, :] = self.values
        self.keys = new_keys
        self.values = new_values