File size: 14,503 Bytes
7bf638f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | """Unvectorized reference implementation of the MoSRAH sparse KV cache.
This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
interface and storage layout as MoSRAHCache but uses an explicit Python loop over
(b, l, t) triples in update(). The loop is obviously correct by inspection: each active
position's key and value are written to the next available slot for that (batch, head)
pair, in the order positions appear along the T dimension, which directly enforces
causal ordering without any index arithmetic to verify.
SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
vectorized implementation is validated by asserting exact agreement with this one on all
test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
(test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
an oracle.
"""
import torch
from transformers.cache_utils import CacheLayerMixin
class SlowMoSRAHCache(CacheLayerMixin):
"""Unvectorized reference implementation of the MoSRAH KV cache.
Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
with the same constructor signature and the same CacheLayerMixin protocol methods.
The sole difference is update(), which uses an explicit Python loop over (b, l, t)
triples rather than vectorized index arithmetic.
This class is not used in the model path. It exists so that MoSRAHCache.update()
can be validated by asserting exact agreement with this implementation on all test
inputs. See module docstring for the trust chain this enables.
Args:
num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
second dimension of all storage tensors.
head_dim: Bottlenecked head embedding width (u). Determines the fourth
dimension of all storage tensors.
batch_size: Number of sequences in the batch. Determines the first dimension
of all storage tensors.
device: Device on which to allocate all tensors. Should match the model device.
initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
when any slot overflows. Defaults to 64 to avoid repeated reallocation
during prompt processing.
"""
is_compileable = False
is_sliding = False
def __init__(
self,
num_mosrah_heads: int,
head_dim: int,
batch_size: int,
device: torch.device,
initial_buffer_size: int = 64,
) -> None:
super().__init__()
self.num_mosrah_heads = num_mosrah_heads
self.head_dim = head_dim
self.batch_size = batch_size
self.device = device
# Allocate primary storage into the mixin-standard self.keys / self.values so
# that inherited methods (offload, prefetch) operate on real tensors. _counts
# tracks valid occupancy per (batch, head) slot.
self.keys: torch.Tensor = torch.zeros(
batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
)
self.values: torch.Tensor = torch.zeros(
batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
)
self._counts: torch.Tensor = torch.zeros(
batch_size, num_mosrah_heads, dtype=torch.long, device=device
)
# Storage is fully allocated at construction β the cache is initialized.
self.is_initialized = True
# ---------------------------------------------------------------------------
# Properties
# ---------------------------------------------------------------------------
@property
def buffer_capacity(self) -> int:
"""Current number of slots allocated per (batch, head) pair.
Derived directly from self.keys rather than tracked separately, so it is
always consistent with the actual buffer after expansion.
"""
return self.keys.shape[2]
# ---------------------------------------------------------------------------
# Primary API
# ---------------------------------------------------------------------------
def update( # type: ignore[override]
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
active_mask: torch.Tensor,
cache_kwargs: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Scatter active key/value states using an explicit loop; return full cache state.
Iterates over every (b, l, t) triple. For each position where active_mask is
True, the key and value are written to the next available slot for that
(batch, head) pair and the count is incremented. Causal ordering is guaranteed
because the t dimension is traversed from 0 to T-1 and counts are updated
immediately after each write.
Buffer expansion (doubling buffer_capacity) is triggered before any writes if
the incoming tokens would cause any slot to overflow the current capacity.
Args:
key_states: Shape (B, L, T, u) β post-RoPE key vectors in expert-choice layout.
value_states: Shape (B, L, T, u) β value vectors in expert-choice layout.
active_mask: Shape (B, L, T) bool β True for real tokens, False for padding.
cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
Returns:
Tuple of (keys, values, active_mask):
keys: (B, L, T, u) float β full key buffer including junk slots.
values: (B, L, T, u) float β full value buffer including junk slots.
active_mask: (B, L, T) bool β True iff slot (b, l, t) has been written.
"""
B, L, T = active_mask.shape
# Expansion check uses the total active tokens per slot, same as the
# vectorized implementation, so both expand under identical conditions.
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
self._expand()
# Write each active position into the next available slot for its (batch, head)
# pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
for b in range(B):
for l in range(L):
for t in range(T):
if active_mask[b, l, t]:
pos = self._counts[b, l].item()
self.keys[b, l, pos, :] = key_states[b, l, t, :]
self.values[b, l, pos, :] = value_states[b, l, t, :]
self._counts[b, l] += 1
return self.keys, self.values, self._make_active_mask()
def get_heads_lengths(self) -> torch.Tensor:
"""Return the per-(batch, head) token count for this layer.
This is the authoritative occupancy tensor consumed by BEA for attention
masking and by position computation (Unit 10.A) for semantic-sequence
position computation.
Returns:
Integer tensor of shape (B, L) where entry [b, h] is the number of valid
tokens stored in the (b, h) slot. Zero for slots with no writes yet.
"""
return self._counts
# ---------------------------------------------------------------------------
# CacheLayerMixin β overridden coordination methods
# ---------------------------------------------------------------------------
def reset(self) -> None:
"""Clear all cached key and value tensors.
Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
and is_initialized remains True β only the contents are cleared.
"""
self.keys.zero_()
self.values.zero_()
self._counts.zero_()
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
"""Reorder the batch dimension of all cached tensors for beam search.
Applied atomically across self.keys, self.values, and _counts. Beam search
must reorder all three together or the occupancy counts and buffer contents
will correspond to different beam hypotheses.
Overrides the parent because the parent's implementation calls get_seq_length(),
which is not supported for this cache.
Args:
beam_idx: Permutation indices of shape (batch,) produced by the beam
search algorithm.
"""
self.keys = self.keys[beam_idx]
self.values = self.values[beam_idx]
self._counts = self._counts[beam_idx]
def batch_repeat_interleave(self, repeats: int) -> None:
"""Expand the batch dimension by repeating each entry repeats times.
Used at beam search initialisation to expand the cache from batch size B to
B * repeats, matching the expanded beam candidate batch. Applied atomically
across keys, values, and _counts; batch_size is updated to reflect the new size.
Args:
repeats: Number of times to repeat each batch entry.
"""
self.keys = self.keys.repeat_interleave(repeats, dim=0)
self.values = self.values.repeat_interleave(repeats, dim=0)
self._counts = self._counts.repeat_interleave(repeats, dim=0)
self.batch_size = self.batch_size * repeats
def batch_select_indices(self, indices: torch.Tensor) -> None:
"""Select a subset of batch entries by index.
Used in contrastive search to retain only the selected candidate entries.
Applied atomically across keys, values, and _counts; batch_size is updated
to reflect the number of retained entries.
Args:
indices: 1-D integer tensor of batch indices to retain.
"""
self.keys = self.keys[indices]
self.values = self.values[indices]
self._counts = self._counts[indices]
self.batch_size = indices.shape[0]
def offload(self) -> None:
"""Offload all cached tensors to CPU.
Extends the parent to also offload _counts, which the parent does not know
about. All three tensors are moved atomically so device state remains consistent.
"""
super().offload()
self._counts = self._counts.to("cpu", non_blocking=True)
def prefetch(self) -> None:
"""Move all cached tensors back to the model device ahead of time.
Extends the parent to also prefetch _counts, which the parent does not know
about. _counts is synced to self.keys.device after the parent moves keys and
values, so all three remain consistent.
"""
super().prefetch()
if self._counts.device != self.keys.device:
self._counts = self._counts.to(self.keys.device, non_blocking=True)
def lazy_initialization( # type: ignore[override]
self, key_states: torch.Tensor, value_states: torch.Tensor
) -> None:
"""No-op β storage is fully allocated at construction time."""
pass
# ---------------------------------------------------------------------------
# CacheLayerMixin β unsupported abstract methods
# ---------------------------------------------------------------------------
def get_seq_length(self) -> int: # type: ignore[override]
"""Not supported β no single sequence length represents this cache's state.
MoSRAH heads accumulate independently; (batch, head) slots have different
lengths depending on routing history. There is no meaningful scalar summary.
Use get_heads_lengths() for per-head occupancy.
"""
raise NotImplementedError(
"SlowMoSRAHCache has no single sequence length. "
"Use get_heads_lengths() for per-head occupancy."
)
def get_max_cache_shape(self) -> int: # type: ignore[override]
"""Not supported β SlowMoSRAHCache is dynamic and unbounded."""
raise NotImplementedError(
"SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
)
def get_mask_sizes( # type: ignore[override]
self,
cache_position: torch.Tensor,
) -> tuple[int, int]:
"""Not supported β SlowMoSRAHCache does not participate in HF mask construction."""
raise NotImplementedError(
"SlowMoSRAHCache does not support get_mask_sizes()."
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _make_active_mask(self) -> torch.Tensor:
"""Construct the (B, L, T) active mask from current counts.
Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
has been written. Positions at or beyond the count are junk and must be
excluded by downstream attention.
"""
cap = self.buffer_capacity
return (
torch.arange(cap, device=self.keys.device)
.expand(self.batch_size, self.num_mosrah_heads, cap)
< self._counts.unsqueeze(-1)
)
def _expand(self) -> None:
"""Double the buffer capacity, preserving existing data.
Called by update() when an incoming batch of tokens would cause any
(batch, head) slot to exceed the current buffer capacity. All existing
key and value data is copied into the low half of the new buffer; the
high half is zero-initialised and will be filled by subsequent writes.
After reassignment, buffer_capacity reflects the new size automatically.
"""
old_cap = self.buffer_capacity
new_cap = old_cap * 2
dev = self.keys.device
new_keys = torch.zeros(
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
)
new_values = torch.zeros(
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
)
new_keys[:, :, :old_cap, :] = self.keys
new_values[:, :, :old_cap, :] = self.values
self.keys = new_keys
self.values = new_values
|