| """Unvectorized reference implementation of the MoSRAH sparse KV cache. |
| |
| This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same |
| interface and storage layout as MoSRAHCache but uses an explicit Python loop over |
| (b, l, t) triples in update(). The loop is obviously correct by inspection: each active |
| position's key and value are written to the next available slot for that (batch, head) |
| pair, in the order positions appear along the T dimension, which directly enforces |
| causal ordering without any index arithmetic to verify. |
| |
| SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a |
| trusted ground truth against which the vectorized MoSRAHCache.update() is validated in |
| Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the |
| vectorized implementation is validated by asserting exact agreement with this one on all |
| test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite |
| (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as |
| an oracle. |
| """ |
|
|
| import torch |
| from transformers.cache_utils import CacheLayerMixin |
|
|
|
|
| class SlowMoSRAHCache(CacheLayerMixin): |
| """Unvectorized reference implementation of the MoSRAH KV cache. |
| |
| Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the |
| mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor, |
| with the same constructor signature and the same CacheLayerMixin protocol methods. |
| The sole difference is update(), which uses an explicit Python loop over (b, l, t) |
| triples rather than vectorized index arithmetic. |
| |
| This class is not used in the model path. It exists so that MoSRAHCache.update() |
| can be validated by asserting exact agreement with this implementation on all test |
| inputs. See module docstring for the trust chain this enables. |
| |
| Args: |
| num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the |
| second dimension of all storage tensors. |
| head_dim: Bottlenecked head embedding width (u). Determines the fourth |
| dimension of all storage tensors. |
| batch_size: Number of sequences in the batch. Determines the first dimension |
| of all storage tensors. |
| device: Device on which to allocate all tensors. Should match the model device. |
| initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled |
| when any slot overflows. Defaults to 64 to avoid repeated reallocation |
| during prompt processing. |
| """ |
|
|
| is_compileable = False |
| is_sliding = False |
|
|
| def __init__( |
| self, |
| num_mosrah_heads: int, |
| head_dim: int, |
| batch_size: int, |
| device: torch.device, |
| initial_buffer_size: int = 64, |
| ) -> None: |
| super().__init__() |
| self.num_mosrah_heads = num_mosrah_heads |
| self.head_dim = head_dim |
| self.batch_size = batch_size |
| self.device = device |
|
|
| |
| |
| |
| self.keys: torch.Tensor = torch.zeros( |
| batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device |
| ) |
| self.values: torch.Tensor = torch.zeros( |
| batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device |
| ) |
| self._counts: torch.Tensor = torch.zeros( |
| batch_size, num_mosrah_heads, dtype=torch.long, device=device |
| ) |
|
|
| |
| self.is_initialized = True |
|
|
| |
| |
| |
|
|
| @property |
| def buffer_capacity(self) -> int: |
| """Current number of slots allocated per (batch, head) pair. |
| |
| Derived directly from self.keys rather than tracked separately, so it is |
| always consistent with the actual buffer after expansion. |
| """ |
| return self.keys.shape[2] |
|
|
| |
| |
| |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| active_mask: torch.Tensor, |
| cache_kwargs: dict | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Scatter active key/value states using an explicit loop; return full cache state. |
| |
| Iterates over every (b, l, t) triple. For each position where active_mask is |
| True, the key and value are written to the next available slot for that |
| (batch, head) pair and the count is incremented. Causal ordering is guaranteed |
| because the t dimension is traversed from 0 to T-1 and counts are updated |
| immediately after each write. |
| |
| Buffer expansion (doubling buffer_capacity) is triggered before any writes if |
| the incoming tokens would cause any slot to overflow the current capacity. |
| |
| Args: |
| key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout. |
| value_states: Shape (B, L, T, u) — value vectors in expert-choice layout. |
| active_mask: Shape (B, L, T) bool — True for real tokens, False for padding. |
| cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature. |
| |
| Returns: |
| Tuple of (keys, values, active_mask): |
| keys: (B, L, T, u) float — full key buffer including junk slots. |
| values: (B, L, T, u) float — full value buffer including junk slots. |
| active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written. |
| """ |
| B, L, T = active_mask.shape |
|
|
| |
| |
| incoming_delta = active_mask.long().sum(dim=2) |
| if (self._counts + incoming_delta).max().item() > self.buffer_capacity: |
| self._expand() |
|
|
| |
| |
| for b in range(B): |
| for l in range(L): |
| for t in range(T): |
| if active_mask[b, l, t]: |
| pos = self._counts[b, l].item() |
| self.keys[b, l, pos, :] = key_states[b, l, t, :] |
| self.values[b, l, pos, :] = value_states[b, l, t, :] |
| self._counts[b, l] += 1 |
|
|
| return self.keys, self.values, self._make_active_mask() |
|
|
| def get_heads_lengths(self) -> torch.Tensor: |
| """Return the per-(batch, head) token count for this layer. |
| |
| This is the authoritative occupancy tensor consumed by BEA for attention |
| masking and by position computation (Unit 10.A) for semantic-sequence |
| position computation. |
| |
| Returns: |
| Integer tensor of shape (B, L) where entry [b, h] is the number of valid |
| tokens stored in the (b, h) slot. Zero for slots with no writes yet. |
| """ |
| return self._counts |
|
|
| |
| |
| |
|
|
| def reset(self) -> None: |
| """Clear all cached key and value tensors. |
| |
| Zeroes self.keys, self.values, and _counts in place. Storage remains allocated |
| and is_initialized remains True — only the contents are cleared. |
| """ |
| self.keys.zero_() |
| self.values.zero_() |
| self._counts.zero_() |
|
|
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: |
| """Reorder the batch dimension of all cached tensors for beam search. |
| |
| Applied atomically across self.keys, self.values, and _counts. Beam search |
| must reorder all three together or the occupancy counts and buffer contents |
| will correspond to different beam hypotheses. |
| |
| Overrides the parent because the parent's implementation calls get_seq_length(), |
| which is not supported for this cache. |
| |
| Args: |
| beam_idx: Permutation indices of shape (batch,) produced by the beam |
| search algorithm. |
| """ |
| self.keys = self.keys[beam_idx] |
| self.values = self.values[beam_idx] |
| self._counts = self._counts[beam_idx] |
|
|
| def batch_repeat_interleave(self, repeats: int) -> None: |
| """Expand the batch dimension by repeating each entry repeats times. |
| |
| Used at beam search initialisation to expand the cache from batch size B to |
| B * repeats, matching the expanded beam candidate batch. Applied atomically |
| across keys, values, and _counts; batch_size is updated to reflect the new size. |
| |
| Args: |
| repeats: Number of times to repeat each batch entry. |
| """ |
| self.keys = self.keys.repeat_interleave(repeats, dim=0) |
| self.values = self.values.repeat_interleave(repeats, dim=0) |
| self._counts = self._counts.repeat_interleave(repeats, dim=0) |
| self.batch_size = self.batch_size * repeats |
|
|
| def batch_select_indices(self, indices: torch.Tensor) -> None: |
| """Select a subset of batch entries by index. |
| |
| Used in contrastive search to retain only the selected candidate entries. |
| Applied atomically across keys, values, and _counts; batch_size is updated |
| to reflect the number of retained entries. |
| |
| Args: |
| indices: 1-D integer tensor of batch indices to retain. |
| """ |
| self.keys = self.keys[indices] |
| self.values = self.values[indices] |
| self._counts = self._counts[indices] |
| self.batch_size = indices.shape[0] |
|
|
| def offload(self) -> None: |
| """Offload all cached tensors to CPU. |
| |
| Extends the parent to also offload _counts, which the parent does not know |
| about. All three tensors are moved atomically so device state remains consistent. |
| """ |
| super().offload() |
| self._counts = self._counts.to("cpu", non_blocking=True) |
|
|
| def prefetch(self) -> None: |
| """Move all cached tensors back to the model device ahead of time. |
| |
| Extends the parent to also prefetch _counts, which the parent does not know |
| about. _counts is synced to self.keys.device after the parent moves keys and |
| values, so all three remain consistent. |
| """ |
| super().prefetch() |
| if self._counts.device != self.keys.device: |
| self._counts = self._counts.to(self.keys.device, non_blocking=True) |
|
|
| def lazy_initialization( |
| self, key_states: torch.Tensor, value_states: torch.Tensor |
| ) -> None: |
| """No-op — storage is fully allocated at construction time.""" |
| pass |
|
|
| |
| |
| |
|
|
| def get_seq_length(self) -> int: |
| """Not supported — no single sequence length represents this cache's state. |
| |
| MoSRAH heads accumulate independently; (batch, head) slots have different |
| lengths depending on routing history. There is no meaningful scalar summary. |
| Use get_heads_lengths() for per-head occupancy. |
| """ |
| raise NotImplementedError( |
| "SlowMoSRAHCache has no single sequence length. " |
| "Use get_heads_lengths() for per-head occupancy." |
| ) |
|
|
| def get_max_cache_shape(self) -> int: |
| """Not supported — SlowMoSRAHCache is dynamic and unbounded.""" |
| raise NotImplementedError( |
| "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported." |
| ) |
|
|
| def get_mask_sizes( |
| self, |
| cache_position: torch.Tensor, |
| ) -> tuple[int, int]: |
| """Not supported — SlowMoSRAHCache does not participate in HF mask construction.""" |
| raise NotImplementedError( |
| "SlowMoSRAHCache does not support get_mask_sizes()." |
| ) |
|
|
| |
| |
| |
|
|
| def _make_active_mask(self) -> torch.Tensor: |
| """Construct the (B, L, T) active mask from current counts. |
| |
| Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot |
| has been written. Positions at or beyond the count are junk and must be |
| excluded by downstream attention. |
| """ |
| cap = self.buffer_capacity |
| return ( |
| torch.arange(cap, device=self.keys.device) |
| .expand(self.batch_size, self.num_mosrah_heads, cap) |
| < self._counts.unsqueeze(-1) |
| ) |
|
|
| def _expand(self) -> None: |
| """Double the buffer capacity, preserving existing data. |
| |
| Called by update() when an incoming batch of tokens would cause any |
| (batch, head) slot to exceed the current buffer capacity. All existing |
| key and value data is copied into the low half of the new buffer; the |
| high half is zero-initialised and will be filled by subsequent writes. |
| After reassignment, buffer_capacity reflects the new size automatically. |
| """ |
| old_cap = self.buffer_capacity |
| new_cap = old_cap * 2 |
| dev = self.keys.device |
| new_keys = torch.zeros( |
| self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev |
| ) |
| new_values = torch.zeros( |
| self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev |
| ) |
| new_keys[:, :, :old_cap, :] = self.keys |
| new_values[:, :, :old_cap, :] = self.values |
| self.keys = new_keys |
| self.values = new_values |
|
|