| """MoSRAH sparse KV cache — single-layer implementation. |
| |
| MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed |
| by head rather than by sequence position. The routing is dynamic and produces a ragged |
| distribution of token counts across (batch, head) slots — different batch items may |
| route different numbers of tokens to the same head, and different heads accumulate at |
| different rates. DynamicCache cannot represent this correctly: it concatenates along |
| the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache |
| therefore uses a custom buffer design. |
| |
| Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values |
| attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert |
| heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked |
| head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the |
| valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the |
| buffer_capacity property and is derived directly from self.keys rather than tracked |
| as a separate variable. |
| |
| The primary interface is update(key_states, value_states, active_mask), which accepts |
| expert-choice layout, stores only active entries in causal order, and returns the full |
| accumulated (keys, values, active_mask) for immediate use by BEA. The returned |
| active_mask identifies valid cached positions; everything beyond each slot's count is |
| junk data that downstream attention must exclude. |
| |
| BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts |
| exposed by get_heads_lengths() must be read before update() if the caller needs the |
| pre-update occupancy for position computation (Unit 10.A). update() increments counts |
| in-place and the pre-update values are not recoverable afterward. |
| |
| All buffers are allocated at construction time. MoSRAHCache is constructed by |
| ShramLayerCache, which has access to batch size, device, and all model config parameters |
| needed to fully specify the storage layout upfront. |
| """ |
|
|
| import torch |
| from transformers.cache_utils import CacheLayerMixin |
|
|
|
|
| class MoSRAHCache(CacheLayerMixin): |
| """KV cache for the MoSRAH sparse attention path — single decoder layer. |
| |
| Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role. |
| Stores keys and values in the mixin-standard self.keys and self.values attributes |
| using a custom (B, L, T, u) layout rather than delegating to DynamicCache, |
| which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly. |
| |
| All storage is allocated at construction time and is_initialized is True |
| immediately. The caller (ShramLayerCache) provides batch size, device, and model |
| config parameters so no lazy allocation is needed. |
| |
| Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a |
| (B, L, T) boolean active_mask. Only positions where active_mask is True are written. |
| This matches the packed representation produced by expert packing in the MoSRAH |
| forward pass, where BEA has already applied RoPE before calling update(). |
| |
| Args: |
| num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the |
| second dimension of all storage tensors. |
| head_dim: Bottlenecked head embedding width (u). Determines the fourth |
| dimension of all storage tensors. |
| batch_size: Number of sequences in the batch. Determines the first dimension |
| of all storage tensors. |
| device: Device on which to allocate all tensors. Should match the model device. |
| initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled |
| when any slot overflows. Defaults to 64 to avoid repeated reallocation |
| during prompt processing. |
| """ |
|
|
| is_compileable = False |
| is_sliding = False |
|
|
| def __init__( |
| self, |
| num_mosrah_heads: int, |
| head_dim: int, |
| batch_size: int, |
| device: torch.device, |
| initial_buffer_size: int = 64, |
| ) -> None: |
| super().__init__() |
| self.num_mosrah_heads = num_mosrah_heads |
| self.head_dim = head_dim |
| self.batch_size = batch_size |
| self.device = device |
|
|
| |
| |
| |
| self.keys: torch.Tensor = torch.zeros( |
| batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device |
| ) |
| self.values: torch.Tensor = torch.zeros( |
| batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device |
| ) |
| self._counts: torch.Tensor = torch.zeros( |
| batch_size, num_mosrah_heads, dtype=torch.long, device=device |
| ) |
|
|
| |
| self.is_initialized = True |
|
|
| |
| |
| |
|
|
| @property |
| def buffer_capacity(self) -> int: |
| """Current number of slots allocated per (batch, head) pair. |
| |
| Derived directly from self.keys rather than tracked separately, so it is |
| always consistent with the actual buffer after expansion. |
| """ |
| return self.keys.shape[2] |
|
|
| |
| |
| |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| active_mask: torch.Tensor, |
| cache_kwargs: dict | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Scatter active key/value states into the buffer and return the full cache state. |
| |
| Accepts expert-choice layout: key_states and value_states are (B, L, T, u); |
| active_mask is (B, L, T) bool with True marking real tokens. Only active |
| positions are written; inactive positions are ignored. |
| |
| Uses a cumsum construction to derive the absolute buffer position for each |
| active token without any Python loops. For a given (batch, head) slot, |
| positions are assigned in the order tokens appear along the T dimension, |
| preserving causal ordering. |
| |
| Returns the full accumulated (keys, values, active_mask) across the cached |
| sparse sequence. The returned active_mask is True exactly for slots t < |
| counts[b, l]; everything beyond is junk data that BEA must exclude. |
| |
| Note: get_heads_lengths() must be called before update() if the caller needs |
| the pre-update occupancy for position computation (Unit 10.A). update() |
| increments counts in-place and the pre-update values are not recoverable. |
| |
| Args: |
| key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout. |
| value_states: Shape (B, L, T, u) — value vectors in expert-choice layout. |
| active_mask: Shape (B, L, T) bool — True for real tokens, False for padding. |
| cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature. |
| |
| Returns: |
| Tuple of (keys, values, active_mask): |
| keys: (B, L, T, u) float — full key buffer including junk slots. |
| values: (B, L, T, u) float — full value buffer including junk slots. |
| active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written. |
| """ |
| incoming_delta = active_mask.long().sum(dim=2) |
|
|
| if (self._counts + incoming_delta).max().item() > self.buffer_capacity: |
| self._expand() |
|
|
| |
| |
| |
| |
| within_slot = active_mask.long().cumsum(dim=2) - 1 |
|
|
| |
| |
| abs_pos = within_slot + self._counts.unsqueeze(-1) |
|
|
| |
| b_idx, l_idx, t_idx = torch.where(active_mask) |
| self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = ( |
| key_states[b_idx, l_idx, t_idx] |
| ) |
| self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = ( |
| value_states[b_idx, l_idx, t_idx] |
| ) |
|
|
| self._counts += incoming_delta |
|
|
| return self.keys, self.values, self._make_active_mask() |
|
|
| def get_heads_lengths(self) -> torch.Tensor: |
| """Return the per-(batch, head) token count for this layer. |
| |
| This is the authoritative occupancy tensor consumed by BEA for attention |
| masking and by position computation (Unit 10.A) for semantic-sequence |
| position computation. |
| |
| Note: in the MoSRAH forward pass, this must be called before update() if the |
| caller needs the pre-update occupancy. update() increments these counts in-place. |
| |
| Returns: |
| Integer tensor of shape (B, L) where entry [b, h] is the number of valid |
| tokens stored in the (b, h) slot. Zero for slots with no writes yet. |
| """ |
| return self._counts |
|
|
| |
| |
| |
|
|
| def reset(self) -> None: |
| """Clear all cached key and value tensors. |
| |
| Zeroes self.keys, self.values, and _counts in place. Storage remains allocated |
| and is_initialized remains True — only the contents are cleared. |
| """ |
| self.keys.zero_() |
| self.values.zero_() |
| self._counts.zero_() |
|
|
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: |
| """Reorder the batch dimension of all cached tensors for beam search. |
| |
| Applied atomically across self.keys, self.values, and _counts. Beam search |
| must reorder all three together or the occupancy counts and buffer contents |
| will correspond to different beam hypotheses. |
| |
| Overrides the parent because the parent's implementation calls get_seq_length(), |
| which is not supported for this cache. |
| |
| Args: |
| beam_idx: Permutation indices of shape (batch,) produced by the beam |
| search algorithm. |
| """ |
| self.keys = self.keys[beam_idx] |
| self.values = self.values[beam_idx] |
| self._counts = self._counts[beam_idx] |
|
|
| def batch_repeat_interleave(self, repeats: int) -> None: |
| """Expand the batch dimension by repeating each entry repeats times. |
| |
| Used at beam search initialisation to expand the cache from batch size B to |
| B * repeats, matching the expanded beam candidate batch. Applied atomically |
| across keys, values, and _counts; batch_size is updated to reflect the new size. |
| |
| Args: |
| repeats: Number of times to repeat each batch entry. |
| """ |
| self.keys = self.keys.repeat_interleave(repeats, dim=0) |
| self.values = self.values.repeat_interleave(repeats, dim=0) |
| self._counts = self._counts.repeat_interleave(repeats, dim=0) |
| self.batch_size = self.batch_size * repeats |
|
|
| def batch_select_indices(self, indices: torch.Tensor) -> None: |
| """Select a subset of batch entries by index. |
| |
| Used in contrastive search to retain only the selected candidate entries. |
| Applied atomically across keys, values, and _counts; batch_size is updated |
| to reflect the number of retained entries. |
| |
| Args: |
| indices: 1-D integer tensor of batch indices to retain. |
| """ |
| self.keys = self.keys[indices] |
| self.values = self.values[indices] |
| self._counts = self._counts[indices] |
| self.batch_size = indices.shape[0] |
|
|
| def offload(self) -> None: |
| """Offload all cached tensors to CPU. |
| |
| Extends the parent to also offload _counts, which the parent does not know |
| about. All three tensors are moved atomically so device state remains consistent. |
| """ |
| super().offload() |
| self._counts = self._counts.to("cpu", non_blocking=True) |
|
|
| def prefetch(self) -> None: |
| """Move all cached tensors back to the model device ahead of time. |
| |
| Extends the parent to also prefetch _counts, which the parent does not know |
| about. _counts is synced to self.keys.device after the parent moves keys and |
| values, so all three remain consistent. |
| """ |
| super().prefetch() |
| if self._counts.device != self.keys.device: |
| self._counts = self._counts.to(self.keys.device, non_blocking=True) |
|
|
| def lazy_initialization( |
| self, key_states: torch.Tensor, value_states: torch.Tensor |
| ) -> None: |
| """No-op — storage is fully allocated at construction time.""" |
| pass |
|
|
| |
| |
| |
|
|
| def get_seq_length(self) -> int: |
| """Not supported — no single sequence length represents this cache's state. |
| |
| MoSRAH heads accumulate independently; (batch, head) slots have different |
| lengths depending on routing history. There is no meaningful scalar summary. |
| Use get_heads_lengths() for per-head occupancy. |
| """ |
| raise NotImplementedError( |
| "MoSRAHCache has no single sequence length. " |
| "Use get_heads_lengths() for per-head occupancy." |
| ) |
|
|
| def get_max_cache_shape(self) -> int: |
| """Not supported — MoSRAHCache is dynamic and unbounded.""" |
| raise NotImplementedError( |
| "MoSRAHCache is unbounded; get_max_cache_shape() is not supported." |
| ) |
|
|
| def get_mask_sizes( |
| self, |
| cache_position: torch.Tensor, |
| ) -> tuple[int, int]: |
| """Not supported — MoSRAHCache does not participate in HF mask construction.""" |
| raise NotImplementedError( |
| "MoSRAHCache does not support get_mask_sizes()." |
| ) |
|
|
| |
| |
| |
|
|
| def _make_active_mask(self) -> torch.Tensor: |
| """Construct the (B, L, T) active mask from current counts. |
| |
| Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot |
| has been written. Positions at or beyond the count are junk and must be |
| excluded by downstream attention. |
| """ |
| cap = self.buffer_capacity |
| return ( |
| torch.arange(cap, device=self.keys.device) |
| .expand(self.batch_size, self.num_mosrah_heads, cap) |
| < self._counts.unsqueeze(-1) |
| ) |
|
|
| def _expand(self) -> None: |
| """Double the buffer capacity, preserving existing data. |
| |
| Called by update() when an incoming batch of tokens would cause any |
| (batch, head) slot to exceed the current buffer capacity. All existing |
| key and value data is copied into the low half of the new buffer; the |
| high half is zero-initialised and will be filled by subsequent writes. |
| After reassignment, buffer_capacity reflects the new size automatically. |
| """ |
| old_cap = self.buffer_capacity |
| new_cap = old_cap * 2 |
| dev = self.keys.device |
| new_keys = torch.zeros( |
| self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev |
| ) |
| new_values = torch.zeros( |
| self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev |
| ) |
| new_keys[:, :, :old_cap, :] = self.keys |
| new_values[:, :, :old_cap, :] = self.values |
| self.keys = new_keys |
| self.values = new_values |
|
|