"""Unvectorized reference implementation of the MoSRAH sparse KV cache. This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same interface and storage layout as MoSRAHCache but uses an explicit Python loop over (b, l, t) triples in update(). The loop is obviously correct by inspection: each active position's key and value are written to the next available slot for that (batch, head) pair, in the order positions appear along the T dimension, which directly enforces causal ordering without any index arithmetic to verify. SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a trusted ground truth against which the vectorized MoSRAHCache.update() is validated in Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the vectorized implementation is validated by asserting exact agreement with this one on all test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as an oracle. """ import torch from transformers.cache_utils import CacheLayerMixin class SlowMoSRAHCache(CacheLayerMixin): """Unvectorized reference implementation of the MoSRAH KV cache. Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor, with the same constructor signature and the same CacheLayerMixin protocol methods. The sole difference is update(), which uses an explicit Python loop over (b, l, t) triples rather than vectorized index arithmetic. This class is not used in the model path. It exists so that MoSRAHCache.update() can be validated by asserting exact agreement with this implementation on all test inputs. See module docstring for the trust chain this enables. Args: num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the second dimension of all storage tensors. head_dim: Bottlenecked head embedding width (u). Determines the fourth dimension of all storage tensors. batch_size: Number of sequences in the batch. Determines the first dimension of all storage tensors. device: Device on which to allocate all tensors. Should match the model device. mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to config.mosrah_cache_length. The buffer never grows; if any slot would exceed this capacity, update() raises a RuntimeError. """ is_compileable = False is_sliding = False def __init__( self, num_mosrah_heads: int, head_dim: int, batch_size: int, device: torch.device, mosrah_cache_length: int, ) -> None: super().__init__() self.num_mosrah_heads = num_mosrah_heads self.head_dim = head_dim self.batch_size = batch_size self.device = device self.mosrah_cache_length = mosrah_cache_length # Allocate primary storage into the mixin-standard self.keys / self.values so # that inherited methods (offload, prefetch) operate on real tensors. _counts # tracks valid occupancy per (batch, head) slot. self.keys: torch.Tensor = torch.zeros( batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device ) self.values: torch.Tensor = torch.zeros( batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device ) self._counts: torch.Tensor = torch.zeros( batch_size, num_mosrah_heads, dtype=torch.long, device=device ) # Storage is fully allocated at construction — the cache is initialized. self.is_initialized = True # --------------------------------------------------------------------------- # Properties # --------------------------------------------------------------------------- @property def buffer_capacity(self) -> int: """Current number of slots allocated per (batch, head) pair. Equal to mosrah_cache_length as supplied at construction. Derived from self.keys so it remains consistent with the actual buffer shape. """ return self.keys.shape[2] # --------------------------------------------------------------------------- # Primary API # --------------------------------------------------------------------------- def update( # type: ignore[override] self, key_states: torch.Tensor, value_states: torch.Tensor, active_mask: torch.Tensor, cache_kwargs: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Scatter active key/value states using an explicit loop; return full cache state. Iterates over every (b, l, t) triple. For each position where active_mask is True, the key and value are written to the next available slot for that (batch, head) pair and the count is incremented. Causal ordering is guaranteed because the t dimension is traversed from 0 to T-1 and counts are updated immediately after each write. Raises RuntimeError before any writes if the incoming tokens would cause any slot to exceed the static mosrah_cache_length capacity. Args: key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout. value_states: Shape (B, L, T, u) — value vectors in expert-choice layout. active_mask: Shape (B, L, T) bool — True for real tokens, False for padding. cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature. Returns: Tuple of (keys, values, active_mask): keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots. values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots. active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written. """ B, L, T = active_mask.shape incoming_delta = active_mask.long().sum(dim=2) # (B, L) if (self._counts + incoming_delta).max().item() > self.mosrah_cache_length: raise RuntimeError( f"SlowMoSRAHCache overflow: a (batch, head) slot would exceed the " f"static buffer capacity of {self.mosrah_cache_length}. Increase " f"mosrah_overallocation_factor in ShramConfig." ) # Write each active position into the next available slot for its (batch, head) # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot. for b in range(B): for l in range(L): for t in range(T): if active_mask[b, l, t]: pos = self._counts[b, l].item() self.keys[b, l, pos, :] = key_states[b, l, t, :] self.values[b, l, pos, :] = value_states[b, l, t, :] self._counts[b, l] += 1 return self.keys, self.values, self._make_active_mask() def get_heads_lengths(self) -> torch.Tensor: """Return the per-(batch, head) token count for this layer. This is the authoritative occupancy tensor consumed by BEA for attention masking and by position computation (Unit 10.A) for semantic-sequence position computation. Returns: Integer tensor of shape (B, L) where entry [b, h] is the number of valid tokens stored in the (b, h) slot. Zero for slots with no writes yet. """ return self._counts # --------------------------------------------------------------------------- # CacheLayerMixin — overridden coordination methods # --------------------------------------------------------------------------- def reset(self) -> None: """Clear all cached key and value tensors. Zeroes self.keys, self.values, and _counts in place. Storage remains allocated and is_initialized remains True — only the contents are cleared. """ self.keys.zero_() self.values.zero_() self._counts.zero_() def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorder the batch dimension of all cached tensors for beam search. Applied atomically across self.keys, self.values, and _counts. Beam search must reorder all three together or the occupancy counts and buffer contents will correspond to different beam hypotheses. Overrides the parent because the parent's implementation calls get_seq_length(), which is not supported for this cache. Args: beam_idx: Permutation indices of shape (batch,) produced by the beam search algorithm. """ self.keys = self.keys[beam_idx] self.values = self.values[beam_idx] self._counts = self._counts[beam_idx] def batch_repeat_interleave(self, repeats: int) -> None: """Expand the batch dimension by repeating each entry repeats times. Used at beam search initialisation to expand the cache from batch size B to B * repeats, matching the expanded beam candidate batch. Applied atomically across keys, values, and _counts; batch_size is updated to reflect the new size. Args: repeats: Number of times to repeat each batch entry. """ self.keys = self.keys.repeat_interleave(repeats, dim=0) self.values = self.values.repeat_interleave(repeats, dim=0) self._counts = self._counts.repeat_interleave(repeats, dim=0) self.batch_size = self.batch_size * repeats def batch_select_indices(self, indices: torch.Tensor) -> None: """Select a subset of batch entries by index. Used in contrastive search to retain only the selected candidate entries. Applied atomically across keys, values, and _counts; batch_size is updated to reflect the number of retained entries. Args: indices: 1-D integer tensor of batch indices to retain. """ self.keys = self.keys[indices] self.values = self.values[indices] self._counts = self._counts[indices] self.batch_size = indices.shape[0] def offload(self) -> None: """Offload all cached tensors to CPU. Extends the parent to also offload _counts, which the parent does not know about. All three tensors are moved atomically so device state remains consistent. """ super().offload() self._counts = self._counts.to("cpu", non_blocking=True) def prefetch(self) -> None: """Move all cached tensors back to the model device ahead of time. Extends the parent to also prefetch _counts, which the parent does not know about. _counts is synced to self.keys.device after the parent moves keys and values, so all three remain consistent. """ super().prefetch() if self._counts.device != self.keys.device: self._counts = self._counts.to(self.keys.device, non_blocking=True) def lazy_initialization( # type: ignore[override] self, key_states: torch.Tensor, value_states: torch.Tensor ) -> None: """No-op — storage is fully allocated at construction time.""" pass # --------------------------------------------------------------------------- # CacheLayerMixin — unsupported abstract methods # --------------------------------------------------------------------------- def get_seq_length(self) -> int: # type: ignore[override] """Not supported — no single sequence length represents this cache's state. MoSRAH heads accumulate independently; (batch, head) slots have different lengths depending on routing history. There is no meaningful scalar summary. Use get_heads_lengths() for per-head occupancy. """ raise NotImplementedError( "SlowMoSRAHCache has no single sequence length. " "Use get_heads_lengths() for per-head occupancy." ) def get_max_cache_shape(self) -> int: # type: ignore[override] """Not supported — SlowMoSRAHCache is dynamic and unbounded.""" raise NotImplementedError( "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported." ) def get_mask_sizes( # type: ignore[override] self, cache_position: torch.Tensor, ) -> tuple[int, int]: """Not supported — SlowMoSRAHCache does not participate in HF mask construction.""" raise NotImplementedError( "SlowMoSRAHCache does not support get_mask_sizes()." ) # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _make_active_mask(self) -> torch.Tensor: """Construct the (B, L, T) active mask from current counts. Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot has been written. Positions at or beyond the count are junk and must be excluded by downstream attention. """ cap = self.buffer_capacity return ( torch.arange(cap, device=self.keys.device) .expand(self.batch_size, self.num_mosrah_heads, cap) < self._counts.unsqueeze(-1) )