smithblack-0 commited on
Commit
a86502d
·
verified ·
1 Parent(s): 0c295ed

Update architecture and tokenizer

Browse files
README.md CHANGED
@@ -79,13 +79,15 @@ contains no weights. All values are overridable via kwargs.
79
  | `attention_dropout` | 0.0 |
80
  | `beta` | 32.0 |
81
  | `dtype` | None |
 
82
  | `head_dim` | 16 |
83
- | `hidden_size` | 512 |
84
  | `inference_sequence_length` | 1024 |
85
- | `intermediate_size` | 1366 |
86
  | `local_rope_theta` | 10000.0 |
 
 
87
  | `mosrah_rope_theta` | 10000.0 |
88
- | `num_hidden_layers` | 12 |
89
  | `num_mosrah_heads` | 16 |
90
  | `num_selected_heads` | 16 |
91
  | `num_sliding_window_heads` | 16 |
 
79
  | `attention_dropout` | 0.0 |
80
  | `beta` | 32.0 |
81
  | `dtype` | None |
82
+ | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
 
84
  | `inference_sequence_length` | 1024 |
85
+ | `load_balance_p` | 2.0 |
86
  | `local_rope_theta` | 10000.0 |
87
+ | `mlp_width` | 1366 |
88
+ | `mosrah_overallocation_factor` | 2.0 |
89
  | `mosrah_rope_theta` | 10000.0 |
90
+ | `num_decoder_layers` | 12 |
91
  | `num_mosrah_heads` | 16 |
92
  | `num_selected_heads` | 16 |
93
  | `num_sliding_window_heads` | 16 |
__attention__bottlenecked_ensemble_attention.py CHANGED
@@ -38,14 +38,14 @@ class BottleneckedEnsembleAttention(nn.Module):
38
 
39
  Args:
40
  config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
- `head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
42
- `inference_sequence_length`, `alpha`, and `beta`.
43
  """
44
 
45
  def __init__(self, config: ShramConfig) -> None:
46
  super().__init__()
47
 
48
- self.hidden_size = config.hidden_size
49
  self.num_heads = config.num_mosrah_heads
50
  self.head_dim = config.head_dim
51
 
@@ -68,11 +68,22 @@ class BottleneckedEnsembleAttention(nn.Module):
68
  # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
  # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
  # no yarn dilation occurs.
 
 
 
 
 
 
 
 
 
 
 
71
  self.rope = RotaryEmbedding(
72
  mode="yarn",
73
  head_dim=self.head_dim,
74
  theta=config.mosrah_rope_theta,
75
- initial_seq_length=config.training_sequence_length,
76
  dilation=config.scale,
77
  alpha=config.alpha,
78
  beta=config.beta,
 
38
 
39
  Args:
40
  config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
+ `head_dim`, `mosrah_rope_theta`, `inference_sequence_length`,
42
+ `scale`, `alpha`, and `beta`.
43
  """
44
 
45
  def __init__(self, config: ShramConfig) -> None:
46
  super().__init__()
47
 
48
+ self.hidden_size = config.embedding_width
49
  self.num_heads = config.num_mosrah_heads
50
  self.head_dim = config.head_dim
51
 
 
68
  # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
  # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
  # no yarn dilation occurs.
71
+ #
72
+ # The required table size depends on position semantics:
73
+ # main_sequence — positions are original token positions, bounded by
74
+ # inference_sequence_length.
75
+ # semantic_sequence — positions are local per-expert slot indices, bounded
76
+ # by mosrah_packed_length.
77
+ maximum_rope_length = (
78
+ config.mosrah_packed_length
79
+ if config.rope_mode == "semantic_sequence"
80
+ else config.inference_sequence_length
81
+ )
82
  self.rope = RotaryEmbedding(
83
  mode="yarn",
84
  head_dim=self.head_dim,
85
  theta=config.mosrah_rope_theta,
86
+ maximum_sequence_length=maximum_rope_length,
87
  dilation=config.scale,
88
  alpha=config.alpha,
89
  beta=config.beta,
__attention__expert_packing.py CHANGED
@@ -3,26 +3,30 @@
3
  This module implements the low-level token-choice -> expert-choice -> token-choice
4
  conversion boundary specified in the paper. The externally visible behavior is fixed:
5
 
6
- - setup_packing() prepares the auxiliary ordering data.
7
- - pack_experts() converts routed token-choice state into packed expert-choice state.
 
 
 
8
  - unpack_experts() restores token-choice ordering afterward.
9
 
10
  Stable sort is a correctness requirement. It preserves causal ordering inside each
11
  expert bucket, which is the foundation on which BEA's later triangular causal mask
12
  is correct.
13
 
14
- pack_experts() returns two distinct masks that serve different roles and must not
15
- be interchanged:
16
 
17
  - unpacking_mask: marks every packed slot that contains a routed token copy,
18
  live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
19
  so its reshape invariant holds regardless of outer token liveness.
20
- - active_mask: marks only the packed slots whose source token was semantically
21
- live. This is what BEA consumes for attention gating. Dead outer tokens must
22
- not influence sparse attention outputs.
23
  """
24
 
25
  import torch
 
26
 
27
 
28
  # ---------------------------------------------------------------------------
@@ -31,7 +35,7 @@ import torch
31
 
32
  def setup_packing(
33
  selected_heads: torch.Tensor,
34
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
  """Prepare the auxiliary ordering data used by pack/unpack.
36
 
37
  Routing produces token-choice state I of shape (B, N, K): for each token, which
@@ -48,10 +52,11 @@ def setup_packing(
48
  selected_heads: Routed token-choice head selections I of shape (B, N, K).
49
 
50
  Returns:
51
- Tuple of:
52
- - flattened_selected_heads: H of shape (B, N*K)
53
- - permutation: stable expert-major permutation Pi of shape (B, N*K)
54
- - inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
 
55
  """
56
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
57
  flattened_selected_heads = selected_heads.reshape(
@@ -62,7 +67,11 @@ def setup_packing(
62
  permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
63
  inverse_permutation = torch.argsort(permutation, dim=-1)
64
 
65
- return flattened_selected_heads, permutation, inverse_permutation
 
 
 
 
66
 
67
 
68
  # ---------------------------------------------------------------------------
@@ -70,27 +79,22 @@ def setup_packing(
70
  # ---------------------------------------------------------------------------
71
 
72
  def pack_experts(
73
- hidden_states: torch.Tensor,
74
- position_ids: torch.Tensor,
75
  selected_heads: torch.Tensor,
76
  num_experts: int,
77
- flattened_selected_heads: torch.Tensor,
78
- permutation: torch.Tensor,
79
- outer_active_mask: torch.Tensor,
80
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
- """Pack token-choice hidden states into expert-choice padded form.
82
 
83
  The paper's packing path has two jobs:
84
 
85
  1. Convert routed token-choice copies into expert-major order.
86
  2. Materialize that expert-major order into a padded tensor layout BEA can consume.
87
 
88
- The routed hidden-state copies are not stored explicitly in token-choice form.
89
- Instead, the same token hidden state is conceptually copied once per selected expert.
90
- The packing step reconstructs those copies by expanding local source-token indices,
91
- reordering those indices with Pi, then gathering hidden states, positions, and outer
92
- liveness in that packed order. All three are carried through the same expert-major
93
- rearrangement so they remain aligned in the packed frame.
94
 
95
  Packed positions are sourced from the authoritative upstream position_ids tensor
96
  rather than synthesized locally from arange(N). This preserves advanced positions
@@ -98,40 +102,40 @@ def pack_experts(
98
  unchanged when position_ids is the ordinary sequential token positions.
99
 
100
  Args:
101
- hidden_states: Token-choice hidden states x of shape (B, N, d).
102
- position_ids: Authoritative upstream token positions J of shape (B, N).
 
 
103
  selected_heads: Routed head selections I of shape (B, N, K).
104
  num_experts: Total number of experts L.
105
- flattened_selected_heads: H from setup_packing(), shape (B, N*K).
106
- permutation: Pi from setup_packing(), shape (B, N*K).
107
- outer_active_mask: Current-chunk active mask of shape (B, N), where True
108
- means the token is semantically live. Dead tokens do not become
109
- semantically active in the packed sparse representation.
110
 
111
  Returns:
112
  Tuple of:
113
- - packed_hidden_states: x' of shape (B, L, T, d)
114
- - packed_positions: J' of shape (B, L, T)
115
- - unpacking_mask: of shape (B, L, T). True where a slot contains any
116
- routed token copy, live or dead. Always has exactly B*N*K True entries.
117
- Pass this to unpack_experts — not active_mask.
118
- - active_mask: of shape (B, L, T). True only where a slot contains a
119
- copy of a live outer token. Pass this to BEA for attention gating.
120
  """
121
- batch_size, sequence_length, hidden_dim = hidden_states.shape
122
- _, _, num_selected_heads = selected_heads.shape
 
 
123
 
124
  # -----------------------------------------------------------------------
125
  # Reconstruct routed local source-token indices in token-choice order.
126
  #
127
- # The internal arange(N) is no longer the packed position tensor. It is only
128
- # the local source-row index object used to gather from the current chunk
129
- # tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
130
- # token-major routed-copy order.
131
  # -----------------------------------------------------------------------
132
  source_token_indices = torch.arange(
133
  sequence_length,
134
- device=hidden_states.device,
135
  dtype=torch.long,
136
  ).view(1, sequence_length, 1).expand(
137
  batch_size,
@@ -147,89 +151,71 @@ def pack_experts(
147
  # Reorder source-token indices into expert-major order.
148
  #
149
  # Applying Pi yields the local source-token rows in the packed expert-major
150
- # order required by the paper. Those same reordered source indices are then
151
- # used to gather hidden states, authoritative upstream positions, and outer
152
- # liveness so all three remain aligned under the exact same packing
153
- # transformation.
154
  # -----------------------------------------------------------------------
155
  sorted_source_indices = flattened_source_indices.gather(
156
  dim=1,
157
  index=permutation,
158
  )
159
- sorted_hidden_states = hidden_states.gather(
160
- dim=1,
161
- index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
162
- )
163
- sorted_positions = position_ids.gather(
164
- dim=1,
165
- index=sorted_source_indices,
166
- )
167
- sorted_active_mask = outer_active_mask.gather(
168
- dim=1,
169
- index=sorted_source_indices,
170
- )
171
 
172
  # -----------------------------------------------------------------------
173
- # Count how many routed copies land in each expert bucket.
 
174
  #
175
- # S[b, l] is the number of routed token copies assigned to expert l in batch b.
176
- # T is the maximum such count across all batches and experts; it determines the
177
- # padded expert-length dimension of the packed representation.
 
178
  # -----------------------------------------------------------------------
179
- tokens_per_expert = _bincount_rows(
180
- values=flattened_selected_heads,
181
- num_bins=num_experts,
182
- )
183
- max_tokens_per_expert = int(tokens_per_expert.max().item())
184
 
185
  # -----------------------------------------------------------------------
186
- # Construct the active-token mask M.
187
  #
188
  # Each expert bucket is left-justified: if S[b, l] = s, then slots
189
- # t = 0, ..., s-1 are active and all later slots are padding. The resulting
190
- # mask therefore both identifies real packed tokens and enforces left-justified
191
- # packing. This is the unpacking_mask — it marks slot occupancy regardless of
192
- # outer token liveness, and always has exactly B*N*K True entries.
193
  # -----------------------------------------------------------------------
194
  time_axis = torch.arange(
195
- max_tokens_per_expert,
196
- device=hidden_states.device,
197
  dtype=torch.long,
198
- ).view(1, 1, max_tokens_per_expert)
199
  unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
200
 
201
  # -----------------------------------------------------------------------
202
- # Materialize the padded packed tensors.
203
  #
204
- # The packed hidden states x', packed original-token positions J', and packed
205
- # active-token mask are allocated as zero-filled tensors. Active entries are
206
- # then written into those buffers in the expert-major order established above.
207
- # Padding remains zero / inactive.
208
  # -----------------------------------------------------------------------
209
- packed_hidden_states = hidden_states.new_zeros(
210
- batch_size,
211
- num_experts,
212
- max_tokens_per_expert,
213
- hidden_dim,
214
- )
215
- packed_positions = position_ids.new_zeros(
216
- batch_size,
217
- num_experts,
218
- max_tokens_per_expert,
219
- )
220
- active_mask = torch.zeros(
221
- batch_size,
222
- num_experts,
223
- max_tokens_per_expert,
224
- dtype=torch.bool,
225
- device=hidden_states.device,
226
- )
227
 
228
- packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
229
- packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
230
- active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
 
 
 
231
 
232
- return packed_hidden_states, packed_positions, unpacking_mask, active_mask
233
 
234
 
235
  # ---------------------------------------------------------------------------
@@ -238,9 +224,9 @@ def pack_experts(
238
 
239
  def unpack_experts(
240
  expert_outputs: torch.Tensor,
241
- selected_heads: torch.Tensor,
242
  unpacking_mask: torch.Tensor,
243
- inverse_permutation: torch.Tensor,
244
  ) -> torch.Tensor:
245
  """Restore token-choice ordering from BEA expert-choice output.
246
 
@@ -257,14 +243,16 @@ def unpack_experts(
257
 
258
  Args:
259
  expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
260
- selected_heads: Routed head selections I of shape (B, N, K).
261
  unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
262
  occupied packed slots regardless of outer token liveness.
263
- inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
264
 
265
  Returns:
266
  Restored token-choice tensor y_tilde of shape (B, N, K, d).
267
  """
 
 
268
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
269
  hidden_dim = expert_outputs.shape[-1]
270
 
@@ -291,45 +279,63 @@ def unpack_experts(
291
  # Helpers
292
  # ---------------------------------------------------------------------------
293
 
294
- def _bincount_rows(
295
- values: torch.Tensor,
296
- num_bins: int,
297
- ) -> torch.Tensor:
298
- """Count per-row integer occurrences for a 2D tensor.
299
 
300
- torch.bincount operates on a flat 1D vector, but the packing algorithm needs
301
- one bincount per batch row. The trick used here is to shift each row into its
302
- own disjoint bin range before flattening:
 
303
 
304
- row 0 uses bins [0, ..., num_bins - 1]
305
- row 1 uses bins [num_bins, ..., 2*num_bins - 1]
306
- row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
307
- ...
308
 
309
- After that shift, one global torch.bincount produces all row-local counts at
310
- once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- This is a vectorized implementation detail only; externally visible behavior
313
- remains exactly the paper's S tensor of per-batch per-expert token counts.
 
314
 
315
  Args:
316
- values: Integer tensor of shape (B, M) with entries in [0, num_bins).
317
- num_bins: Number of bins.
 
318
 
319
  Returns:
320
- Counts tensor of shape (B, num_bins).
321
  """
322
- batch_size = values.shape[0]
323
-
324
- row_offsets = torch.arange(
325
  batch_size,
326
- device=values.device,
327
- dtype=values.dtype,
328
- ).unsqueeze(1) * num_bins
329
- shifted_values = values + row_offsets
330
-
331
- counts = torch.bincount(
332
- shifted_values.reshape(-1),
333
- minlength=batch_size * num_bins,
334
  )
335
- return counts.reshape(batch_size, num_bins)
 
3
  This module implements the low-level token-choice -> expert-choice -> token-choice
4
  conversion boundary specified in the paper. The externally visible behavior is fixed:
5
 
6
+ - setup_packing() prepares the auxiliary ordering data and returns it as a dict
7
+ payload forwarded whole to pack_experts and unpack_experts.
8
+ - pack_experts() converts a dict of routed token-choice tensors into packed
9
+ expert-choice form. Each entry is paired with its intended padding value; all
10
+ entries undergo the same expert-major gather-scatter so they remain aligned.
11
  - unpack_experts() restores token-choice ordering afterward.
12
 
13
  Stable sort is a correctness requirement. It preserves causal ordering inside each
14
  expert bucket, which is the foundation on which BEA's later triangular causal mask
15
  is correct.
16
 
17
+ pack_experts() returns the packed entries dict together with a separate unpacking_mask.
18
+ Two masks serve different roles and must not be interchanged:
19
 
20
  - unpacking_mask: marks every packed slot that contains a routed token copy,
21
  live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
22
  so its reshape invariant holds regardless of outer token liveness.
23
+ - active_mask (caller-supplied entry): marks only the packed slots whose source
24
+ token was semantically live. This is what BEA consumes for attention gating.
25
+ Dead outer tokens must not influence sparse attention outputs.
26
  """
27
 
28
  import torch
29
+ from typing import Any
30
 
31
 
32
  # ---------------------------------------------------------------------------
 
35
 
36
  def setup_packing(
37
  selected_heads: torch.Tensor,
38
+ ) -> dict[str, torch.Tensor]:
39
  """Prepare the auxiliary ordering data used by pack/unpack.
40
 
41
  Routing produces token-choice state I of shape (B, N, K): for each token, which
 
52
  selected_heads: Routed token-choice head selections I of shape (B, N, K).
53
 
54
  Returns:
55
+ Auxiliary payload dict with keys:
56
+ - "flattened_selected_heads": H of shape (B, N*K)
57
+ - "permutation": stable expert-major permutation Pi of shape (B, N*K)
58
+ - "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K)
59
+ This dict is forwarded whole to pack_experts and unpack_experts.
60
  """
61
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
62
  flattened_selected_heads = selected_heads.reshape(
 
67
  permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
68
  inverse_permutation = torch.argsort(permutation, dim=-1)
69
 
70
+ return {
71
+ "flattened_selected_heads": flattened_selected_heads,
72
+ "permutation": permutation,
73
+ "inverse_permutation": inverse_permutation,
74
+ }
75
 
76
 
77
  # ---------------------------------------------------------------------------
 
79
  # ---------------------------------------------------------------------------
80
 
81
  def pack_experts(
82
+ entries: dict[str, tuple[torch.Tensor, Any]],
83
+ setup: dict[str, torch.Tensor],
84
  selected_heads: torch.Tensor,
85
  num_experts: int,
86
+ packed_length: int,
87
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
88
+ """Pack token-choice tensors into expert-choice padded form.
 
 
89
 
90
  The paper's packing path has two jobs:
91
 
92
  1. Convert routed token-choice copies into expert-major order.
93
  2. Materialize that expert-major order into a padded tensor layout BEA can consume.
94
 
95
+ All entries in the provided dict undergo the same expert-major gather-scatter so
96
+ they remain mutually aligned in the packed frame. Each entry is paired with its
97
+ intended padding value, which fills slots that contain no routed token copy.
 
 
 
98
 
99
  Packed positions are sourced from the authoritative upstream position_ids tensor
100
  rather than synthesized locally from arange(N). This preserves advanced positions
 
102
  unchanged when position_ids is the ordinary sequential token positions.
103
 
104
  Args:
105
+ entries: Mapping from string keys to (tensor, padding_value) pairs. Each
106
+ tensor has shape (B, N, ...) and is rearranged into expert-choice layout
107
+ (B, L, T, ...). The returned dict carries the same keys.
108
+ setup: Auxiliary payload returned by setup_packing().
109
  selected_heads: Routed head selections I of shape (B, N, K).
110
  num_experts: Total number of experts L.
111
+ packed_length: Static packed time dimension T. All per-expert buffers are
112
+ allocated to exactly this length. Use config.mosrah_packed_length as the
113
+ source of this value. Raises if any actual per-expert token count exceeds
114
+ this value.
 
115
 
116
  Returns:
117
  Tuple of:
118
+ - packed_entries: Dict with same keys as entries; each value is the
119
+ packed tensor of shape (B, L, T, ...).
120
+ - unpacking_mask: Boolean tensor of shape (B, L, T). True where a slot
121
+ contains any routed token copy, live or dead. Always has exactly
122
+ B*N*K True entries. Pass this to unpack_experts — not active_mask.
 
 
123
  """
124
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
125
+
126
+ flattened_selected_heads = setup["flattened_selected_heads"]
127
+ permutation = setup["permutation"]
128
 
129
  # -----------------------------------------------------------------------
130
  # Reconstruct routed local source-token indices in token-choice order.
131
  #
132
+ # The internal arange(N) is only the local source-row index object used to
133
+ # gather from the current chunk tensors. Flattening gives a (B, N*K) tensor
134
+ # aligned with H's token-major routed-copy order.
 
135
  # -----------------------------------------------------------------------
136
  source_token_indices = torch.arange(
137
  sequence_length,
138
+ device=flattened_selected_heads.device,
139
  dtype=torch.long,
140
  ).view(1, sequence_length, 1).expand(
141
  batch_size,
 
151
  # Reorder source-token indices into expert-major order.
152
  #
153
  # Applying Pi yields the local source-token rows in the packed expert-major
154
+ # order required by the paper. All entries are then gathered using these same
155
+ # reordered indices so they remain aligned under the exact same transformation.
 
 
156
  # -----------------------------------------------------------------------
157
  sorted_source_indices = flattened_source_indices.gather(
158
  dim=1,
159
  index=permutation,
160
  )
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  # -----------------------------------------------------------------------
163
+ # Count how many routed copies land in each expert bucket and verify
164
+ # that no bucket exceeds the statically preallocated packed_length T.
165
  #
166
+ # S[b, l] is the number of routed token copies assigned to expert l in
167
+ # batch b. T (packed_length) is a static allocation derived from config,
168
+ # not a data-dependent maximum. Overflow is detected here and raises in
169
+ # both eager and compiled modes.
170
  # -----------------------------------------------------------------------
171
+ tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts)
172
+ max_count = tokens_per_expert.max().item()
173
+ no_overflow = max_count <= packed_length
174
+ _enforce_no_overflow(no_overflow)
 
175
 
176
  # -----------------------------------------------------------------------
177
+ # Construct the unpacking mask.
178
  #
179
  # Each expert bucket is left-justified: if S[b, l] = s, then slots
180
+ # t = 0, ..., s-1 are occupied and all later slots are padding. The mask
181
+ # marks slot occupancy regardless of outer token liveness, and always has
182
+ # exactly B*N*K True entries.
 
183
  # -----------------------------------------------------------------------
184
  time_axis = torch.arange(
185
+ packed_length,
186
+ device=flattened_selected_heads.device,
187
  dtype=torch.long,
188
+ ).view(1, 1, packed_length)
189
  unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
190
 
191
  # -----------------------------------------------------------------------
192
+ # Materialize all entries into the packed expert-choice frame.
193
  #
194
+ # Each entry is gathered using the expert-major sorted source indices, then
195
+ # scattered into a padded buffer. The gather index is expanded to cover each
196
+ # tensor's trailing dimensions. Padding slots receive the caller-supplied fill
197
+ # value rather than an implicit zero.
198
  # -----------------------------------------------------------------------
199
+ packed_entries: dict[str, torch.Tensor] = {}
200
+ for key, (tensor, padding_value) in entries.items():
201
+ extra_shape = tensor.shape[2:]
202
+
203
+ # Expand gather index to cover trailing dimensions, if any.
204
+ idx = sorted_source_indices.view(
205
+ batch_size,
206
+ sequence_length * num_selected_heads,
207
+ *(1,) * len(extra_shape),
208
+ ).expand(-1, -1, *extra_shape)
209
+ sorted_tensor = tensor.gather(dim=1, index=idx)
 
 
 
 
 
 
 
210
 
211
+ packed_tensor = tensor.new_full(
212
+ (batch_size, num_experts, packed_length, *extra_shape),
213
+ fill_value=padding_value,
214
+ )
215
+ packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
216
+ packed_entries[key] = packed_tensor
217
 
218
+ return packed_entries, unpacking_mask
219
 
220
 
221
  # ---------------------------------------------------------------------------
 
224
 
225
  def unpack_experts(
226
  expert_outputs: torch.Tensor,
227
+ setup: dict[str, torch.Tensor],
228
  unpacking_mask: torch.Tensor,
229
+ selected_heads: torch.Tensor,
230
  ) -> torch.Tensor:
231
  """Restore token-choice ordering from BEA expert-choice output.
232
 
 
243
 
244
  Args:
245
  expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
246
+ setup: Auxiliary payload returned by setup_packing().
247
  unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
248
  occupied packed slots regardless of outer token liveness.
249
+ selected_heads: Routed head selections I of shape (B, N, K).
250
 
251
  Returns:
252
  Restored token-choice tensor y_tilde of shape (B, N, K, d).
253
  """
254
+ inverse_permutation = setup["inverse_permutation"]
255
+
256
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
257
  hidden_dim = expert_outputs.shape[-1]
258
 
 
279
  # Helpers
280
  # ---------------------------------------------------------------------------
281
 
282
+ def _enforce_no_overflow(condition: bool) -> None:
283
+ """Enforce that no expert bucket exceeds the preallocated packed length.
 
 
 
284
 
285
+ This check fires when the number of tokens assigned to any expert in any
286
+ batch item exceeds mosrah_packed_length. When that limit is exceeded, the
287
+ packed buffer is too small to hold all assignments and data would be dropped.
288
+ Increase mosrah_overallocation_factor in ShramConfig to resolve.
289
 
290
+ The caller must derive condition via .item() on the max count tensor so that
291
+ dynamo captures a SymInt and the comparison produces a SymBool. Passing a
292
+ tensor comparison result directly bypasses the SymInt mechanism and prevents
293
+ the check from firing at compiled runtime.
294
 
295
+ Args:
296
+ condition: True means no overflow has occurred; False means at least one
297
+ expert bucket exceeds packed_length. In compiled mode this is a SymBool
298
+ produced by comparing a SymInt against the static packed_length.
299
+ """
300
+ if torch.compiler.is_compiling():
301
+ torch._check(condition)
302
+ else:
303
+ if not condition:
304
+ raise RuntimeError(
305
+ "Expert packing overflow: at least one expert bucket contains more "
306
+ "tokens than mosrah_packed_length allows. Increase "
307
+ "mosrah_overallocation_factor in ShramConfig to resolve."
308
+ )
309
+
310
+
311
+ def _count_tokens_per_expert(
312
+ flattened_selected_heads: torch.Tensor,
313
+ num_experts: int,
314
+ ) -> torch.Tensor:
315
+ """Count how many routed token copies are assigned to each expert per batch item.
316
 
317
+ Uses scatter_add into a pre-sized (B, num_experts) zero buffer, producing a
318
+ statically-shaped output that compiles without graph breaks. Each position in
319
+ flattened_selected_heads contributes one count to the corresponding expert slot.
320
 
321
  Args:
322
+ flattened_selected_heads: Expert assignments of shape (B, N*K) with values
323
+ in [0, num_experts).
324
+ num_experts: Total number of experts L.
325
 
326
  Returns:
327
+ Counts tensor of shape (B, num_experts).
328
  """
329
+ batch_size = flattened_selected_heads.shape[0]
330
+ counts = torch.zeros(
 
331
  batch_size,
332
+ num_experts,
333
+ device=flattened_selected_heads.device,
334
+ dtype=flattened_selected_heads.dtype,
335
+ )
336
+ counts.scatter_add_(
337
+ dim=1,
338
+ index=flattened_selected_heads,
339
+ src=torch.ones_like(flattened_selected_heads),
340
  )
341
+ return counts
__attention__mosrah.py CHANGED
@@ -40,6 +40,7 @@ class MoSRAHLayer(nn.Module):
40
  def __init__(self, config: ShramConfig) -> None:
41
  super().__init__()
42
  self.num_experts = config.num_mosrah_heads
 
43
 
44
  self.router = MoSRAHRouter(config)
45
  self.positions = SparseMoSRAHPositions(config)
@@ -91,18 +92,16 @@ class MoSRAHLayer(nn.Module):
91
  hidden_states, active_mask
92
  )
93
 
94
- flattened_selected_heads, permutation, inverse_permutation = setup_packing(
95
- selected_heads
96
- )
97
- packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
98
- hidden_states=hidden_states,
99
- position_ids=position_ids,
100
- selected_heads=selected_heads,
101
- num_experts=self.num_experts,
102
- flattened_selected_heads=flattened_selected_heads,
103
- permutation=permutation,
104
- outer_active_mask=active_mask,
105
- )
106
 
107
  # -------------------------------------------------------------------
108
  # Sparse attention runs entirely in the packed expert-choice frame, so
@@ -114,6 +113,7 @@ class MoSRAHLayer(nn.Module):
114
  # -------------------------------------------------------------------
115
  bea_positions = self.positions(
116
  packed_positions=packed_positions,
 
117
  cache=cache,
118
  )
119
  packed_outputs = self.bea(
@@ -133,9 +133,9 @@ class MoSRAHLayer(nn.Module):
133
  # -------------------------------------------------------------------
134
  token_choice_outputs = unpack_experts(
135
  expert_outputs=packed_outputs,
136
- selected_heads=selected_heads,
137
  unpacking_mask=unpacking_mask,
138
- inverse_permutation=inverse_permutation,
139
  )
140
  final_output = (
141
  token_choice_outputs * routing_probs.unsqueeze(-1)
 
40
  def __init__(self, config: ShramConfig) -> None:
41
  super().__init__()
42
  self.num_experts = config.num_mosrah_heads
43
+ self.packed_length = config.mosrah_packed_length
44
 
45
  self.router = MoSRAHRouter(config)
46
  self.positions = SparseMoSRAHPositions(config)
 
92
  hidden_states, active_mask
93
  )
94
 
95
+ setup = setup_packing(selected_heads)
96
+ entries = {
97
+ "hidden_states": (hidden_states, 0.0),
98
+ "position_ids": (position_ids, 0),
99
+ "active_mask": (active_mask, False),
100
+ }
101
+ packed, unpacking_mask = pack_experts(entries, setup, selected_heads, self.num_experts, self.packed_length)
102
+ packed_hidden_states = packed["hidden_states"]
103
+ packed_positions = packed["position_ids"]
104
+ active_mask = packed["active_mask"]
 
 
105
 
106
  # -------------------------------------------------------------------
107
  # Sparse attention runs entirely in the packed expert-choice frame, so
 
113
  # -------------------------------------------------------------------
114
  bea_positions = self.positions(
115
  packed_positions=packed_positions,
116
+ active_mask=active_mask,
117
  cache=cache,
118
  )
119
  packed_outputs = self.bea(
 
133
  # -------------------------------------------------------------------
134
  token_choice_outputs = unpack_experts(
135
  expert_outputs=packed_outputs,
136
+ setup=setup,
137
  unpacking_mask=unpacking_mask,
138
+ selected_heads=selected_heads,
139
  )
140
  final_output = (
141
  token_choice_outputs * routing_probs.unsqueeze(-1)
__attention__positions_converter.py CHANGED
@@ -32,12 +32,17 @@ class SparseMoSRAHPositions(nn.Module):
32
  def forward(
33
  self,
34
  packed_positions: torch.Tensor,
 
35
  cache: MoSRAHCache | None,
36
  ) -> torch.Tensor:
37
  """Compute the packed position tensor P consumed by BEA.
38
 
39
  Args:
40
  packed_positions: Packed original-token positions J' of shape (B, L, T).
 
 
 
 
41
  cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
42
  mode, the current per-head occupancies offset the local packed sequence.
43
 
@@ -45,14 +50,15 @@ class SparseMoSRAHPositions(nn.Module):
45
  Packed position tensor P of shape (B, L, T).
46
  """
47
  if self.rope_mode == "main_sequence":
48
- return self._main_sequence_positions(packed_positions)
49
-
50
- if self.rope_mode == "semantic_sequence":
51
- return self._semantic_sequence_positions(packed_positions, cache)
52
-
53
- raise NotImplementedError(
54
- f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
55
- )
 
56
 
57
  def _main_sequence_positions(
58
  self,
 
32
  def forward(
33
  self,
34
  packed_positions: torch.Tensor,
35
+ active_mask: torch.Tensor,
36
  cache: MoSRAHCache | None,
37
  ) -> torch.Tensor:
38
  """Compute the packed position tensor P consumed by BEA.
39
 
40
  Args:
41
  packed_positions: Packed original-token positions J' of shape (B, L, T).
42
+ active_mask: Boolean active-token mask of shape (B, L, T). Inactive
43
+ positions are zeroed in the returned tensor regardless of mode —
44
+ their position value is semantically irrelevant and 0 is guaranteed
45
+ to be within any valid RoPE table.
46
  cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
47
  mode, the current per-head occupancies offset the local packed sequence.
48
 
 
50
  Packed position tensor P of shape (B, L, T).
51
  """
52
  if self.rope_mode == "main_sequence":
53
+ positions = self._main_sequence_positions(packed_positions)
54
+ elif self.rope_mode == "semantic_sequence":
55
+ positions = self._semantic_sequence_positions(packed_positions, cache)
56
+ else:
57
+ raise NotImplementedError(
58
+ f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
59
+ )
60
+
61
+ return torch.where(active_mask, positions, torch.zeros_like(positions))
62
 
63
  def _main_sequence_positions(
64
  self,
__attention__router.py CHANGED
@@ -57,10 +57,11 @@ class MoSRAHRouter(nn.Module):
57
  super().__init__()
58
  self.num_mosrah_heads = config.num_mosrah_heads
59
  self.num_selected_heads = config.num_selected_heads
 
60
 
61
  # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
62
  self.routing_projection = nn.Linear(
63
- config.hidden_size, config.num_mosrah_heads, bias=False
64
  )
65
 
66
  # b: learned per-head bias for load balancing. Initialized to zero so that all
@@ -117,25 +118,31 @@ class MoSRAHRouter(nn.Module):
117
  gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
118
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
119
 
120
- # Routing frequency f_l: fraction of active (batch, token, head_slot) triples
121
- # assigned to each head. Dead tokens are excluded by zeroing their rows in the
122
- # assignment mask before reduction. Normalization uses the active assignment
123
- # count so frequencies remain properly scaled regardless of how many tokens
124
- # are live in this chunk.
125
  assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
  assignment_mask.scatter_(-1, selected_heads, 1.0)
127
  active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
- num_active_assignments = active_mask.sum() * K
129
- routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
 
 
 
 
 
 
 
 
130
 
131
  # Load balance loss via custom autograd. expert_bias is an input so PyTorch
132
  # registers it as a graph node; the custom backward writes the DeepSeek-style
133
  # correction gradient to expert_bias.grad for the optimizer to consume.
134
  load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
135
 
136
- # MaxVio is a detached monitoring scalar derived from routing_freqs. It must
137
- # not contribute gradients under any circumstance, so it is detached at the
138
- # point of computation rather than left to callers to detach.
139
  max_vio = self._compute_max_vio(routing_freqs, L)
140
 
141
  return selected_heads, routing_probs, load_balance_loss, max_vio
@@ -145,15 +152,16 @@ class MoSRAHRouter(nn.Module):
145
  """Compute the MaxVio routing-imbalance scalar.
146
 
147
  MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
148
- head l and 1/L is the perfectly balanced target. A value of zero indicates
149
- perfect balance; a value of 1 means the most overloaded head received exactly
150
- double its fair share.
 
151
 
152
  The result is detached from the autograd graph — MaxVio is a monitoring scalar
153
  and must never contribute gradients to any parameter.
154
 
155
  Args:
156
- routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
157
  num_heads: Total number of MoSRAH heads L.
158
 
159
  Returns:
 
57
  super().__init__()
58
  self.num_mosrah_heads = config.num_mosrah_heads
59
  self.num_selected_heads = config.num_selected_heads
60
+ self.load_balance_p = config.load_balance_p
61
 
62
  # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
63
  self.routing_projection = nn.Linear(
64
+ config.embedding_width, config.num_mosrah_heads, bias=False
65
  )
66
 
67
  # b: learned per-head bias for load balancing. Initialized to zero so that all
 
118
  gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
119
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
120
 
121
+ # Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
122
+ # fraction of that item's active K assignments over all tokens go to head l.
123
+ # Dead tokens are excluded before reduction. Normalization is per batch item so
124
+ # each item's frequencies sum to 1 independently of other items in the batch.
 
125
  assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
  assignment_mask.scatter_(-1, selected_heads, 1.0)
127
  active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
+ per_item_counts = active_assignments.sum(dim=1) # (B, L)
129
+ per_item_total = active_mask.sum(dim=1, keepdim=True) * K # (B, 1)
130
+ per_item_freqs = per_item_counts / per_item_total # (B, L)
131
+
132
+ # p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,).
133
+ # p-mean weights aggregation toward the worst-case batch item relative to
134
+ # arithmetic mean, making the load balance signal sensitive to per-item spikes
135
+ # that cause packing overflow.
136
+ p = self.load_balance_p
137
+ routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
138
 
139
  # Load balance loss via custom autograd. expert_bias is an input so PyTorch
140
  # registers it as a graph node; the custom backward writes the DeepSeek-style
141
  # correction gradient to expert_bias.grad for the optimizer to consume.
142
  load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
143
 
144
+ # MaxVio is a detached monitoring scalar following the paper's formula
145
+ # L · max_l(f_l 1/L) applied to routing_freqs. Must not contribute gradients.
 
146
  max_vio = self._compute_max_vio(routing_freqs, L)
147
 
148
  return selected_heads, routing_probs, load_balance_loss, max_vio
 
152
  """Compute the MaxVio routing-imbalance scalar.
153
 
154
  MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
155
+ head l and 1/L is the perfectly balanced target. Follows the paper's definition
156
+ (Wang et al.) applied to routing_freqs. A value of zero indicates perfect
157
+ balance; a value of 0.5 means the most overloaded head received 50% more routed
158
+ tokens than ideal.
159
 
160
  The result is detached from the autograd graph — MaxVio is a monitoring scalar
161
  and must never contribute gradients to any parameter.
162
 
163
  Args:
164
+ routing_freqs: Per-head routing frequencies of shape (L,).
165
  num_heads: Total number of MoSRAH heads L.
166
 
167
  Returns:
__attention__shram.py CHANGED
@@ -64,19 +64,6 @@ class SHRAMHybridLayer(nn.Module):
64
  max_vio: Detached scalar routing-imbalance summary. Passed through
65
  unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
66
  """
67
- # ------------------------------------------------
68
- # It is not possible, due to how bea constructs its block mask,
69
- # for the model to process a sequence that does not start at zero
70
- # without a cache to track the per-head offsets
71
- # ------------------------------------------------
72
-
73
- if cache is None and torch.any(position_ids[:, 0] != 0):
74
- raise ValueError(
75
- "Uncached SHRAMHybridLayer does not support nonzero starting positions. "
76
- "Either provide a matching ShramLayerCache populated by the prefix for "
77
- "continued decoding, or rebase the uncached sequence to start at 0."
78
- )
79
-
80
  # -------------------------------------------------------------------
81
  # The hybrid layer's first responsibility is cache dispatch. The layer
82
  # cache already owns the concrete sub-cache objects required by each
 
64
  max_vio: Detached scalar routing-imbalance summary. Passed through
65
  unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
66
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # -------------------------------------------------------------------
68
  # The hybrid layer's first responsibility is cache dispatch. The layer
69
  # cache already owns the concrete sub-cache objects required by each
__attention__sliding_window_attention.py CHANGED
@@ -44,7 +44,7 @@ class SlidingWindowAttention(nn.Module):
44
  def __init__(self, config: ShramConfig) -> None:
45
  super().__init__()
46
 
47
- self.hidden_size = config.hidden_size
48
  self.num_heads = config.num_sliding_window_heads
49
  self.head_dim = config.head_dim
50
  self.window_size = config.window_size
@@ -69,6 +69,7 @@ class SlidingWindowAttention(nn.Module):
69
  mode="default",
70
  head_dim=self.head_dim,
71
  theta=config.local_rope_theta,
 
72
  )
73
 
74
  def forward(
 
44
  def __init__(self, config: ShramConfig) -> None:
45
  super().__init__()
46
 
47
+ self.hidden_size = config.embedding_width
48
  self.num_heads = config.num_sliding_window_heads
49
  self.head_dim = config.head_dim
50
  self.window_size = config.window_size
 
69
  mode="default",
70
  head_dim=self.head_dim,
71
  theta=config.local_rope_theta,
72
+ maximum_sequence_length=config.inference_sequence_length,
73
  )
74
 
75
  def forward(
__cache__mosrah_cache.py CHANGED
@@ -61,12 +61,13 @@ class MoSRAHCache(CacheLayerMixin):
61
  batch_size: Number of sequences in the batch. Determines the first dimension
62
  of all storage tensors.
63
  device: Device on which to allocate all tensors. Should match the model device.
64
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
65
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
66
- during prompt processing.
 
67
  """
68
 
69
- is_compileable = False
70
  is_sliding = False
71
 
72
  def __init__(
@@ -75,22 +76,23 @@ class MoSRAHCache(CacheLayerMixin):
75
  head_dim: int,
76
  batch_size: int,
77
  device: torch.device,
78
- initial_buffer_size: int = 64,
79
  ) -> None:
80
  super().__init__()
81
  self.num_mosrah_heads = num_mosrah_heads
82
  self.head_dim = head_dim
83
  self.batch_size = batch_size
84
  self.device = device
 
85
 
86
  # Allocate primary storage into the mixin-standard self.keys / self.values so
87
  # that inherited methods (offload, prefetch) operate on real tensors. _counts
88
  # tracks valid occupancy per (batch, head) slot.
89
  self.keys: torch.Tensor = torch.zeros(
90
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
91
  )
92
  self.values: torch.Tensor = torch.zeros(
93
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
94
  )
95
  self._counts: torch.Tensor = torch.zeros(
96
  batch_size, num_mosrah_heads, dtype=torch.long, device=device
@@ -107,8 +109,8 @@ class MoSRAHCache(CacheLayerMixin):
107
  def buffer_capacity(self) -> int:
108
  """Current number of slots allocated per (batch, head) pair.
109
 
110
- Derived directly from self.keys rather than tracked separately, so it is
111
- always consistent with the actual buffer after expansion.
112
  """
113
  return self.keys.shape[2]
114
 
@@ -129,10 +131,11 @@ class MoSRAHCache(CacheLayerMixin):
129
  active_mask is (B, L, T) bool with True marking real tokens. Only active
130
  positions are written; inactive positions are ignored.
131
 
132
- Uses a cumsum construction to derive the absolute buffer position for each
133
- active token without any Python loops. For a given (batch, head) slot,
134
- positions are assigned in the order tokens appear along the T dimension,
135
- preserving causal ordering.
 
136
 
137
  Returns the full accumulated (keys, values, active_mask) across the cached
138
  sparse sequence. The returned active_mask is True exactly for slots t <
@@ -150,35 +153,36 @@ class MoSRAHCache(CacheLayerMixin):
150
 
151
  Returns:
152
  Tuple of (keys, values, active_mask):
153
- keys: (B, L, T, u) float — full key buffer including junk slots.
154
- values: (B, L, T, u) float — full value buffer including junk slots.
155
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
156
  """
157
  incoming_delta = active_mask.long().sum(dim=2) # (B, L)
158
 
159
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
160
- self._expand()
161
-
162
- # Cumulative count of active positions along T for each (b, l) slot. Entry
163
- # [b, l, t] is the 1-based rank of position t among all active positions in
164
- # that slot. Subtract 1 for a zero-based within-slot index. Inactive positions
165
- # produce a negative value, which is excluded by the mask gate below.
166
- within_slot = active_mask.long().cumsum(dim=2) - 1 # (B, L, T)
167
-
168
- # Add the pre-update count to get the absolute buffer position for each
169
- # active token.
170
- abs_pos = within_slot + self._counts.unsqueeze(-1) # (B, L, T)
171
-
172
- # Scatter key and value vectors at all active positions.
173
- b_idx, l_idx, t_idx = torch.where(active_mask)
174
- self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
175
- key_states[b_idx, l_idx, t_idx]
176
- )
177
- self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
178
- value_states[b_idx, l_idx, t_idx]
179
  )
 
 
180
 
181
- self._counts += incoming_delta
 
 
 
 
 
 
182
 
183
  return self.keys, self.values, self._make_active_mask()
184
 
@@ -303,10 +307,13 @@ class MoSRAHCache(CacheLayerMixin):
303
  )
304
 
305
  def get_max_cache_shape(self) -> int: # type: ignore[override]
306
- """Not supported MoSRAHCache is dynamic and unbounded."""
307
- raise NotImplementedError(
308
- "MoSRAHCache is unbounded; get_max_cache_shape() is not supported."
309
- )
 
 
 
310
 
311
  def get_mask_sizes( # type: ignore[override]
312
  self,
@@ -335,25 +342,26 @@ class MoSRAHCache(CacheLayerMixin):
335
  < self._counts.unsqueeze(-1)
336
  )
337
 
338
- def _expand(self) -> None:
339
- """Double the buffer capacity, preserving existing data.
 
340
 
341
- Called by update() when an incoming batch of tokens would cause any
342
- (batch, head) slot to exceed the current buffer capacity. All existing
343
- key and value data is copied into the low half of the new buffer; the
344
- high half is zero-initialised and will be filled by subsequent writes.
345
- After reassignment, buffer_capacity reflects the new size automatically.
 
 
 
346
  """
347
- old_cap = self.buffer_capacity
348
- new_cap = old_cap * 2
349
- dev = self.keys.device
350
- new_keys = torch.zeros(
351
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
352
- )
353
- new_values = torch.zeros(
354
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
355
- )
356
- new_keys[:, :, :old_cap, :] = self.keys
357
- new_values[:, :, :old_cap, :] = self.values
358
- self.keys = new_keys
359
- self.values = new_values
 
61
  batch_size: Number of sequences in the batch. Determines the first dimension
62
  of all storage tensors.
63
  device: Device on which to allocate all tensors. Should match the model device.
64
+ mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to
65
+ config.mosrah_cache_length. The buffer never grows; if any slot would exceed
66
+ this capacity, update() raises in both eager and compiled modes. Increase
67
+ mosrah_overallocation_factor in ShramConfig to resolve an overflow.
68
  """
69
 
70
+ is_compileable = True
71
  is_sliding = False
72
 
73
  def __init__(
 
76
  head_dim: int,
77
  batch_size: int,
78
  device: torch.device,
79
+ mosrah_cache_length: int,
80
  ) -> None:
81
  super().__init__()
82
  self.num_mosrah_heads = num_mosrah_heads
83
  self.head_dim = head_dim
84
  self.batch_size = batch_size
85
  self.device = device
86
+ self.mosrah_cache_length = mosrah_cache_length
87
 
88
  # Allocate primary storage into the mixin-standard self.keys / self.values so
89
  # that inherited methods (offload, prefetch) operate on real tensors. _counts
90
  # tracks valid occupancy per (batch, head) slot.
91
  self.keys: torch.Tensor = torch.zeros(
92
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
93
  )
94
  self.values: torch.Tensor = torch.zeros(
95
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
96
  )
97
  self._counts: torch.Tensor = torch.zeros(
98
  batch_size, num_mosrah_heads, dtype=torch.long, device=device
 
109
  def buffer_capacity(self) -> int:
110
  """Current number of slots allocated per (batch, head) pair.
111
 
112
+ Equal to mosrah_cache_length as supplied at construction. Derived from
113
+ self.keys so it remains consistent with the actual buffer shape.
114
  """
115
  return self.keys.shape[2]
116
 
 
131
  active_mask is (B, L, T) bool with True marking real tokens. Only active
132
  positions are written; inactive positions are ignored.
133
 
134
+ Uses a fixed-shape destination mask constructed from per-slot write intervals
135
+ to transfer active tokens into the buffer without any data-dependent shape
136
+ operations. Active tokens are left-justified within each packed slot by the
137
+ packing machinery, so the destination positions are a contiguous range
138
+ starting at the current slot count — no cumsum or torch.where needed.
139
 
140
  Returns the full accumulated (keys, values, active_mask) across the cached
141
  sparse sequence. The returned active_mask is True exactly for slots t <
 
153
 
154
  Returns:
155
  Tuple of (keys, values, active_mask):
156
+ keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots.
157
+ values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots.
158
+ active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written.
159
  """
160
  incoming_delta = active_mask.long().sum(dim=2) # (B, L)
161
 
162
+ post_counts = self._counts + incoming_delta
163
+ self._check_no_overflow(post_counts.max(), self.mosrah_cache_length)
164
+
165
+ # Build a fixed-shape destination mask in cache space. Active tokens within
166
+ # each (b, l) slot are left-justified by the packing machinery, so they occupy
167
+ # positions 0..s-1 in their packed slot. The corresponding cache positions are
168
+ # write_start[b,l]..write_start[b,l]+write_count[b,l]-1. Broadcasting a
169
+ # time arange against these per-slot intervals selects exactly the target
170
+ # positions without any data-dependent shape query.
171
+ write_start = self._counts.unsqueeze(-1) # cache position where new tokens begin
172
+ write_count = incoming_delta.unsqueeze(-1) # number of new tokens arriving per slot
173
+ time_arange = torch.arange(
174
+ self.mosrah_cache_length, device=active_mask.device
 
 
 
 
 
 
 
175
  )
176
+ dest_mask = (time_arange >= write_start) & (time_arange < write_start + write_count)
177
+ # dest_mask: (B, L, mosrah_cache_length)
178
 
179
+ # Transfer key and value vectors. Left-justification guarantees that
180
+ # dest_mask and active_mask have equal True counts per (b, l) slot, so the
181
+ # boolean-mask transfer is correct without any explicit count verification.
182
+ self.keys[dest_mask] = key_states[active_mask]
183
+ self.values[dest_mask] = value_states[active_mask]
184
+
185
+ self._counts = post_counts
186
 
187
  return self.keys, self.values, self._make_active_mask()
188
 
 
307
  )
308
 
309
  def get_max_cache_shape(self) -> int: # type: ignore[override]
310
+ """Return the static per-(batch, head) slot capacity of this cache.
311
+
312
+ Equal to mosrah_cache_length as supplied at construction, which is derived
313
+ from config.mosrah_cache_length. Required by the HuggingFace static cache
314
+ contract; generation machinery uses this to size attention masks.
315
+ """
316
+ return self.mosrah_cache_length
317
 
318
  def get_mask_sizes( # type: ignore[override]
319
  self,
 
342
  < self._counts.unsqueeze(-1)
343
  )
344
 
345
+ @staticmethod
346
+ def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
347
+ """Raise if any (batch, head) slot would exceed the static buffer capacity.
348
 
349
+ Uses the 19.F.1 pattern: branches on whether the graph is being compiled.
350
+ In compiled mode, `.item()` folds into the graph when capture_scalar_outputs=True
351
+ and `torch._check` issues a compile-time assertion. In eager mode, a plain
352
+ RuntimeError is raised with a descriptive message.
353
+
354
+ Args:
355
+ max_count: Scalar tensor — the maximum post-update count across all slots.
356
+ capacity: The static buffer capacity (mosrah_cache_length).
357
  """
358
+ if torch.compiler.is_compiling():
359
+ torch._check(max_count.item() <= capacity)
360
+ else:
361
+ if max_count.item() > capacity:
362
+ raise RuntimeError(
363
+ f"MoSRAHCache overflow: a (batch, head) slot would reach "
364
+ f"{max_count.item()} tokens but the static buffer capacity is "
365
+ f"{capacity}. Increase mosrah_overallocation_factor in ShramConfig."
366
+ )
367
+
 
 
 
__cache__shram_cache.py CHANGED
@@ -21,6 +21,7 @@ what HuggingFace generation reads through get_seq_length().
21
  import torch
22
  from transformers.cache_utils import Cache
23
 
 
24
  from .__cache__shram_layer_cache import ShramLayerCache
25
 
26
 
@@ -36,44 +37,28 @@ class ShramCache(Cache):
36
  via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
37
 
38
  Args:
39
- num_hidden_layers: Number of SHRAM decoder layers. Determines how many
40
- ShramLayerCache objects are constructed.
41
- sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
42
- num_local_heads: Number of local attention heads per layer.
43
- local_head_dim: Per-head embedding width for the local path.
44
- num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
45
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
46
  batch_size: Number of sequences in the batch.
47
  device: Device on which to allocate cache tensors.
48
- initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
49
- Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
50
- during prompt processing.
51
  """
52
 
 
 
53
  def __init__(
54
  self,
55
- num_hidden_layers: int,
56
- sliding_window: int,
57
- num_local_heads: int,
58
- local_head_dim: int,
59
- num_mosrah_heads: int,
60
- mosrah_head_dim: int,
61
  batch_size: int,
62
  device: torch.device,
63
- initial_buffer_size: int = 64,
64
  ) -> None:
65
  layers = [
66
  ShramLayerCache(
67
- sliding_window=sliding_window,
68
- num_local_heads=num_local_heads,
69
- local_head_dim=local_head_dim,
70
- num_mosrah_heads=num_mosrah_heads,
71
- mosrah_head_dim=mosrah_head_dim,
72
  batch_size=batch_size,
73
  device=device,
74
- initial_buffer_size=initial_buffer_size,
75
  )
76
- for _ in range(num_hidden_layers)
77
  ]
78
  super().__init__(layers=layers)
79
 
@@ -133,9 +118,10 @@ class ShramCache(Cache):
133
 
134
  @property
135
  def max_cache_len(self) -> int:
136
- """Not supported ShramCache has no single maximum cache length.
137
 
138
- The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
139
- No truthful scalar maximum represents the composite.
 
140
  """
141
- raise NotImplementedError("ShramCache does not expose max_cache_len.")
 
21
  import torch
22
  from transformers.cache_utils import Cache
23
 
24
+ from .configuration import ShramConfig
25
  from .__cache__shram_layer_cache import ShramLayerCache
26
 
27
 
 
37
  via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
38
 
39
  Args:
40
+ config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache
41
+ dimensions are derived from config so that a single source of truth governs
42
+ every buffer size across the full cache stack.
 
 
 
 
43
  batch_size: Number of sequences in the batch.
44
  device: Device on which to allocate cache tensors.
 
 
 
45
  """
46
 
47
+ is_compileable = True
48
+
49
  def __init__(
50
  self,
51
+ config: ShramConfig,
 
 
 
 
 
52
  batch_size: int,
53
  device: torch.device,
 
54
  ) -> None:
55
  layers = [
56
  ShramLayerCache(
57
+ config=config,
 
 
 
 
58
  batch_size=batch_size,
59
  device=device,
 
60
  )
61
+ for _ in range(config.num_decoder_layers)
62
  ]
63
  super().__init__(layers=layers)
64
 
 
118
 
119
  @property
120
  def max_cache_len(self) -> int:
121
+ """Return the maximum sequence length the cache can serve.
122
 
123
+ Delegates to layers[0].get_max_cache_shape(), which returns
124
+ config.inference_sequence_length. HuggingFace's static-cache machinery reads
125
+ this value to size generation loops and verify compileable cache contracts.
126
  """
127
+ return self.layers[0].get_max_cache_shape()
__cache__shram_layer_cache.py CHANGED
@@ -21,6 +21,7 @@ quantity HuggingFace generation reads through get_seq_length().
21
  import torch
22
  from transformers.cache_utils import CacheLayerMixin
23
 
 
24
  from .__cache__mosrah_cache import MoSRAHCache
25
  from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
26
 
@@ -40,46 +41,36 @@ class ShramLayerCache(CacheLayerMixin):
40
  tracks the cumulative count of token positions processed across all update() calls.
41
 
42
  Args:
43
- sliding_window: Number of tokens retained by the local sliding-window cache.
44
- num_local_heads: Number of local attention heads.
45
- local_head_dim: Per-head embedding width for the local path.
46
- num_mosrah_heads: Total number of MoSRAH expert heads (L).
47
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
48
  batch_size: Number of sequences in the batch.
49
  device: Device on which to allocate cache tensors.
50
- initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
51
- when any slot overflows. Defaults to 64 to avoid repeated reallocation during
52
- prompt processing.
53
  """
54
 
55
- is_compileable = False
56
  is_sliding = False
57
 
58
  def __init__(
59
  self,
60
- sliding_window: int,
61
- num_local_heads: int,
62
- local_head_dim: int,
63
- num_mosrah_heads: int,
64
- mosrah_head_dim: int,
65
  batch_size: int,
66
  device: torch.device,
67
- initial_buffer_size: int = 64,
68
  ) -> None:
69
  super().__init__()
 
70
  self.sliding_window_cache = LocalSlidingWindowLayerCache(
71
- sliding_window=sliding_window,
72
- num_heads=num_local_heads,
73
- head_dim=local_head_dim,
74
  batch_size=batch_size,
75
  device=device,
76
  )
77
  self.mosrah_cache = MoSRAHCache(
78
- num_mosrah_heads=num_mosrah_heads,
79
- head_dim=mosrah_head_dim,
80
  batch_size=batch_size,
81
  device=device,
82
- initial_buffer_size=initial_buffer_size,
83
  )
84
 
85
  # ---------------------------------------------------------------------------
@@ -208,26 +199,23 @@ class ShramLayerCache(CacheLayerMixin):
208
  )
209
 
210
  def get_max_cache_shape(self) -> int: # type: ignore[override]
211
- """Not supported — the composite cache has no single maximum shape.
212
 
213
- The sliding-window cache is bounded by sliding_window; the MoSRAH cache is
214
- unbounded. No truthful scalar maximum represents the composite.
 
 
215
  """
216
- raise NotImplementedError(
217
- "ShramLayerCache has no single maximum cache shape. "
218
- "Query sliding_window_cache or mosrah_cache directly."
219
- )
220
 
221
  def get_mask_sizes( # type: ignore[override]
222
  self,
223
  cache_position: torch.Tensor,
224
  ) -> tuple[int, int]:
225
- """Not supported ShramLayerCache does not participate in HF mask construction.
226
 
227
- The two sub-caches have different mask semantics and their respective attention
228
- paths handle masking directly.
 
229
  """
230
- raise NotImplementedError(
231
- "ShramLayerCache does not support get_mask_sizes(). "
232
- "Query sliding_window_cache or mosrah_cache directly."
233
- )
 
21
  import torch
22
  from transformers.cache_utils import CacheLayerMixin
23
 
24
+ from .configuration import ShramConfig
25
  from .__cache__mosrah_cache import MoSRAHCache
26
  from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
27
 
 
41
  tracks the cumulative count of token positions processed across all update() calls.
42
 
43
  Args:
44
+ config: ShramConfig instance. All sub-cache dimensions and capacities are derived
45
+ from config so that a single source of truth governs every buffer size.
 
 
 
46
  batch_size: Number of sequences in the batch.
47
  device: Device on which to allocate cache tensors.
 
 
 
48
  """
49
 
50
+ is_compileable = True
51
  is_sliding = False
52
 
53
  def __init__(
54
  self,
55
+ config: ShramConfig,
 
 
 
 
56
  batch_size: int,
57
  device: torch.device,
 
58
  ) -> None:
59
  super().__init__()
60
+ self._inference_sequence_length = config.inference_sequence_length
61
  self.sliding_window_cache = LocalSlidingWindowLayerCache(
62
+ sliding_window=config.window_size,
63
+ num_heads=config.num_sliding_window_heads,
64
+ head_dim=config.head_dim,
65
  batch_size=batch_size,
66
  device=device,
67
  )
68
  self.mosrah_cache = MoSRAHCache(
69
+ num_mosrah_heads=config.num_mosrah_heads,
70
+ head_dim=config.head_dim,
71
  batch_size=batch_size,
72
  device=device,
73
+ mosrah_cache_length=config.mosrah_cache_length,
74
  )
75
 
76
  # ---------------------------------------------------------------------------
 
199
  )
200
 
201
  def get_max_cache_shape(self) -> int: # type: ignore[override]
202
+ """Return the maximum sequence length this layer cache can serve.
203
 
204
+ The authoritative upper bound is ``config.inference_sequence_length``, which
205
+ governs the full accumulated token history the model is configured to handle.
206
+ HuggingFace's static-cache machinery reads this value to determine whether the
207
+ cache is compileable and to size generation loops.
208
  """
209
+ return self._inference_sequence_length
 
 
 
210
 
211
  def get_mask_sizes( # type: ignore[override]
212
  self,
213
  cache_position: torch.Tensor,
214
  ) -> tuple[int, int]:
215
+ """Return the KV dimensions for HuggingFace causal mask construction.
216
 
217
+ Returns (inference_sequence_length, 0): the full static cache capacity as
218
+ kv_length and zero offset. HuggingFace reads these values to size the causal
219
+ attention mask when is_compileable is True.
220
  """
221
+ return self._inference_sequence_length, 0
 
 
 
__cache__sliding_window_cache.py CHANGED
@@ -39,7 +39,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
39
  device: Device on which to allocate cache storage.
40
  """
41
 
42
- is_compileable = False
43
  is_sliding = True
44
 
45
  def __init__(
 
39
  device: Device on which to allocate cache storage.
40
  """
41
 
42
+ is_compileable = True
43
  is_sliding = True
44
 
45
  def __init__(
__cache__slow_mosrah_cache.py CHANGED
@@ -41,9 +41,9 @@ class SlowMoSRAHCache(CacheLayerMixin):
41
  batch_size: Number of sequences in the batch. Determines the first dimension
42
  of all storage tensors.
43
  device: Device on which to allocate all tensors. Should match the model device.
44
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
45
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
46
- during prompt processing.
47
  """
48
 
49
  is_compileable = False
@@ -55,22 +55,23 @@ class SlowMoSRAHCache(CacheLayerMixin):
55
  head_dim: int,
56
  batch_size: int,
57
  device: torch.device,
58
- initial_buffer_size: int = 64,
59
  ) -> None:
60
  super().__init__()
61
  self.num_mosrah_heads = num_mosrah_heads
62
  self.head_dim = head_dim
63
  self.batch_size = batch_size
64
  self.device = device
 
65
 
66
  # Allocate primary storage into the mixin-standard self.keys / self.values so
67
  # that inherited methods (offload, prefetch) operate on real tensors. _counts
68
  # tracks valid occupancy per (batch, head) slot.
69
  self.keys: torch.Tensor = torch.zeros(
70
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
71
  )
72
  self.values: torch.Tensor = torch.zeros(
73
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
74
  )
75
  self._counts: torch.Tensor = torch.zeros(
76
  batch_size, num_mosrah_heads, dtype=torch.long, device=device
@@ -87,8 +88,8 @@ class SlowMoSRAHCache(CacheLayerMixin):
87
  def buffer_capacity(self) -> int:
88
  """Current number of slots allocated per (batch, head) pair.
89
 
90
- Derived directly from self.keys rather than tracked separately, so it is
91
- always consistent with the actual buffer after expansion.
92
  """
93
  return self.keys.shape[2]
94
 
@@ -111,8 +112,8 @@ class SlowMoSRAHCache(CacheLayerMixin):
111
  because the t dimension is traversed from 0 to T-1 and counts are updated
112
  immediately after each write.
113
 
114
- Buffer expansion (doubling buffer_capacity) is triggered before any writes if
115
- the incoming tokens would cause any slot to overflow the current capacity.
116
 
117
  Args:
118
  key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
@@ -122,17 +123,19 @@ class SlowMoSRAHCache(CacheLayerMixin):
122
 
123
  Returns:
124
  Tuple of (keys, values, active_mask):
125
- keys: (B, L, T, u) float — full key buffer including junk slots.
126
- values: (B, L, T, u) float — full value buffer including junk slots.
127
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
128
  """
129
  B, L, T = active_mask.shape
130
 
131
- # Expansion check uses the total active tokens per slot, same as the
132
- # vectorized implementation, so both expand under identical conditions.
133
  incoming_delta = active_mask.long().sum(dim=2) # (B, L)
134
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
135
- self._expand()
 
 
 
 
136
 
137
  # Write each active position into the next available slot for its (batch, head)
138
  # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
@@ -297,25 +300,3 @@ class SlowMoSRAHCache(CacheLayerMixin):
297
  < self._counts.unsqueeze(-1)
298
  )
299
 
300
- def _expand(self) -> None:
301
- """Double the buffer capacity, preserving existing data.
302
-
303
- Called by update() when an incoming batch of tokens would cause any
304
- (batch, head) slot to exceed the current buffer capacity. All existing
305
- key and value data is copied into the low half of the new buffer; the
306
- high half is zero-initialised and will be filled by subsequent writes.
307
- After reassignment, buffer_capacity reflects the new size automatically.
308
- """
309
- old_cap = self.buffer_capacity
310
- new_cap = old_cap * 2
311
- dev = self.keys.device
312
- new_keys = torch.zeros(
313
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
314
- )
315
- new_values = torch.zeros(
316
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
317
- )
318
- new_keys[:, :, :old_cap, :] = self.keys
319
- new_values[:, :, :old_cap, :] = self.values
320
- self.keys = new_keys
321
- self.values = new_values
 
41
  batch_size: Number of sequences in the batch. Determines the first dimension
42
  of all storage tensors.
43
  device: Device on which to allocate all tensors. Should match the model device.
44
+ mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to
45
+ config.mosrah_cache_length. The buffer never grows; if any slot would exceed
46
+ this capacity, update() raises a RuntimeError.
47
  """
48
 
49
  is_compileable = False
 
55
  head_dim: int,
56
  batch_size: int,
57
  device: torch.device,
58
+ mosrah_cache_length: int,
59
  ) -> None:
60
  super().__init__()
61
  self.num_mosrah_heads = num_mosrah_heads
62
  self.head_dim = head_dim
63
  self.batch_size = batch_size
64
  self.device = device
65
+ self.mosrah_cache_length = mosrah_cache_length
66
 
67
  # Allocate primary storage into the mixin-standard self.keys / self.values so
68
  # that inherited methods (offload, prefetch) operate on real tensors. _counts
69
  # tracks valid occupancy per (batch, head) slot.
70
  self.keys: torch.Tensor = torch.zeros(
71
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
72
  )
73
  self.values: torch.Tensor = torch.zeros(
74
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
75
  )
76
  self._counts: torch.Tensor = torch.zeros(
77
  batch_size, num_mosrah_heads, dtype=torch.long, device=device
 
88
  def buffer_capacity(self) -> int:
89
  """Current number of slots allocated per (batch, head) pair.
90
 
91
+ Equal to mosrah_cache_length as supplied at construction. Derived from
92
+ self.keys so it remains consistent with the actual buffer shape.
93
  """
94
  return self.keys.shape[2]
95
 
 
112
  because the t dimension is traversed from 0 to T-1 and counts are updated
113
  immediately after each write.
114
 
115
+ Raises RuntimeError before any writes if the incoming tokens would cause any
116
+ slot to exceed the static mosrah_cache_length capacity.
117
 
118
  Args:
119
  key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
 
123
 
124
  Returns:
125
  Tuple of (keys, values, active_mask):
126
+ keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots.
127
+ values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots.
128
+ active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written.
129
  """
130
  B, L, T = active_mask.shape
131
 
 
 
132
  incoming_delta = active_mask.long().sum(dim=2) # (B, L)
133
+ if (self._counts + incoming_delta).max().item() > self.mosrah_cache_length:
134
+ raise RuntimeError(
135
+ f"SlowMoSRAHCache overflow: a (batch, head) slot would exceed the "
136
+ f"static buffer capacity of {self.mosrah_cache_length}. Increase "
137
+ f"mosrah_overallocation_factor in ShramConfig."
138
+ )
139
 
140
  # Write each active position into the next available slot for its (batch, head)
141
  # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
 
300
  < self._counts.unsqueeze(-1)
301
  )
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -6,14 +6,16 @@
6
  "AutoModelForCausalLM": "huggingface.ShramForCausalLM"
7
  },
8
  "beta": 32.0,
 
9
  "head_dim": 16,
10
- "hidden_size": 512,
11
  "inference_sequence_length": 1024,
12
- "intermediate_size": 1366,
13
  "local_rope_theta": 10000.0,
 
14
  "model_type": "shram",
 
15
  "mosrah_rope_theta": 10000.0,
16
- "num_hidden_layers": 12,
17
  "num_mosrah_heads": 16,
18
  "num_selected_heads": 16,
19
  "num_sliding_window_heads": 16,
@@ -21,7 +23,7 @@
21
  "rope_mode": "main_sequence",
22
  "tie_word_embeddings": false,
23
  "training_sequence_length": 1024,
24
- "transformers_version": "5.8.0",
25
  "use_cache": true,
26
  "vocab_size": 50277,
27
  "window_size": 128
 
6
  "AutoModelForCausalLM": "huggingface.ShramForCausalLM"
7
  },
8
  "beta": 32.0,
9
+ "embedding_width": 512,
10
  "head_dim": 16,
 
11
  "inference_sequence_length": 1024,
12
+ "load_balance_p": 2.0,
13
  "local_rope_theta": 10000.0,
14
+ "mlp_width": 1366,
15
  "model_type": "shram",
16
+ "mosrah_overallocation_factor": 2.0,
17
  "mosrah_rope_theta": 10000.0,
18
+ "num_decoder_layers": 12,
19
  "num_mosrah_heads": 16,
20
  "num_selected_heads": 16,
21
  "num_sliding_window_heads": 16,
 
23
  "rope_mode": "main_sequence",
24
  "tie_word_embeddings": false,
25
  "training_sequence_length": 1024,
26
+ "transformers_version": "5.8.1",
27
  "use_cache": true,
28
  "vocab_size": 50277,
29
  "window_size": 128
configuration.py CHANGED
@@ -11,6 +11,8 @@ parameters directly and constructs its own RotaryEmbedding instance explicitly
11
  HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
  """
13
 
 
 
14
  from transformers import PretrainedConfig
15
 
16
 
@@ -77,6 +79,15 @@ class ShramConfig(PretrainedConfig):
77
  use_cache: Whether to return past_key_values for KV caching.
78
  output_hidden_states: Whether to return hidden states after each layer.
79
  tie_word_embeddings: Whether input embedding and LM head share weights.
 
 
 
 
 
 
 
 
 
80
  """
81
 
82
  model_type = "shram"
@@ -109,7 +120,9 @@ class ShramConfig(PretrainedConfig):
109
  use_cache: bool = True,
110
  output_hidden_states: bool = False,
111
  tie_word_embeddings: bool = False,
112
- **kwargs,
 
 
113
  ):
114
  if head_dim % 2 != 0:
115
  raise ValueError(
@@ -137,10 +150,22 @@ class ShramConfig(PretrainedConfig):
137
  f"got {inference_sequence_length}."
138
  )
139
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  self.vocab_size = vocab_size
141
- self.hidden_size = embedding_width
142
- self.intermediate_size = mlp_width
143
- self.num_hidden_layers = num_decoder_layers
144
  self.num_sliding_window_heads = num_sliding_window_heads
145
  self.num_mosrah_heads = num_mosrah_heads
146
  self.num_selected_heads = num_selected_heads
@@ -154,13 +179,15 @@ class ShramConfig(PretrainedConfig):
154
  self.inference_sequence_length = inference_sequence_length
155
  self.alpha = alpha
156
  self.beta = beta
 
 
157
  self.attention_dropout = attention_dropout
158
  self.use_cache = use_cache
159
 
160
  super().__init__(
161
  tie_word_embeddings=tie_word_embeddings,
162
  output_hidden_states=output_hidden_states,
163
- **kwargs,
164
  )
165
 
166
  # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
@@ -176,3 +203,47 @@ class ShramConfig(PretrainedConfig):
176
  """
177
  return self.inference_sequence_length / self.training_sequence_length
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
  """
13
 
14
+ import math
15
+
16
  from transformers import PretrainedConfig
17
 
18
 
 
79
  use_cache: Whether to return past_key_values for KV caching.
80
  output_hidden_states: Whether to return hidden states after each layer.
81
  tie_word_embeddings: Whether input embedding and LM head share weights.
82
+ mosrah_overallocation_factor: Overallocation multiplier for the expert packing
83
+ buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
84
+ num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
85
+ Must be > 1.0 to guarantee a buffer larger than the balanced-routing
86
+ baseline. Default 2.0.
87
+ load_balance_p: Exponent p for the p-mean aggregation of per-item routing
88
+ frequencies into the load balance signal. Higher p weights aggregation
89
+ toward the worst-case batch item, making the correction signal more
90
+ sensitive to per-item allocation spikes. Must be positive. Default 2.0.
91
  """
92
 
93
  model_type = "shram"
 
120
  use_cache: bool = True,
121
  output_hidden_states: bool = False,
122
  tie_word_embeddings: bool = False,
123
+ mosrah_overallocation_factor: float = 2.0,
124
+ load_balance_p: float = 2.0,
125
+ **kwargs
126
  ):
127
  if head_dim % 2 != 0:
128
  raise ValueError(
 
150
  f"got {inference_sequence_length}."
151
  )
152
 
153
+ if mosrah_overallocation_factor <= 1.0:
154
+ raise ValueError(
155
+ f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
156
+ f"buffer larger than the balanced-routing baseline. "
157
+ f"Got {mosrah_overallocation_factor}."
158
+ )
159
+
160
+ if load_balance_p <= 0.0:
161
+ raise ValueError(
162
+ f"load_balance_p must be positive, got {load_balance_p}."
163
+ )
164
+
165
  self.vocab_size = vocab_size
166
+ self.embedding_width = embedding_width
167
+ self.mlp_width = mlp_width
168
+ self.num_decoder_layers = num_decoder_layers
169
  self.num_sliding_window_heads = num_sliding_window_heads
170
  self.num_mosrah_heads = num_mosrah_heads
171
  self.num_selected_heads = num_selected_heads
 
179
  self.inference_sequence_length = inference_sequence_length
180
  self.alpha = alpha
181
  self.beta = beta
182
+ self.mosrah_overallocation_factor = mosrah_overallocation_factor
183
+ self.load_balance_p = load_balance_p
184
  self.attention_dropout = attention_dropout
185
  self.use_cache = use_cache
186
 
187
  super().__init__(
188
  tie_word_embeddings=tie_word_embeddings,
189
  output_hidden_states=output_hidden_states,
190
+ **kwargs
191
  )
192
 
193
  # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
 
203
  """
204
  return self.inference_sequence_length / self.training_sequence_length
205
 
206
+ @property
207
+ def mosrah_packed_length(self) -> int:
208
+ """Static packed time dimension T for expert packing.
209
+
210
+ The expected tokens per expert under perfectly balanced routing is
211
+ ``training_sequence_length * num_selected_heads / num_mosrah_heads``.
212
+ Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
213
+ that baseline. The ceiling ensures T is always an integer >= 1.
214
+
215
+ All consumers of the packed buffer size must read this property rather
216
+ than deriving T independently.
217
+ """
218
+ return math.ceil(
219
+ self.training_sequence_length
220
+ * self.num_selected_heads
221
+ / self.num_mosrah_heads
222
+ * self.mosrah_overallocation_factor
223
+ )
224
+
225
+ @property
226
+ def mosrah_cache_length(self) -> int:
227
+ """Static per-(batch, head) slot capacity for the MoSRAH inference cache.
228
+
229
+ The expected tokens per expert over the full inference context under perfectly
230
+ balanced routing is ``inference_sequence_length * num_selected_heads /
231
+ num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
232
+ a buffer above that baseline. The ceiling ensures the result is always an
233
+ integer >= 1.
234
+
235
+ Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
236
+ using ``training_sequence_length``. This property uses
237
+ ``inference_sequence_length`` because the cache must hold the full accumulated
238
+ token history across the entire inference run.
239
+
240
+ All consumers of the MoSRAH cache buffer size must read this property rather
241
+ than deriving the capacity independently.
242
+ """
243
+ return math.ceil(
244
+ self.inference_sequence_length
245
+ * self.num_selected_heads
246
+ / self.num_mosrah_heads
247
+ * self.mosrah_overallocation_factor
248
+ )
249
+
decoder_layer.py CHANGED
@@ -46,8 +46,8 @@ class DecoderLayer(nn.Module):
46
 
47
  def __init__(self, config: ShramConfig) -> None:
48
  super().__init__()
49
- self.attn_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
50
- self.mlp_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51
  self.attention = SHRAMHybridLayer(config)
52
  self.mlp = SwiGLUMLP(config)
53
 
 
46
 
47
  def __init__(self, config: ShramConfig) -> None:
48
  super().__init__()
49
+ self.attn_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
50
+ self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
51
  self.attention = SHRAMHybridLayer(config)
52
  self.mlp = SwiGLUMLP(config)
53
 
huggingface.py CHANGED
@@ -74,9 +74,9 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
74
 
75
  def __init__(self, config: ShramConfig) -> None:
76
  super().__init__(config)
77
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
78
  self.model = ShramModel(config)
79
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
80
  self._configure_tied_embeddings()
81
  self.post_init()
82
 
@@ -127,12 +127,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
127
  ) -> ShramCache:
128
  """Construct a fresh top-level SHRAM cache."""
129
  return ShramCache(
130
- num_hidden_layers=self.config.num_hidden_layers,
131
- sliding_window=self.config.window_size,
132
- num_local_heads=self.config.num_sliding_window_heads,
133
- local_head_dim=self.config.head_dim,
134
- num_mosrah_heads=self.config.num_mosrah_heads,
135
- mosrah_head_dim=self.config.head_dim,
136
  batch_size=batch_size,
137
  device=device,
138
  )
@@ -231,6 +226,26 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
231
  past_key_values.reorder_cache(beam_idx)
232
  return past_key_values
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
235
  """Validate token IDs at the wrapper boundary."""
236
  if input_ids.ndim != 2:
@@ -352,6 +367,63 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
352
  f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
353
  )
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  def _standardize_full_attention_mask(
356
  self,
357
  input_ids: torch.Tensor,
@@ -449,6 +521,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
449
  # This keeps the main sequence readable while ensuring invalid states
450
  # fail before they can silently contaminate backbone execution.
451
  # ------------------------------------------------------------------
 
452
  self._validate_input_ids(input_ids)
453
  self._validate_attention_mask(input_ids, attention_mask)
454
  self._validate_position_ids(input_ids, position_ids)
@@ -487,6 +560,10 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
487
  )
488
  shram_cache: ShramCache | None = past_key_values if use_cache else None
489
 
 
 
 
 
490
  # ------------------------------------------------------------------
491
  # Core wrapper responsibilities.
492
  #
 
74
 
75
  def __init__(self, config: ShramConfig) -> None:
76
  super().__init__(config)
77
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_width)
78
  self.model = ShramModel(config)
79
+ self.lm_head = nn.Linear(config.embedding_width, config.vocab_size, bias=False)
80
  self._configure_tied_embeddings()
81
  self.post_init()
82
 
 
127
  ) -> ShramCache:
128
  """Construct a fresh top-level SHRAM cache."""
129
  return ShramCache(
130
+ config=self.config,
 
 
 
 
 
131
  batch_size=batch_size,
132
  device=device,
133
  )
 
226
  past_key_values.reorder_cache(beam_idx)
227
  return past_key_values
228
 
229
+ @staticmethod
230
+ def create_masks_for_generate(
231
+ config: Any,
232
+ inputs_embeds: torch.Tensor,
233
+ attention_mask: torch.Tensor | None,
234
+ past_key_values: Cache | None,
235
+ position_ids: torch.Tensor | None = None,
236
+ **kwargs: Any,
237
+ ) -> torch.Tensor | None:
238
+ """Return the 2D attention_mask unchanged.
239
+
240
+ HuggingFace calls this during compiled generation to convert the 2D
241
+ attention mask into a 4D causal additive-bias mask. SHRAM uses flex
242
+ attention with custom masking and constructs causality internally; the
243
+ 4D format is incompatible with the SHRAM masking contract. Overriding
244
+ as a no-op restores symmetry between compiled and non-compiled pathways
245
+ without any loss of correctness or performance (see Unit 19.G.4).
246
+ """
247
+ return attention_mask
248
+
249
  def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
250
  """Validate token IDs at the wrapper boundary."""
251
  if input_ids.ndim != 2:
 
367
  f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
368
  )
369
 
370
+ @staticmethod
371
+ def _enforce_uncached_starting_position(condition: torch.Tensor) -> None:
372
+ """Enforce that an uncached forward pass begins at position 0.
373
+
374
+ An uncached forward has no prior KV state. Nonzero starting positions
375
+ produce silently incorrect RoPE encoding and attention outputs with no
376
+ downstream diagnostic. This method intercepts that misuse at the
377
+ outermost boundary before any backbone computation runs.
378
+
379
+ To resolve a violation: either supply a ShramCache populated with the
380
+ prefix (for continued decoding), or rebase the sequence so positions
381
+ start at 0.
382
+
383
+ Args:
384
+ condition: Scalar bool tensor. True = all batch items start at 0
385
+ (valid); False = at least one batch item starts nonzero
386
+ (violated).
387
+ """
388
+ if torch.compiler.is_compiling():
389
+ # bool.item() is not captured as a SymBool by dynamo; converting to
390
+ # int first produces a SymInt, and the Python comparison (!=0) then
391
+ # yields a SymBool that torch._check folds into the compiled graph.
392
+ condition_as_int = condition.to(torch.int).item()
393
+ torch._check(condition_as_int != 0)
394
+ else:
395
+ if not condition.item():
396
+ raise RuntimeError(
397
+ "Uncached ShramForCausalLM forward does not support nonzero "
398
+ "starting positions. Either provide a ShramCache populated "
399
+ "with the prefix for continued decoding, or rebase the "
400
+ "uncached sequence to start at 0.",
401
+ )
402
+
403
+ @staticmethod
404
+ def _enforce_capture_scalar_outputs() -> None:
405
+ """Enforce that capture_scalar_outputs is enabled when compiling.
406
+
407
+ The safety checks in this model (e.g. position-zero constraint, packing
408
+ overflow detection) rely on torch._check folding into the compiled graph,
409
+ which requires torch._dynamo.config.capture_scalar_outputs = True. Without
410
+ it those checks are silently absent in the compiled model while appearing
411
+ to work in eager mode — a misconfiguration with no diagnostic output.
412
+
413
+ This method fires during dynamo tracing so the missing flag is surfaced
414
+ immediately at compile time rather than discovered from downstream failures.
415
+ """
416
+ if torch.compiler.is_compiling():
417
+ torch._check(
418
+ torch._dynamo.config.capture_scalar_outputs,
419
+ lambda: RuntimeError(
420
+ "ShramForCausalLM requires torch._dynamo.config.capture_scalar_outputs = True "
421
+ "when compiled. Without it, runtime safety checks (position constraints, "
422
+ "overflow detection) are silently absent in the compiled model. Set the flag "
423
+ "before calling torch.compile()."
424
+ ),
425
+ )
426
+
427
  def _standardize_full_attention_mask(
428
  self,
429
  input_ids: torch.Tensor,
 
521
  # This keeps the main sequence readable while ensuring invalid states
522
  # fail before they can silently contaminate backbone execution.
523
  # ------------------------------------------------------------------
524
+ self._enforce_capture_scalar_outputs()
525
  self._validate_input_ids(input_ids)
526
  self._validate_attention_mask(input_ids, attention_mask)
527
  self._validate_position_ids(input_ids, position_ids)
 
560
  )
561
  shram_cache: ShramCache | None = past_key_values if use_cache else None
562
 
563
+ if shram_cache is None:
564
+ positions_start_sane = torch.all(current_position_ids[:, 0] == 0)
565
+ self._enforce_uncached_starting_position(positions_start_sane)
566
+
567
  # ------------------------------------------------------------------
568
  # Core wrapper responsibilities.
569
  #
mlp.py CHANGED
@@ -36,9 +36,9 @@ class SwiGLUMLP(nn.Module):
36
 
37
  def __init__(self, config: PretrainedConfig) -> None:
38
  super().__init__()
39
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
40
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
41
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
42
 
43
  def forward(self, x: torch.Tensor) -> torch.Tensor:
44
  """Apply the SwiGLU feed-forward transformation.
 
36
 
37
  def __init__(self, config: PretrainedConfig) -> None:
38
  super().__init__()
39
+ self.gate_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False)
40
+ self.up_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False)
41
+ self.down_proj = nn.Linear(config.mlp_width, config.embedding_width, bias=False)
42
 
43
  def forward(self, x: torch.Tensor) -> torch.Tensor:
44
  """Apply the SwiGLU feed-forward transformation.
model.py CHANGED
@@ -58,9 +58,9 @@ class ShramModel(nn.Module):
58
  super().__init__()
59
  self.config = config
60
  self.layers = nn.ModuleList(
61
- [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
62
  )
63
- self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
 
65
  def num_mosrah_parameters(self) -> int:
66
  """Return the total number of trainable MoSRAH parameters across all decoder layers."""
 
58
  super().__init__()
59
  self.config = config
60
  self.layers = nn.ModuleList(
61
+ [DecoderLayer(config) for _ in range(config.num_decoder_layers)]
62
  )
63
+ self.norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
64
 
65
  def num_mosrah_parameters(self) -> int:
66
  """Return the total number of trainable MoSRAH parameters across all decoder layers."""
rope.py CHANGED
@@ -26,10 +26,10 @@ Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explic
26
  parameters — no shared instance, no config reading. See Unit 5.A design decisions.
27
 
28
  Cache sharing: all instances with identical parameters share one cos/sin table via a
29
- class-level registry. The first instance that needs a particular (parameters, seq_len,
30
- device, dtype) combination builds the table; all subsequent instances reference it
31
- directly. This avoids redundant builds across the num_hidden_layers instances that
32
- share the same parametrisation.
33
  """
34
 
35
  import math
@@ -66,17 +66,21 @@ class RotaryEmbedding(nn.Module):
66
  h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
67
  config object is read inside this module.
68
 
69
- The cos/sin cache is built lazily on the first forward call and extended
70
- automatically when a longer sequence is encountered. Instances with identical
71
- parameters share one cache via the class-level ``_cache`` registry,
72
- avoiding redundant computation across decoder layers.
 
 
73
 
74
  Args:
75
  mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
76
  head_dim: Per-head embedding dimension ``u``. Must be even.
77
  theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
78
- initial_seq_length: ``C_train`` context length the model was trained at.
79
- Required for ``mode="yarn"``.
 
 
80
  dilation: Scale factor ``s = C_target / C_train`` — how much the context
81
  window is extended beyond training length. Required for ``mode="yarn"``.
82
  When ``dilation=1.0``, YaRN reduces to standard RoPE.
@@ -88,11 +92,11 @@ class RotaryEmbedding(nn.Module):
88
 
89
  Raises:
90
  NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
91
- ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``,
92
- ``dilation``, ``alpha``, ``beta`` are absent.
93
  """
94
 
95
- # Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table).
96
  # Shared across all RotaryEmbedding instances in the process. Keys include device
97
  # and dtype so that tables built on different devices or in different precisions
98
  # are stored independently.
@@ -103,7 +107,7 @@ class RotaryEmbedding(nn.Module):
103
  mode: str,
104
  head_dim: int,
105
  theta: float,
106
- initial_seq_length: int | None = None,
107
  dilation: float | None = None,
108
  alpha: float | None = None,
109
  beta: float | None = None,
@@ -112,8 +116,9 @@ class RotaryEmbedding(nn.Module):
112
  super().__init__()
113
 
114
  self._validate_mode(mode)
115
- self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta)
116
  self.mode = mode
 
117
 
118
  # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
119
  # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
@@ -128,9 +133,14 @@ class RotaryEmbedding(nn.Module):
128
  else: # yarn
129
  s = dilation
130
 
 
 
 
 
 
131
  # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
132
  # function to classify each dimension into one of three regimes.
133
- normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi)
134
 
135
  # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
136
  # linear blend between α and β.
@@ -142,16 +152,13 @@ class RotaryEmbedding(nn.Module):
142
  # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
143
  self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
144
 
145
- # freq_key uniquely identifies the parameter set that produced rotation_freqs.
146
- # Used as the primary component of the cache registry key.
 
147
  if mode == "default":
148
- self._freq_key: tuple = ("default", head_dim, float(theta))
149
  else:
150
- self._freq_key = (
151
- "yarn", head_dim, float(theta),
152
- int(initial_seq_length), float(dilation),
153
- float(alpha), float(beta),
154
- )
155
 
156
  # rotation_freqs is a non-persistent buffer so it moves with the model across
157
  # devices via .to() / .cuda() without appearing in saved checkpoints.
@@ -167,6 +174,11 @@ class RotaryEmbedding(nn.Module):
167
  self._cos_cached: torch.Tensor | None = None
168
  self._sin_cached: torch.Tensor | None = None
169
 
 
 
 
 
 
170
  # ---------------------------------------------------------------------------
171
  # Validation helpers
172
  # ---------------------------------------------------------------------------
@@ -182,7 +194,6 @@ class RotaryEmbedding(nn.Module):
182
  @staticmethod
183
  def _validate_yarn_params(
184
  mode: str,
185
- initial_seq_length: int | None,
186
  dilation: float | None,
187
  alpha: float | None,
188
  beta: float | None,
@@ -192,7 +203,6 @@ class RotaryEmbedding(nn.Module):
192
  return
193
  missing = [
194
  name for name, val in [
195
- ("initial_seq_length", initial_seq_length),
196
  ("dilation", dilation),
197
  ("alpha", alpha),
198
  ("beta", beta),
@@ -206,20 +216,23 @@ class RotaryEmbedding(nn.Module):
206
  # Cache management
207
  # ---------------------------------------------------------------------------
208
 
209
- def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
210
- """Build the cos/sin table to cover positions [0, seq_len).
211
 
212
  Checks the class-level registry first. If a table already exists for this
213
- exact (parameters, seq_len, device, dtype) combination it is reused directly;
214
  otherwise it is computed and stored. The instance attributes are pointed at
215
  the registry entry so that all layers sharing the same parametrisation
216
  reference the same tensor.
217
  """
218
- cache_key = (self._freq_key, seq_len, str(device), str(dtype))
219
 
220
  if cache_key not in RotaryEmbedding._cache:
221
- positions = torch.arange(seq_len, device=device, dtype=torch.float32)
222
- # outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim)
 
 
 
223
  freqs = torch.outer(
224
  positions,
225
  self.rotation_freqs.to(device=device, dtype=torch.float32),
@@ -240,11 +253,12 @@ class RotaryEmbedding(nn.Module):
240
  ) -> tuple[torch.Tensor, torch.Tensor, float]:
241
  """Apply rotary embeddings to query and key tensors.
242
 
243
- The cos/sin cache is extended lazily when position_ids reference positions
244
- beyond its current length, or when the device or dtype has changed.
 
245
 
246
- ``position_ids`` may be any integer tensor shape. Its values are valid
247
- position indices into the cos/sin cache:
248
 
249
  - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
250
  - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
@@ -262,18 +276,11 @@ class RotaryEmbedding(nn.Module):
262
  1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
263
  apply to attention logits before softmax.
264
  """
265
- seq_len = int(position_ids.max().item()) + 1
266
-
267
- # The cache is valid when it exists, covers all positions referenced by
268
- # position_ids, and matches q's dtype and device. Each condition is named
269
- # separately so the rebuild trigger is readable rather than a compound predicate.
270
- cache_missing = self._cos_cached is None
271
- cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
272
- wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
273
- wrong_device = not cache_missing and self._cos_cached.device != q.device
274
-
275
- if cache_missing or cache_too_short or wrong_dtype or wrong_device:
276
- self._extend_cache(seq_len, device=q.device, dtype=q.dtype)
277
 
278
  cos = self._cos_cached[position_ids]
279
  sin = self._sin_cached[position_ids]
 
26
  parameters — no shared instance, no config reading. See Unit 5.A design decisions.
27
 
28
  Cache sharing: all instances with identical parameters share one cos/sin table via a
29
+ class-level registry. The first instance that needs a particular (parameters, device,
30
+ dtype) combination builds the table; all subsequent instances reference it directly.
31
+ This avoids redundant builds across the num_hidden_layers instances that share the
32
+ same parametrisation.
33
  """
34
 
35
  import math
 
66
  h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
67
  config object is read inside this module.
68
 
69
+ The cos/sin table is built at construction time to cover all positions in
70
+ ``[0, maximum_sequence_length)``. In forward, the table is rebuilt only if
71
+ the query tensor's dtype or device has changed since construction.
72
+
73
+ Instances with identical parameters share one cos/sin table via the class-level
74
+ ``_cache`` registry, avoiding redundant computation across decoder layers.
75
 
76
  Args:
77
  mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
78
  head_dim: Per-head embedding dimension ``u``. Must be even.
79
  theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
80
+ maximum_sequence_length: Maximum number of positions the table must cover.
81
+ The cos/sin table is preallocated to this length at construction time.
82
+ For ``mode="yarn"``, the training context length C_train is derived
83
+ internally as ``round(maximum_sequence_length / dilation)``.
84
  dilation: Scale factor ``s = C_target / C_train`` — how much the context
85
  window is extended beyond training length. Required for ``mode="yarn"``.
86
  When ``dilation=1.0``, YaRN reduces to standard RoPE.
 
92
 
93
  Raises:
94
  NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
95
+ ValueError: If ``mode="yarn"`` and any of ``dilation``, ``alpha``,
96
+ ``beta`` are absent.
97
  """
98
 
99
+ # Maps (freq_key, device_str, dtype_str) → (cos_table, sin_table).
100
  # Shared across all RotaryEmbedding instances in the process. Keys include device
101
  # and dtype so that tables built on different devices or in different precisions
102
  # are stored independently.
 
107
  mode: str,
108
  head_dim: int,
109
  theta: float,
110
+ maximum_sequence_length: int,
111
  dilation: float | None = None,
112
  alpha: float | None = None,
113
  beta: float | None = None,
 
116
  super().__init__()
117
 
118
  self._validate_mode(mode)
119
+ self._validate_yarn_params(mode, dilation, alpha, beta)
120
  self.mode = mode
121
+ self._maximum_sequence_length = maximum_sequence_length
122
 
123
  # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
124
  # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
 
133
  else: # yarn
134
  s = dilation
135
 
136
+ # C_train is the training context length, recovered from the inference
137
+ # context length and the dilation factor. round() guards against floating
138
+ # point error since both underlying quantities are integers.
139
+ c_train: int = round(maximum_sequence_length / dilation)
140
+
141
  # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
142
  # function to classify each dimension into one of three regimes.
143
+ normalized_freqs = c_train * base_freqs / (2.0 * math.pi)
144
 
145
  # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
146
  # linear blend between α and β.
 
152
  # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
153
  self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
154
 
155
+ # freq_key uniquely identifies the parameter set that produced rotation_freqs,
156
+ # including maximum_sequence_length so instances with different table sizes
157
+ # do not collide in the registry.
158
  if mode == "default":
159
+ self._freq_key: tuple = ("default", head_dim, theta, maximum_sequence_length)
160
  else:
161
+ self._freq_key = ("yarn", head_dim, theta, maximum_sequence_length, dilation, alpha, beta)
 
 
 
 
162
 
163
  # rotation_freqs is a non-persistent buffer so it moves with the model across
164
  # devices via .to() / .cuda() without appearing in saved checkpoints.
 
174
  self._cos_cached: torch.Tensor | None = None
175
  self._sin_cached: torch.Tensor | None = None
176
 
177
+ # Build the table at construction time. Forward rebuilds only on dtype or
178
+ # device change. If no device is specified, build on CPU as the default.
179
+ build_device = device if device is not None else torch.device("cpu")
180
+ self._build_cache(device=build_device, dtype=torch.float32)
181
+
182
  # ---------------------------------------------------------------------------
183
  # Validation helpers
184
  # ---------------------------------------------------------------------------
 
194
  @staticmethod
195
  def _validate_yarn_params(
196
  mode: str,
 
197
  dilation: float | None,
198
  alpha: float | None,
199
  beta: float | None,
 
203
  return
204
  missing = [
205
  name for name, val in [
 
206
  ("dilation", dilation),
207
  ("alpha", alpha),
208
  ("beta", beta),
 
216
  # Cache management
217
  # ---------------------------------------------------------------------------
218
 
219
+ def _build_cache(self, device: torch.device, dtype: torch.dtype) -> None:
220
+ """Build the cos/sin table to cover positions [0, maximum_sequence_length).
221
 
222
  Checks the class-level registry first. If a table already exists for this
223
+ exact (parameters, device, dtype) combination it is reused directly;
224
  otherwise it is computed and stored. The instance attributes are pointed at
225
  the registry entry so that all layers sharing the same parametrisation
226
  reference the same tensor.
227
  """
228
+ cache_key = (self._freq_key, str(device), str(dtype))
229
 
230
  if cache_key not in RotaryEmbedding._cache:
231
+ positions = torch.arange(
232
+ self._maximum_sequence_length, device=device, dtype=torch.float32
233
+ )
234
+ # outer product → (maximum_sequence_length, head_dim // 2);
235
+ # duplicate to (maximum_sequence_length, head_dim)
236
  freqs = torch.outer(
237
  positions,
238
  self.rotation_freqs.to(device=device, dtype=torch.float32),
 
253
  ) -> tuple[torch.Tensor, torch.Tensor, float]:
254
  """Apply rotary embeddings to query and key tensors.
255
 
256
+ The cos/sin table is built at construction time. It is rebuilt here only
257
+ if ``q``'s dtype or device differs from the cached table for example,
258
+ after moving the model to a different device via ``.cuda()``.
259
 
260
+ ``position_ids`` may be any integer tensor shape. Its values must be in
261
+ ``[0, maximum_sequence_length)``:
262
 
263
  - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
264
  - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
 
276
  1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
277
  apply to attention logits before softmax.
278
  """
279
+ wrong_dtype = self._cos_cached.dtype != q.dtype
280
+ wrong_device = self._cos_cached.device != q.device
281
+
282
+ if wrong_dtype or wrong_device:
283
+ self._build_cache(device=q.device, dtype=q.dtype)
 
 
 
 
 
 
 
284
 
285
  cos = self._cos_cached[position_ids]
286
  sin = self._sin_cached[position_ids]