smithblack-0 commited on
Commit
4fedccb
·
verified ·
1 Parent(s): 28f3eff

Update architecture and tokenizer

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. configuration.py +262 -262
  3. huggingface.py +164 -203
README.md CHANGED
@@ -48,7 +48,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
48
  config = AutoConfig.from_pretrained(
49
  "smithblack-0/SHRAM-dev",
50
  trust_remote_code=True,
51
- num_hidden_layers=16, # example override
52
  num_mosrah_heads=32, # example override
53
  )
54
 
 
48
  config = AutoConfig.from_pretrained(
49
  "smithblack-0/SHRAM-dev",
50
  trust_remote_code=True,
51
+ num_decoder_layers=16, # example override
52
  num_mosrah_heads=32, # example override
53
  )
54
 
configuration.py CHANGED
@@ -1,262 +1,262 @@
1
- """Configuration for the SHRAM transformer.
2
-
3
- All architectural parameters that vary across model scales or are meaningful research
4
- variables are expressed here. Architectural constants (no bias in linear layers,
5
- SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
- documented at the point of use — they are not config parameters because they do not
7
- vary and changing them produces a different architecture, not a different scale.
8
-
9
- RoPE configuration is owned entirely by this config. Each attention path reads its
10
- parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
- HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
- """
13
-
14
- import math
15
-
16
- from transformers import PretrainedConfig
17
-
18
-
19
- class ShramConfig(PretrainedConfig):
20
- """Configuration class for the SHRAM decoder-only transformer.
21
-
22
- SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
23
- attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
24
- local sliding-window causal attention path and h_s is the MoSRAH sparse routed
25
- path. All other components follow the Llama 3 baseline.
26
-
27
- This config is the single source of truth for every architectural dimension of the
28
- model. Nothing in the architecture may use a literal number that belongs here.
29
-
30
- Two independent RoPE configurations exist — one per attention path:
31
-
32
- - h_l always uses standard RoPE with ``local_rope_theta``.
33
- - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
34
- ``inference_sequence_length``, ``alpha``, and ``beta``. When
35
- ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
36
- ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
37
- and the correct setting for experiments that do not require context extension.
38
-
39
- Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
40
-
41
- config = AutoConfig.from_pretrained(
42
- "your-namespace/advanced-transformers-lib",
43
- trust_remote_code=True,
44
- num_hidden_layers=12,
45
- )
46
- model = AutoModelForCausalLM.from_config(config)
47
-
48
- Args:
49
- vocab_size: Vocabulary size. Controls the embedding table and output logits
50
- dimension. Must match the tokenizer.
51
- embedding_width: Model width ``d``. The dimension of the residual stream.
52
- mlp_width: FFN hidden dimension.
53
- num_decoder_layers: Number of transformer blocks stacked in sequence.
54
- num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
55
- num_mosrah_heads: Total MoSRAH expert heads available ``L``.
56
- num_selected_heads: MoSRAH heads each token selects ``K``.
57
- head_dim: Per-head dimension, shared by both attention paths. Must be even
58
- (RoPE rotates dimensions in pairs). Paper uses 16.
59
- window_size: Sliding window size for h_l. Paper uses 128.
60
- rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
61
- original sequence positions; ``"semantic_sequence"`` supplies local slot
62
- indices. Both are required; experimentally correct mode is undetermined
63
- (paper §4). Default ``"main_sequence"``.
64
- rms_norm_eps: Epsilon for RMSNorm layers.
65
- local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
66
- Paper uses b=10000.
67
- mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
68
- b=10000.
69
- training_sequence_length: Context length ``C_train`` the model was or will be
70
- trained at. Used to compute the YaRN scale factor for BEA.
71
- inference_sequence_length: Context length ``C_target`` the model must support
72
- at inference. Optional; defaults to ``training_sequence_length`` so that
73
- ``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
74
- alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
75
- ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
76
- beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
77
- ``r(d) > beta`` are left unscaled. Paper value: 32.0.
78
- attention_dropout: Dropout probability on attention weights. Default 0.0.
79
- use_cache: Whether to return past_key_values for KV caching.
80
- output_hidden_states: Whether to return hidden states after each layer.
81
- tie_word_embeddings: Whether input embedding and LM head share weights.
82
- mosrah_overallocation_factor: Overallocation multiplier for the expert packing
83
- buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
84
- num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
85
- Must be > 1.0 to guarantee a buffer larger than the balanced-routing
86
- baseline. Default 2.0.
87
- load_balance_p: Exponent p for the p-mean aggregation of per-item routing
88
- frequencies into the load balance signal. Higher p weights aggregation
89
- toward the worst-case batch item, making the correction signal more
90
- sensitive to per-item allocation spikes. Must be positive. Default 2.0.
91
- max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
92
- solver in ``balance_capacity``. 10 covers convergence at approximately
93
- the 98th percentile of routing densities; the top 2% of extreme-density
94
- cases are not expected under normal training. The bound exists as a
95
- correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
96
- Default 10.
97
- """
98
-
99
- model_type = "shram"
100
-
101
- auto_map = {
102
- "AutoConfig": "configuration.ShramConfig",
103
- "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
104
- }
105
-
106
- def __init__(
107
- self,
108
- vocab_size: int = 50277,
109
- embedding_width: int = 512,
110
- mlp_width: int = 1366,
111
- num_decoder_layers: int = 12,
112
- num_sliding_window_heads: int = 16,
113
- num_mosrah_heads: int = 16,
114
- num_selected_heads: int = 16,
115
- head_dim: int = 16,
116
- window_size: int = 128,
117
- rope_mode: str = "main_sequence",
118
- rms_norm_eps: float = 1e-5,
119
- local_rope_theta: float = 10000.0,
120
- mosrah_rope_theta: float = 10000.0,
121
- training_sequence_length: int = 1024,
122
- inference_sequence_length: int | None = None,
123
- alpha: float = 1.0,
124
- beta: float = 32.0,
125
- attention_dropout: float = 0.0,
126
- use_cache: bool = True,
127
- output_hidden_states: bool = False,
128
- tie_word_embeddings: bool = False,
129
- mosrah_overallocation_factor: float = 2.0,
130
- load_balance_p: float = 2.0,
131
- max_bid_rounds: int = 10,
132
- **kwargs
133
- ):
134
- if head_dim % 2 != 0:
135
- raise ValueError(
136
- f"head_dim must be even (RoPE rotates dimensions in pairs). "
137
- f"Got head_dim={head_dim}."
138
- )
139
-
140
- if rope_mode not in {"main_sequence", "semantic_sequence"}:
141
- raise ValueError(
142
- f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
143
- f"got '{rope_mode}'."
144
- )
145
-
146
- if training_sequence_length <= 0:
147
- raise ValueError(
148
- f"training_sequence_length must be positive, "
149
- f"got {training_sequence_length}."
150
- )
151
-
152
- if inference_sequence_length is None:
153
- inference_sequence_length = training_sequence_length
154
- if inference_sequence_length <= 0:
155
- raise ValueError(
156
- f"inference_sequence_length must be positive, "
157
- f"got {inference_sequence_length}."
158
- )
159
-
160
- if mosrah_overallocation_factor <= 1.0:
161
- raise ValueError(
162
- f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
163
- f"buffer larger than the balanced-routing baseline. "
164
- f"Got {mosrah_overallocation_factor}."
165
- )
166
-
167
- if load_balance_p <= 0.0:
168
- raise ValueError(
169
- f"load_balance_p must be positive, got {load_balance_p}."
170
- )
171
-
172
- if max_bid_rounds < 1:
173
- raise ValueError(
174
- f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
175
- )
176
-
177
- self.vocab_size = vocab_size
178
- self.embedding_width = embedding_width
179
- self.mlp_width = mlp_width
180
- self.num_decoder_layers = num_decoder_layers
181
- self.num_sliding_window_heads = num_sliding_window_heads
182
- self.num_mosrah_heads = num_mosrah_heads
183
- self.num_selected_heads = num_selected_heads
184
- self.head_dim = head_dim
185
- self.window_size = window_size
186
- self.rope_mode = rope_mode
187
- self.rms_norm_eps = rms_norm_eps
188
- self.local_rope_theta = local_rope_theta
189
- self.mosrah_rope_theta = mosrah_rope_theta
190
- self.training_sequence_length = training_sequence_length
191
- self.inference_sequence_length = inference_sequence_length
192
- self.alpha = alpha
193
- self.beta = beta
194
- self.mosrah_overallocation_factor = mosrah_overallocation_factor
195
- self.load_balance_p = load_balance_p
196
- self.max_bid_rounds = max_bid_rounds
197
- self.attention_dropout = attention_dropout
198
- self.use_cache = use_cache
199
-
200
- super().__init__(
201
- tie_word_embeddings=tie_word_embeddings,
202
- output_hidden_states=output_hidden_states,
203
- **kwargs
204
- )
205
-
206
- # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
207
- # serialises it into config.json.
208
- self.auto_map = type(self).auto_map
209
-
210
- @property
211
- def scale(self) -> float:
212
- """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
213
-
214
- When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
215
- adjustments cancel and A_rope = 1. This is the default state.
216
- """
217
- return self.inference_sequence_length / self.training_sequence_length
218
-
219
- @property
220
- def mosrah_packed_length(self) -> int:
221
- """Static packed time dimension T for expert packing.
222
-
223
- The expected tokens per expert under perfectly balanced routing is
224
- ``training_sequence_length * num_selected_heads / num_mosrah_heads``.
225
- Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
226
- that baseline. The ceiling ensures T is always an integer >= 1.
227
-
228
- All consumers of the packed buffer size must read this property rather
229
- than deriving T independently.
230
- """
231
- return math.ceil(
232
- self.training_sequence_length
233
- * self.num_selected_heads
234
- / self.num_mosrah_heads
235
- * self.mosrah_overallocation_factor
236
- )
237
-
238
- @property
239
- def mosrah_cache_length(self) -> int:
240
- """Static per-(batch, head) slot capacity for the MoSRAH inference cache.
241
-
242
- The expected tokens per expert over the full inference context under perfectly
243
- balanced routing is ``inference_sequence_length * num_selected_heads /
244
- num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
245
- a buffer above that baseline. The ceiling ensures the result is always an
246
- integer >= 1.
247
-
248
- Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
249
- using ``training_sequence_length``. This property uses
250
- ``inference_sequence_length`` because the cache must hold the full accumulated
251
- token history across the entire inference run.
252
-
253
- All consumers of the MoSRAH cache buffer size must read this property rather
254
- than deriving the capacity independently.
255
- """
256
- return math.ceil(
257
- self.inference_sequence_length
258
- * self.num_selected_heads
259
- / self.num_mosrah_heads
260
- * self.mosrah_overallocation_factor
261
- )
262
-
 
1
+ """Configuration for the SHRAM transformer.
2
+
3
+ All architectural parameters that vary across model scales or are meaningful research
4
+ variables are expressed here. Architectural constants (no bias in linear layers,
5
+ SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
+ documented at the point of use — they are not config parameters because they do not
7
+ vary and changing them produces a different architecture, not a different scale.
8
+
9
+ RoPE configuration is owned entirely by this config. Each attention path reads its
10
+ parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
+ HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
+ """
13
+
14
+ import math
15
+
16
+ from transformers import PretrainedConfig
17
+
18
+
19
+ class ShramConfig(PretrainedConfig):
20
+ """Configuration class for the SHRAM decoder-only transformer.
21
+
22
+ SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
23
+ attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
24
+ local sliding-window causal attention path and h_s is the MoSRAH sparse routed
25
+ path. All other components follow the Llama 3 baseline.
26
+
27
+ This config is the single source of truth for every architectural dimension of the
28
+ model. Nothing in the architecture may use a literal number that belongs here.
29
+
30
+ Two independent RoPE configurations exist — one per attention path:
31
+
32
+ - h_l always uses standard RoPE with ``local_rope_theta``.
33
+ - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
34
+ ``inference_sequence_length``, ``alpha``, and ``beta``. When
35
+ ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
36
+ ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
37
+ and the correct setting for experiments that do not require context extension.
38
+
39
+ Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
40
+
41
+ config = AutoConfig.from_pretrained(
42
+ "your-namespace/advanced-transformers-lib",
43
+ trust_remote_code=True,
44
+ num_decoder_layers=12,
45
+ )
46
+ model = AutoModelForCausalLM.from_config(config)
47
+
48
+ Args:
49
+ vocab_size: Vocabulary size. Controls the embedding table and output logits
50
+ dimension. Must match the tokenizer.
51
+ embedding_width: Model width ``d``. The dimension of the residual stream.
52
+ mlp_width: FFN hidden dimension.
53
+ num_decoder_layers: Number of transformer blocks stacked in sequence.
54
+ num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
55
+ num_mosrah_heads: Total MoSRAH expert heads available ``L``.
56
+ num_selected_heads: MoSRAH heads each token selects ``K``.
57
+ head_dim: Per-head dimension, shared by both attention paths. Must be even
58
+ (RoPE rotates dimensions in pairs). Paper uses 16.
59
+ window_size: Sliding window size for h_l. Paper uses 128.
60
+ rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
61
+ original sequence positions; ``"semantic_sequence"`` supplies local slot
62
+ indices. Both are required; experimentally correct mode is undetermined
63
+ (paper §4). Default ``"main_sequence"``.
64
+ rms_norm_eps: Epsilon for RMSNorm layers.
65
+ local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
66
+ Paper uses b=10000.
67
+ mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
68
+ b=10000.
69
+ training_sequence_length: Context length ``C_train`` the model was or will be
70
+ trained at. Used to compute the YaRN scale factor for BEA.
71
+ inference_sequence_length: Context length ``C_target`` the model must support
72
+ at inference. Optional; defaults to ``training_sequence_length`` so that
73
+ ``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
74
+ alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
75
+ ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
76
+ beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
77
+ ``r(d) > beta`` are left unscaled. Paper value: 32.0.
78
+ attention_dropout: Dropout probability on attention weights. Default 0.0.
79
+ use_cache: Whether to return past_key_values for KV caching.
80
+ output_hidden_states: Whether to return hidden states after each layer.
81
+ tie_word_embeddings: Whether input embedding and LM head share weights.
82
+ mosrah_overallocation_factor: Overallocation multiplier for the expert packing
83
+ buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
84
+ num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
85
+ Must be > 1.0 to guarantee a buffer larger than the balanced-routing
86
+ baseline. Default 2.0.
87
+ load_balance_p: Exponent p for the p-mean aggregation of per-item routing
88
+ frequencies into the load balance signal. Higher p weights aggregation
89
+ toward the worst-case batch item, making the correction signal more
90
+ sensitive to per-item allocation spikes. Must be positive. Default 2.0.
91
+ max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
92
+ solver in ``balance_capacity``. 10 covers convergence at approximately
93
+ the 98th percentile of routing densities; the top 2% of extreme-density
94
+ cases are not expected under normal training. The bound exists as a
95
+ correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
96
+ Default 10.
97
+ """
98
+
99
+ model_type = "shram"
100
+
101
+ auto_map = {
102
+ "AutoConfig": "configuration.ShramConfig",
103
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
104
+ }
105
+
106
+ def __init__(
107
+ self,
108
+ vocab_size: int = 50277,
109
+ embedding_width: int = 512,
110
+ mlp_width: int = 1366,
111
+ num_decoder_layers: int = 12,
112
+ num_sliding_window_heads: int = 16,
113
+ num_mosrah_heads: int = 16,
114
+ num_selected_heads: int = 16,
115
+ head_dim: int = 16,
116
+ window_size: int = 128,
117
+ rope_mode: str = "main_sequence",
118
+ rms_norm_eps: float = 1e-5,
119
+ local_rope_theta: float = 10000.0,
120
+ mosrah_rope_theta: float = 10000.0,
121
+ training_sequence_length: int = 1024,
122
+ inference_sequence_length: int | None = None,
123
+ alpha: float = 1.0,
124
+ beta: float = 32.0,
125
+ attention_dropout: float = 0.0,
126
+ use_cache: bool = True,
127
+ output_hidden_states: bool = False,
128
+ tie_word_embeddings: bool = False,
129
+ mosrah_overallocation_factor: float = 2.0,
130
+ load_balance_p: float = 2.0,
131
+ max_bid_rounds: int = 10,
132
+ **kwargs
133
+ ):
134
+ if head_dim % 2 != 0:
135
+ raise ValueError(
136
+ f"head_dim must be even (RoPE rotates dimensions in pairs). "
137
+ f"Got head_dim={head_dim}."
138
+ )
139
+
140
+ if rope_mode not in {"main_sequence", "semantic_sequence"}:
141
+ raise ValueError(
142
+ f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
143
+ f"got '{rope_mode}'."
144
+ )
145
+
146
+ if training_sequence_length <= 0:
147
+ raise ValueError(
148
+ f"training_sequence_length must be positive, "
149
+ f"got {training_sequence_length}."
150
+ )
151
+
152
+ if inference_sequence_length is None:
153
+ inference_sequence_length = training_sequence_length
154
+ if inference_sequence_length <= 0:
155
+ raise ValueError(
156
+ f"inference_sequence_length must be positive, "
157
+ f"got {inference_sequence_length}."
158
+ )
159
+
160
+ if mosrah_overallocation_factor <= 1.0:
161
+ raise ValueError(
162
+ f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
163
+ f"buffer larger than the balanced-routing baseline. "
164
+ f"Got {mosrah_overallocation_factor}."
165
+ )
166
+
167
+ if load_balance_p <= 0.0:
168
+ raise ValueError(
169
+ f"load_balance_p must be positive, got {load_balance_p}."
170
+ )
171
+
172
+ if max_bid_rounds < 1:
173
+ raise ValueError(
174
+ f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
175
+ )
176
+
177
+ self.vocab_size = vocab_size
178
+ self.embedding_width = embedding_width
179
+ self.mlp_width = mlp_width
180
+ self.num_decoder_layers = num_decoder_layers
181
+ self.num_sliding_window_heads = num_sliding_window_heads
182
+ self.num_mosrah_heads = num_mosrah_heads
183
+ self.num_selected_heads = num_selected_heads
184
+ self.head_dim = head_dim
185
+ self.window_size = window_size
186
+ self.rope_mode = rope_mode
187
+ self.rms_norm_eps = rms_norm_eps
188
+ self.local_rope_theta = local_rope_theta
189
+ self.mosrah_rope_theta = mosrah_rope_theta
190
+ self.training_sequence_length = training_sequence_length
191
+ self.inference_sequence_length = inference_sequence_length
192
+ self.alpha = alpha
193
+ self.beta = beta
194
+ self.mosrah_overallocation_factor = mosrah_overallocation_factor
195
+ self.load_balance_p = load_balance_p
196
+ self.max_bid_rounds = max_bid_rounds
197
+ self.attention_dropout = attention_dropout
198
+ self.use_cache = use_cache
199
+
200
+ super().__init__(
201
+ tie_word_embeddings=tie_word_embeddings,
202
+ output_hidden_states=output_hidden_states,
203
+ **kwargs
204
+ )
205
+
206
+ # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
207
+ # serialises it into config.json.
208
+ self.auto_map = type(self).auto_map
209
+
210
+ @property
211
+ def scale(self) -> float:
212
+ """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
213
+
214
+ When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
215
+ adjustments cancel and A_rope = 1. This is the default state.
216
+ """
217
+ return self.inference_sequence_length / self.training_sequence_length
218
+
219
+ @property
220
+ def mosrah_packed_length(self) -> int:
221
+ """Static packed time dimension T for expert packing.
222
+
223
+ The expected tokens per expert under perfectly balanced routing is
224
+ ``training_sequence_length * num_selected_heads / num_mosrah_heads``.
225
+ Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
226
+ that baseline. The ceiling ensures T is always an integer >= 1.
227
+
228
+ All consumers of the packed buffer size must read this property rather
229
+ than deriving T independently.
230
+ """
231
+ return math.ceil(
232
+ self.training_sequence_length
233
+ * self.num_selected_heads
234
+ / self.num_mosrah_heads
235
+ * self.mosrah_overallocation_factor
236
+ )
237
+
238
+ @property
239
+ def mosrah_cache_length(self) -> int:
240
+ """Static per-(batch, head) slot capacity for the MoSRAH inference cache.
241
+
242
+ The expected tokens per expert over the full inference context under perfectly
243
+ balanced routing is ``inference_sequence_length * num_selected_heads /
244
+ num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
245
+ a buffer above that baseline. The ceiling ensures the result is always an
246
+ integer >= 1.
247
+
248
+ Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
249
+ using ``training_sequence_length``. This property uses
250
+ ``inference_sequence_length`` because the cache must hold the full accumulated
251
+ token history across the entire inference run.
252
+
253
+ All consumers of the MoSRAH cache buffer size must read this property rather
254
+ than deriving the capacity independently.
255
+ """
256
+ return math.ceil(
257
+ self.inference_sequence_length
258
+ * self.num_selected_heads
259
+ / self.num_mosrah_heads
260
+ * self.mosrah_overallocation_factor
261
+ )
262
+
huggingface.py CHANGED
@@ -128,7 +128,7 @@ class ShramConfig(PretrainedConfig):
128
  config = AutoConfig.from_pretrained(
129
  "your-namespace/advanced-transformers-lib",
130
  trust_remote_code=True,
131
- num_hidden_layers=12,
132
  )
133
  model = AutoModelForCausalLM.from_config(config)
134
 
@@ -725,17 +725,21 @@ class MoSRAHCache(CacheLayerMixin):
725
  def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
726
  """Raise if any (batch, head) slot would exceed the static buffer capacity.
727
 
728
- Uses the 19.F.1 pattern: branches on whether the graph is being compiled.
729
- In compiled mode, `.item()` folds into the graph when capture_scalar_outputs=True
730
- and `torch._check` issues a compile-time assertion. In eager mode, a plain
731
- RuntimeError is raised with a descriptive message.
732
 
733
  Args:
734
  max_count: Scalar tensor — the maximum post-update count across all slots.
735
  capacity: The static buffer capacity (mosrah_cache_length).
736
  """
737
  if torch.compiler.is_compiling():
738
- torch._check(max_count.item() <= capacity)
 
 
 
 
739
  else:
740
  if max_count.item() > capacity:
741
  raise RuntimeError(
@@ -856,7 +860,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
856
  # Cumulative count of all token positions presented through update() for
857
  # this cache instance. This is the quantity HuggingFace generation reads
858
  # through get_seq_length() to track how far along the sequence we are.
859
- self._total_processed: int = 0
860
 
861
  def update( # type: ignore[override]
862
  self,
@@ -996,7 +1000,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
996
  generation reads to track sequence progress and is not the same as active-token
997
  count or current window occupancy.
998
  """
999
- return self._total_processed
1000
 
1001
  def get_max_cache_shape(self) -> int:
1002
  return self.sliding_window
@@ -2299,29 +2303,24 @@ class BottleneckedEnsembleAttention(nn.Module):
2299
  # -----------
2300
  """Expert packing and unpacking for the MoSRAH path.
2301
 
2302
- This module implements the low-level token-choice -> expert-choice -> token-choice
2303
- conversion boundary specified in the paper. The externally visible behavior is fixed:
2304
-
2305
- - setup_packing() prepares the auxiliary ordering data and returns it as a dict
2306
- payload forwarded whole to pack_experts and unpack_experts.
2307
- - pack_experts() converts a dict of routed token-choice tensors into packed
2308
- expert-choice form. Each entry is paired with its intended padding value; all
2309
- entries undergo the same expert-major gather-scatter so they remain aligned.
2310
- - unpack_experts() restores token-choice ordering afterward.
2311
-
2312
- Stable sort is a correctness requirement. It preserves causal ordering inside each
2313
- expert bucket, which is the foundation on which BEA's later triangular causal mask
2314
- is correct.
2315
-
2316
- pack_experts() returns the packed entries dict together with a separate unpacking_mask.
2317
- Two masks serve different roles and must not be interchanged:
2318
-
2319
- - unpacking_mask: marks every packed slot that contains a routed token copy,
2320
- live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
2321
- so its reshape invariant holds regardless of outer token liveness.
2322
- - active_mask (caller-supplied entry): marks only the packed slots whose source
2323
- token was semantically live. This is what BEA consumes for attention gating.
2324
- Dead outer tokens must not influence sparse attention outputs.
2325
  """
2326
 
2327
 
@@ -2337,23 +2336,13 @@ def setup_packing(
2337
  ) -> dict[str, torch.Tensor]:
2338
  """Prepare the auxiliary ordering data used by pack/unpack.
2339
 
2340
- Routing produces token-choice state I of shape (B, N, K): for each token, which
2341
- K experts were selected. Packing needs the same routed token copies reordered into
2342
- expert-major order so each expert bucket becomes contiguous.
2343
-
2344
- The paper's setup step does this by flattening (N, K) into one axis to produce
2345
- H in token-major order, then computing a stable argsort permutation Pi over the
2346
- expert indices stored in H. Applying Pi reorders the flattened routed copies into
2347
- expert-major order while preserving their original token order *within* each expert
2348
- bucket. That preservation is why stable sort is required for causality.
2349
-
2350
  Args:
2351
  selected_heads: Routed token-choice head selections I of shape (B, N, K).
2352
 
2353
  Returns:
2354
  Auxiliary payload dict with keys:
2355
  - "flattened_selected_heads": H of shape (B, N*K)
2356
- - "permutation": stable expert-major permutation Pi of shape (B, N*K)
2357
  - "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K)
2358
  This dict is forwarded whole to pack_experts and unpack_experts.
2359
  """
@@ -2362,7 +2351,14 @@ def setup_packing(
2362
  batch_size,
2363
  sequence_length * num_selected_heads,
2364
  )
2365
- num_elements = batch_size*sequence_length*num_selected_heads
 
 
 
 
 
 
 
2366
  permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
2367
  inverse_permutation = torch.argsort(permutation, dim=-1)
2368
 
@@ -2370,7 +2366,6 @@ def setup_packing(
2370
  "flattened_selected_heads": flattened_selected_heads,
2371
  "permutation": permutation,
2372
  "inverse_permutation": inverse_permutation,
2373
- "num_elements" : num_elements,
2374
  }
2375
 
2376
 
@@ -2387,20 +2382,6 @@ def pack_experts(
2387
  ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
2388
  """Pack token-choice tensors into expert-choice padded form.
2389
 
2390
- The paper's packing path has two jobs:
2391
-
2392
- 1. Convert routed token-choice copies into expert-major order.
2393
- 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
2394
-
2395
- All entries in the provided dict undergo the same expert-major gather-scatter so
2396
- they remain mutually aligned in the packed frame. Each entry is paired with its
2397
- intended padding value, which fills slots that contain no routed token copy.
2398
-
2399
- Packed positions are sourced from the authoritative upstream position_ids tensor
2400
- rather than synthesized locally from arange(N). This preserves advanced positions
2401
- correctly during cached inference while leaving training/full-sequence behavior
2402
- unchanged when position_ids is the ordinary sequential token positions.
2403
-
2404
  Args:
2405
  entries: Mapping from string keys to (tensor, padding_value) pairs. Each
2406
  tensor has shape (B, N, ...) and is rearranged into expert-choice layout
@@ -2409,29 +2390,40 @@ def pack_experts(
2409
  selected_heads: Routed head selections I of shape (B, N, K).
2410
  num_experts: Total number of experts L.
2411
  packed_length: Static packed time dimension T. All per-expert buffers are
2412
- allocated to exactly this length. Use config.mosrah_packed_length as the
2413
- source of this value. Raises if any actual per-expert token count exceeds
2414
- this value.
2415
 
2416
  Returns:
2417
  Tuple of:
2418
  - packed_entries: Dict with same keys as entries; each value is the
2419
  packed tensor of shape (B, L, T, ...).
2420
- - unpacking_mask: Boolean tensor of shape (B, L, T). True where a slot
2421
- contains any routed token copy, live or dead. Always has exactly
2422
- B*N*K True entries. Pass this to unpack_experts — not active_mask.
2423
  """
2424
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
 
 
2425
 
2426
  flattened_selected_heads = setup["flattened_selected_heads"]
2427
  permutation = setup["permutation"]
2428
 
2429
  # -----------------------------------------------------------------------
2430
- # Reconstruct routed local source-token indices in token-choice order.
 
 
 
 
 
 
 
 
 
 
2431
  #
2432
- # The internal arange(N) is only the local source-row index object used to
2433
- # gather from the current chunk tensors. Flattening gives a (B, N*K) tensor
2434
- # aligned with H's token-major routed-copy order.
2435
  # -----------------------------------------------------------------------
2436
  source_token_indices = torch.arange(
2437
  sequence_length,
@@ -2442,81 +2434,91 @@ def pack_experts(
2442
  sequence_length,
2443
  num_selected_heads,
2444
  )
2445
- flattened_source_indices = source_token_indices.reshape(
2446
  batch_size,
2447
- sequence_length * num_selected_heads,
2448
  )
2449
-
2450
- # -----------------------------------------------------------------------
2451
- # Reorder source-token indices into expert-major order.
2452
- #
2453
- # Applying Pi yields the local source-token rows in the packed expert-major
2454
- # order required by the paper. All entries are then gathered using these same
2455
- # reordered indices so they remain aligned under the exact same transformation.
2456
- # -----------------------------------------------------------------------
2457
- sorted_source_indices = flattened_source_indices.gather(
2458
  dim=1,
2459
  index=permutation,
2460
  )
2461
 
2462
  # -----------------------------------------------------------------------
2463
- # Count how many routed copies land in each expert bucket and verify
2464
- # that no bucket exceeds the statically preallocated packed_length T.
2465
  #
2466
- # S[b, l] is the number of routed token copies assigned to expert l in
2467
- # batch b. T (packed_length) is a static allocation derived from config,
2468
- # not a data-dependent maximum. Overflow is detected here and raises in
2469
- # both eager and compiled modes.
2470
  # -----------------------------------------------------------------------
2471
  tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts)
2472
- max_count = tokens_per_expert.max().item()
2473
- no_overflow = max_count <= packed_length
2474
- _enforce_no_overflow(no_overflow, tokens_per_expert, packed_length)
2475
 
2476
  # -----------------------------------------------------------------------
2477
- # Construct the unpacking mask.
2478
  #
2479
- # Each expert bucket is left-justified: if S[b, l] = s, then slots
2480
- # t = 0, ..., s-1 are occupied and all later slots are padding. The mask
2481
- # marks slot occupancy regardless of outer token liveness, and always has
2482
- # exactly B*N*K True entries.
 
2483
  # -----------------------------------------------------------------------
2484
- time_axis = torch.arange(
2485
- packed_length,
 
 
 
 
 
 
 
 
 
 
 
2486
  device=flattened_selected_heads.device,
2487
  dtype=torch.long,
2488
- ).view(1, 1, packed_length)
2489
- unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
 
 
2490
 
2491
  # -----------------------------------------------------------------------
2492
- # Materialize all entries into the packed expert-choice frame.
2493
  #
2494
- # Each entry is gathered using the expert-major sorted source indices, then
2495
- # scattered into a padded buffer. The gather index is expanded to cover each
2496
- # tensor's trailing dimensions. Padding slots receive the caller-supplied fill
2497
- # value rather than an implicit zero.
2498
  # -----------------------------------------------------------------------
2499
  packed_entries: dict[str, torch.Tensor] = {}
2500
  for key, (tensor, padding_value) in entries.items():
2501
  extra_shape = tensor.shape[2:]
2502
 
2503
- # Expand gather index to cover trailing dimensions, if any.
2504
- idx = sorted_source_indices.view(
 
 
2505
  batch_size,
2506
- sequence_length * num_selected_heads,
2507
  *(1,) * len(extra_shape),
2508
  ).expand(-1, -1, *extra_shape)
2509
- sorted_tensor = tensor.gather(dim=1, index=idx)
2510
 
2511
  packed_tensor = tensor.new_full(
2512
- (batch_size, num_experts, packed_length, *extra_shape),
2513
  fill_value=padding_value,
2514
  )
 
 
 
 
 
 
 
 
 
 
2515
 
2516
- packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
2517
- packed_entries[key] = packed_tensor
2518
-
2519
- return packed_entries, unpacking_mask
2520
 
2521
 
2522
  # ---------------------------------------------------------------------------
@@ -2526,27 +2528,17 @@ def pack_experts(
2526
  def unpack_experts(
2527
  expert_outputs: torch.Tensor,
2528
  setup: dict[str, torch.Tensor],
2529
- unpacking_mask: torch.Tensor,
2530
  selected_heads: torch.Tensor,
2531
  ) -> torch.Tensor:
2532
  """Restore token-choice ordering from BEA expert-choice output.
2533
 
2534
- Unpacking inverts the packing path only on occupied entries. Padding does not
2535
- participate: the output tensor is first filtered by unpacking_mask to recover
2536
- only the real routed-token copies in expert-major order, then Pi^{-1} restores
2537
- the original token-choice ordering, and finally the tensor is reshaped back to
2538
- (B, N, K, d).
2539
-
2540
- The unpacking_mask — not active_mask — must be used here. Even copies of dead
2541
- outer tokens occupy slots and must be un-scattered correctly for the inverse
2542
- permutation to hold. The total True entry count in unpacking_mask is always
2543
- B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
2544
-
2545
  Args:
2546
  expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
2547
  setup: Auxiliary payload returned by setup_packing().
2548
- unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
2549
- occupied packed slots regardless of outer token liveness.
 
2550
  selected_heads: Routed head selections I of shape (B, N, K).
2551
 
2552
  Returns:
@@ -2555,22 +2547,22 @@ def unpack_experts(
2555
  inverse_permutation = setup["inverse_permutation"]
2556
 
2557
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
 
2558
  hidden_dim = expert_outputs.shape[-1]
2559
 
2560
- coords = torch.nonzero_static(
2561
- unpacking_mask,
2562
- size=setup["num_elements"],
2563
- ) # shape: (B*N*K, 3)
2564
-
2565
- active_outputs = expert_outputs[
2566
- coords[:, 0],
2567
- coords[:, 1],
2568
- coords[:, 2],
2569
- ] # shape: (B*N*K, d)
2570
 
2571
- sorted_token_choice_outputs = active_outputs.reshape(
2572
  batch_size,
2573
- sequence_length * num_selected_heads,
2574
  hidden_dim,
2575
  )
2576
  restored_outputs = sorted_token_choice_outputs.gather(
@@ -2589,34 +2581,34 @@ def unpack_experts(
2589
  # Helpers
2590
  # ---------------------------------------------------------------------------
2591
 
2592
- def _enforce_no_overflow(condition: bool, tokens_per_expert, max_length) -> None:
2593
- """Enforce that no expert bucket exceeds the preallocated packed length.
2594
 
2595
- This check fires when the number of tokens assigned to any expert in any
2596
- batch item exceeds mosrah_packed_length. When that limit is exceeded, the
2597
- packed buffer is too small to hold all assignments and data would be dropped.
2598
- Increase mosrah_overallocation_factor in ShramConfig to resolve.
2599
 
2600
- The caller must derive condition via .item() on the max count tensor so that
2601
- dynamo captures a SymInt and the comparison produces a SymBool. Passing a
2602
- tensor comparison result directly bypasses the SymInt mechanism and prevents
2603
- the check from firing at compiled runtime.
2604
 
2605
  Args:
2606
- condition: True means no overflow has occurred; False means at least one
2607
- expert bucket exceeds packed_length. In compiled mode this is a SymBool
2608
- produced by comparing a SymInt against the static packed_length.
2609
  """
2610
  if torch.compiler.is_compiling():
2611
- torch._check(condition)
 
 
 
 
2612
  else:
2613
- if not condition:
 
2614
  raise RuntimeError(
2615
  "Expert packing overflow: at least one expert bucket contains more "
2616
  "tokens than mosrah_packed_length allows. Increase "
2617
  "mosrah_overallocation_factor in ShramConfig to resolve.\n"
2618
- f"Supported lengths were:\n {max_length}\n"
2619
- f"head lengths were:\n {tokens_per_expert}\n"
2620
  )
2621
 
2622
 
@@ -2626,8 +2618,7 @@ def _count_tokens_per_expert(
2626
  ) -> torch.Tensor:
2627
  """Count how many routed token copies are assigned to each expert per batch item.
2628
 
2629
- Uses scatter_add into a pre-sized (B, num_experts) zero buffer, producing a
2630
- statically-shaped output that compiles without graph breaks. Each position in
2631
  flattened_selected_heads contributes one count to the corresponding expert slot.
2632
 
2633
  Args:
@@ -2639,19 +2630,18 @@ def _count_tokens_per_expert(
2639
  Counts tensor of shape (B, num_experts).
2640
  """
2641
  batch_size = flattened_selected_heads.shape[0]
2642
- counts = torch.zeros(
2643
  batch_size,
2644
  num_experts,
2645
  device=flattened_selected_heads.device,
2646
- dtype=flattened_selected_heads.dtype,
2647
  )
2648
- counts.scatter_add_(
2649
  dim=1,
2650
  index=flattened_selected_heads,
2651
- src=torch.ones_like(flattened_selected_heads),
2652
  )
2653
- return counts
2654
-
2655
  # -----------
2656
  # Inlined from: router.py
2657
  # -----------
@@ -2825,7 +2815,7 @@ class MoSRAHRouter(nn.Module):
2825
  self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
2826
 
2827
  @staticmethod
2828
- def get_mask(
2829
  tensor: torch.Tensor,
2830
  dim: int,
2831
  n: int | torch.Tensor,
@@ -2958,7 +2948,7 @@ class MoSRAHRouter(nn.Module):
2958
  choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
2959
 
2960
  unproposed_logits = logits.masked_fill(proposals, float('-inf'))
2961
- new_proposals = cls.get_mask(
2962
  unproposed_logits, dim=-1, n=choices_deficit, capacity_scalar=min_choices,
2963
  )
2964
  proposals = proposals | new_proposals
@@ -2969,7 +2959,7 @@ class MoSRAHRouter(nn.Module):
2969
  # Acceptances are recomputed from scratch each round so that a
2970
  # stronger new proposal can displace a weaker prior one.
2971
  proposed_logits = logits.masked_fill(~proposals, float('-inf'))
2972
- acceptances = cls.get_mask(
2973
  proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
2974
  )
2975
 
@@ -3351,7 +3341,7 @@ class MoSRAHLayer(nn.Module):
3351
  "position_ids": (position_ids, 0),
3352
  "active_mask": (active_mask, False),
3353
  }
3354
- packed, unpacking_mask = pack_experts(entries, setup, selected_heads, self.num_experts, self.packed_length)
3355
  packed_hidden_states = packed["hidden_states"]
3356
  packed_positions = packed["position_ids"]
3357
  active_mask = packed["active_mask"]
@@ -3387,7 +3377,7 @@ class MoSRAHLayer(nn.Module):
3387
  token_choice_outputs = unpack_experts(
3388
  expert_outputs=packed_outputs,
3389
  setup=setup,
3390
- unpacking_mask=unpacking_mask,
3391
  selected_heads=selected_heads,
3392
  )
3393
  final_output = (
@@ -3886,11 +3876,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
3886
 
3887
  @staticmethod
3888
  def create_masks_for_generate(
3889
- config: Any,
3890
- inputs_embeds: torch.Tensor,
3891
  attention_mask: torch.Tensor | None,
3892
- past_key_values: Cache | None,
3893
- position_ids: torch.Tensor | None = None,
3894
  **kwargs: Any,
3895
  ) -> torch.Tensor | None:
3896
  """Return the 2D attention_mask unchanged.
@@ -3944,7 +3930,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
3944
  raise ValueError(
3945
  "position_ids must match the current input_ids shape exactly."
3946
  )
3947
- if input_ids.dtype != torch.long:
3948
  raise TypeError("position_ids must be an long tensor.")
3949
 
3950
  def _validate_labels(
@@ -3959,7 +3945,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
3959
  raise ValueError("labels must have shape (batch, seq_len).")
3960
  if labels.shape != input_ids.shape:
3961
  raise ValueError("labels must have the same shape as input_ids.")
3962
- if input_ids.dtype != torch.long:
3963
  raise TypeError("labels must be a long tensor.")
3964
 
3965
  def _validate_cache_inputs(
@@ -4044,11 +4030,11 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4044
  (violated).
4045
  """
4046
  if torch.compiler.is_compiling():
4047
- # bool.item() is not captured as a SymBool by dynamo; converting to
4048
- # int first produces a SymInt, and the Python comparison (!=0) then
4049
- # yields a SymBool that torch._check folds into the compiled graph.
4050
- condition_as_int = condition.to(torch.int).item()
4051
- torch._check(condition_as_int != 0)
4052
  else:
4053
  if not condition.item():
4054
  raise RuntimeError(
@@ -4058,30 +4044,6 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4058
  "uncached sequence to start at 0.",
4059
  )
4060
 
4061
- @staticmethod
4062
- def _enforce_capture_scalar_outputs() -> None:
4063
- """Enforce that capture_scalar_outputs is enabled when compiling.
4064
-
4065
- The safety checks in this model (e.g. position-zero constraint, packing
4066
- overflow detection) rely on torch._check folding into the compiled graph,
4067
- which requires torch._dynamo.config.capture_scalar_outputs = True. Without
4068
- it those checks are silently absent in the compiled model while appearing
4069
- to work in eager mode — a misconfiguration with no diagnostic output.
4070
-
4071
- This method fires during dynamo tracing so the missing flag is surfaced
4072
- immediately at compile time rather than discovered from downstream failures.
4073
- """
4074
- if torch.compiler.is_compiling():
4075
- torch._check(
4076
- torch._dynamo.config.capture_scalar_outputs,
4077
- lambda: RuntimeError(
4078
- "ShramForCausalLM requires torch._dynamo.config.capture_scalar_outputs = True "
4079
- "when compiled. Without it, runtime safety checks (position constraints, "
4080
- "overflow detection) are silently absent in the compiled model. Set the flag "
4081
- "before calling torch.compile()."
4082
- ),
4083
- )
4084
-
4085
  def _standardize_full_attention_mask(
4086
  self,
4087
  input_ids: torch.Tensor,
@@ -4179,7 +4141,6 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4179
  # This keeps the main sequence readable while ensuring invalid states
4180
  # fail before they can silently contaminate backbone execution.
4181
  # ------------------------------------------------------------------
4182
- self._enforce_capture_scalar_outputs()
4183
  self._validate_input_ids(input_ids)
4184
  self._validate_attention_mask(input_ids, attention_mask)
4185
  self._validate_position_ids(input_ids, position_ids)
 
128
  config = AutoConfig.from_pretrained(
129
  "your-namespace/advanced-transformers-lib",
130
  trust_remote_code=True,
131
+ num_decoder_layers=12,
132
  )
133
  model = AutoModelForCausalLM.from_config(config)
134
 
 
725
  def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
726
  """Raise if any (batch, head) slot would exceed the static buffer capacity.
727
 
728
+ Branches on whether the graph is being compiled. In compiled mode,
729
+ torch._assert_async fires asynchronously on the GPU when the condition
730
+ tensor is False. In eager mode, a plain RuntimeError is raised with a
731
+ descriptive message.
732
 
733
  Args:
734
  max_count: Scalar tensor — the maximum post-update count across all slots.
735
  capacity: The static buffer capacity (mosrah_cache_length).
736
  """
737
  if torch.compiler.is_compiling():
738
+ torch._assert_async(
739
+ max_count <= capacity,
740
+ "MoSRAHCache overflow: buffer capacity exceeded. "
741
+ "Increase mosrah_overallocation_factor in ShramConfig.",
742
+ )
743
  else:
744
  if max_count.item() > capacity:
745
  raise RuntimeError(
 
860
  # Cumulative count of all token positions presented through update() for
861
  # this cache instance. This is the quantity HuggingFace generation reads
862
  # through get_seq_length() to track how far along the sequence we are.
863
+ self._total_processed = torch.tensor(0)
864
 
865
  def update( # type: ignore[override]
866
  self,
 
1000
  generation reads to track sequence progress and is not the same as active-token
1001
  count or current window occupancy.
1002
  """
1003
+ return int(self._total_processed)
1004
 
1005
  def get_max_cache_shape(self) -> int:
1006
  return self.sliding_window
 
2303
  # -----------
2304
  """Expert packing and unpacking for the MoSRAH path.
2305
 
2306
+ This module owns the token-choice -> expert-choice -> token-choice conversion
2307
+ boundary used by the sparse routed attention path. Its public behavior is fixed:
2308
+
2309
+ - setup_packing() prepares the auxiliary ordering data forwarded through packing
2310
+ and unpacking.
2311
+ - pack_experts() converts routed token-choice tensors into padded expert-choice
2312
+ tensors.
2313
+ - unpack_experts() restores token-choice ordering from padded expert-choice output.
2314
+
2315
+ Packed expert-choice tensors are expert-major and left-justified. For each expert,
2316
+ routed token copies occupy the prefix of that expert's packed block; padding occupies
2317
+ the suffix. Every packed entry uses the same ordering and transfer artifact, so
2318
+ hidden states, positions, masks, and probabilities remain aligned across the boundary.
2319
+
2320
+ pack_experts() returns a flat transfer index together with the packed entries. This
2321
+ index replaces the old boolean unpacking artifact as the source of truth for
2322
+ pack/unpack data movement: packing writes to those flat packed slots, and unpacking
2323
+ reads from those same slots.
 
 
 
 
 
2324
  """
2325
 
2326
 
 
2336
  ) -> dict[str, torch.Tensor]:
2337
  """Prepare the auxiliary ordering data used by pack/unpack.
2338
 
 
 
 
 
 
 
 
 
 
 
2339
  Args:
2340
  selected_heads: Routed token-choice head selections I of shape (B, N, K).
2341
 
2342
  Returns:
2343
  Auxiliary payload dict with keys:
2344
  - "flattened_selected_heads": H of shape (B, N*K)
2345
+ - "permutation": expert-major permutation Pi of shape (B, N*K)
2346
  - "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K)
2347
  This dict is forwarded whole to pack_experts and unpack_experts.
2348
  """
 
2351
  batch_size,
2352
  sequence_length * num_selected_heads,
2353
  )
2354
+
2355
+ # -----------------------------------------------------------------------
2356
+ # Establish the expert-major ordering invariant.
2357
+ #
2358
+ # BEA later applies a triangular causal mask inside each expert bucket. That
2359
+ # mask is only meaningful if routed copies for the same expert preserve their
2360
+ # source-token order. Stable sorting by selected head establishes that order.
2361
+ # -----------------------------------------------------------------------
2362
  permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
2363
  inverse_permutation = torch.argsort(permutation, dim=-1)
2364
 
 
2366
  "flattened_selected_heads": flattened_selected_heads,
2367
  "permutation": permutation,
2368
  "inverse_permutation": inverse_permutation,
 
2369
  }
2370
 
2371
 
 
2382
  ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
2383
  """Pack token-choice tensors into expert-choice padded form.
2384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2385
  Args:
2386
  entries: Mapping from string keys to (tensor, padding_value) pairs. Each
2387
  tensor has shape (B, N, ...) and is rearranged into expert-choice layout
 
2390
  selected_heads: Routed head selections I of shape (B, N, K).
2391
  num_experts: Total number of experts L.
2392
  packed_length: Static packed time dimension T. All per-expert buffers are
2393
+ allocated to exactly this length. Raises if any actual per-expert token
2394
+ count exceeds this value.
 
2395
 
2396
  Returns:
2397
  Tuple of:
2398
  - packed_entries: Dict with same keys as entries; each value is the
2399
  packed tensor of shape (B, L, T, ...).
2400
+ - flat_packed_transfer_indices: Long tensor of shape (B*N*K,). Each value
2401
+ is the flattened padded expert-choice slot occupied by the corresponding
2402
+ routed-copy row. Pass this to unpack_experts().
2403
  """
2404
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
2405
+ num_routed_copies_per_batch = sequence_length * num_selected_heads
2406
+ num_routed_copies = batch_size * num_routed_copies_per_batch
2407
 
2408
  flattened_selected_heads = setup["flattened_selected_heads"]
2409
  permutation = setup["permutation"]
2410
 
2411
  # -----------------------------------------------------------------------
2412
+ # Algorithm overview.
2413
+ #
2414
+ # Packing first builds one routed-copy row for each selected token/expert
2415
+ # pair, ordered by the stable expert-major permutation. Those rows contain
2416
+ # no padding. The final packed tensor reserves packed_length slots per expert.
2417
+ # The flat transfer index bridges those layouts by adding back the cumulative
2418
+ # padding skipped before each expert block.
2419
+ # -----------------------------------------------------------------------
2420
+
2421
+ # -----------------------------------------------------------------------
2422
+ # Build the shared routed-copy source rows.
2423
  #
2424
+ # This tensor identifies the source token row for each selected token/expert
2425
+ # pair after the stable expert-major permutation. Every packed entry uses this
2426
+ # same row plan, so all entries remain aligned before padded materialization.
2427
  # -----------------------------------------------------------------------
2428
  source_token_indices = torch.arange(
2429
  sequence_length,
 
2434
  sequence_length,
2435
  num_selected_heads,
2436
  )
2437
+ flattened_source_token_indices = source_token_indices.reshape(
2438
  batch_size,
2439
+ num_routed_copies_per_batch,
2440
  )
2441
+ sorted_source_token_indices = flattened_source_token_indices.gather(
 
 
 
 
 
 
 
 
2442
  dim=1,
2443
  index=permutation,
2444
  )
2445
 
2446
  # -----------------------------------------------------------------------
2447
+ # Establish packed expert occupancy and capacity.
 
2448
  #
2449
+ # tokens_per_expert tells how many routed-copy rows occupy the prefix of each
2450
+ # expert block. The padded layout is valid only when every prefix fits inside
2451
+ # the configured packed_length.
 
2452
  # -----------------------------------------------------------------------
2453
  tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts)
2454
+ _enforce_no_overflow(tokens_per_expert, packed_length)
 
 
2455
 
2456
  # -----------------------------------------------------------------------
2457
+ # Build the flat insertion points for the padded expert frame.
2458
  #
2459
+ # Routed-copy rows omit padding, while the packed frame reserves packed_length
2460
+ # slots for every expert. The transfer index adds back the cumulative padding
2461
+ # skipped before each expert block, producing one flat destination slot for
2462
+ # every routed-copy row. This tensor is forwarded to unpack_experts so removal
2463
+ # uses the same positions that insertion used.
2464
  # -----------------------------------------------------------------------
2465
+ flat_tokens_per_expert = tokens_per_expert.reshape(-1)
2466
+ flat_padding_per_expert = packed_length - flat_tokens_per_expert
2467
+ flat_padding_before_expert = (
2468
+ flat_padding_per_expert.cumsum(dim=0) - flat_padding_per_expert
2469
+ )
2470
+
2471
+ flat_padding_for_routed_rows = torch.repeat_interleave(
2472
+ flat_padding_before_expert,
2473
+ flat_tokens_per_expert,
2474
+ output_size=num_routed_copies,
2475
+ )
2476
+ flat_routed_row_indices = torch.arange(
2477
+ num_routed_copies,
2478
  device=flattened_selected_heads.device,
2479
  dtype=torch.long,
2480
+ )
2481
+ flat_packed_transfer_indices = (
2482
+ flat_routed_row_indices + flat_padding_for_routed_rows
2483
+ )
2484
 
2485
  # -----------------------------------------------------------------------
2486
+ # Materialize each entry through the shared routing and transfer artifacts.
2487
  #
2488
+ # Each entry first gathers into the shared routed-copy order. The flat packed
2489
+ # allocation supplies padding, and the transfer index writes each routed-copy
2490
+ # row into its padded expert slot before the public shape is restored.
 
2491
  # -----------------------------------------------------------------------
2492
  packed_entries: dict[str, torch.Tensor] = {}
2493
  for key, (tensor, padding_value) in entries.items():
2494
  extra_shape = tensor.shape[2:]
2495
 
2496
+ # The sorted source index is shared across all entries; expanding it over
2497
+ # trailing dimensions lets the same routing/order plan apply to hidden
2498
+ # states, positions, masks, probabilities, and any other packed tensor.
2499
+ sorted_gather_indices = sorted_source_token_indices.view(
2500
  batch_size,
2501
+ num_routed_copies_per_batch,
2502
  *(1,) * len(extra_shape),
2503
  ).expand(-1, -1, *extra_shape)
2504
+ sorted_tensor = tensor.gather(dim=1, index=sorted_gather_indices)
2505
 
2506
  packed_tensor = tensor.new_full(
2507
+ (batch_size * num_experts * packed_length, *extra_shape),
2508
  fill_value=padding_value,
2509
  )
2510
+ packed_tensor[flat_packed_transfer_indices] = sorted_tensor.reshape(
2511
+ num_routed_copies,
2512
+ *extra_shape,
2513
+ )
2514
+ packed_entries[key] = packed_tensor.reshape(
2515
+ batch_size,
2516
+ num_experts,
2517
+ packed_length,
2518
+ *extra_shape,
2519
+ )
2520
 
2521
+ return packed_entries, flat_packed_transfer_indices
 
 
 
2522
 
2523
 
2524
  # ---------------------------------------------------------------------------
 
2528
  def unpack_experts(
2529
  expert_outputs: torch.Tensor,
2530
  setup: dict[str, torch.Tensor],
2531
+ flat_packed_transfer_indices: torch.Tensor,
2532
  selected_heads: torch.Tensor,
2533
  ) -> torch.Tensor:
2534
  """Restore token-choice ordering from BEA expert-choice output.
2535
 
 
 
 
 
 
 
 
 
 
 
 
2536
  Args:
2537
  expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
2538
  setup: Auxiliary payload returned by setup_packing().
2539
+ flat_packed_transfer_indices: Transfer index returned by pack_experts().
2540
+ Each value identifies a routed-copy slot in the flattened padded
2541
+ expert-choice frame.
2542
  selected_heads: Routed head selections I of shape (B, N, K).
2543
 
2544
  Returns:
 
2547
  inverse_permutation = setup["inverse_permutation"]
2548
 
2549
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
2550
+ num_routed_copies_per_batch = sequence_length * num_selected_heads
2551
  hidden_dim = expert_outputs.shape[-1]
2552
 
2553
+ # -----------------------------------------------------------------------
2554
+ # Recover routed-copy rows from the same packed slots used at insertion.
2555
+ #
2556
+ # Packing writes into the forwarded flat slots, and unpacking reads from those
2557
+ # same slots before applying the inverse routing permutation back to
2558
+ # token-choice order.
2559
+ # -----------------------------------------------------------------------
2560
+ flat_expert_outputs = expert_outputs.reshape(-1, hidden_dim)
2561
+ flat_routed_copy_outputs = flat_expert_outputs[flat_packed_transfer_indices]
 
2562
 
2563
+ sorted_token_choice_outputs = flat_routed_copy_outputs.reshape(
2564
  batch_size,
2565
+ num_routed_copies_per_batch,
2566
  hidden_dim,
2567
  )
2568
  restored_outputs = sorted_token_choice_outputs.gather(
 
2581
  # Helpers
2582
  # ---------------------------------------------------------------------------
2583
 
 
 
2584
 
2585
+ def _enforce_no_overflow(tokens_per_expert: torch.Tensor, packed_length: int) -> None:
2586
+ """Enforce that no expert bucket exceeds the preallocated packed length.
 
 
2587
 
2588
+ This check fires when the number of tokens assigned to any expert in any batch
2589
+ item exceeds mosrah_packed_length. When that limit is exceeded, the packed buffer
2590
+ is too small to hold all assignments and data would be dropped. Increase
2591
+ mosrah_overallocation_factor in ShramConfig to resolve.
2592
 
2593
  Args:
2594
+ tokens_per_expert: Per-expert token counts, shape (B, num_experts).
2595
+ packed_length: The preallocated packed time dimension.
 
2596
  """
2597
  if torch.compiler.is_compiling():
2598
+ torch._assert_async(
2599
+ tokens_per_expert.max() <= packed_length,
2600
+ "Expert packing overflow: expert bucket exceeds mosrah_packed_length. "
2601
+ "Increase mosrah_overallocation_factor in ShramConfig.",
2602
+ )
2603
  else:
2604
+ max_count = tokens_per_expert.max().item()
2605
+ if max_count > packed_length:
2606
  raise RuntimeError(
2607
  "Expert packing overflow: at least one expert bucket contains more "
2608
  "tokens than mosrah_packed_length allows. Increase "
2609
  "mosrah_overallocation_factor in ShramConfig to resolve.\n"
2610
+ f"Packed length: {packed_length}\n"
2611
+ f"Head lengths: {tokens_per_expert}\n"
2612
  )
2613
 
2614
 
 
2618
  ) -> torch.Tensor:
2619
  """Count how many routed token copies are assigned to each expert per batch item.
2620
 
2621
+ Uses scatter_add into a pre-sized (B, num_experts) buffer. Each position in
 
2622
  flattened_selected_heads contributes one count to the corresponding expert slot.
2623
 
2624
  Args:
 
2630
  Counts tensor of shape (B, num_experts).
2631
  """
2632
  batch_size = flattened_selected_heads.shape[0]
2633
+ tokens_per_expert = torch.zeros(
2634
  batch_size,
2635
  num_experts,
2636
  device=flattened_selected_heads.device,
2637
+ dtype=torch.long,
2638
  )
2639
+ tokens_per_expert.scatter_add_(
2640
  dim=1,
2641
  index=flattened_selected_heads,
2642
+ src=torch.ones_like(flattened_selected_heads, dtype=torch.long),
2643
  )
2644
+ return tokens_per_expert
 
2645
  # -----------
2646
  # Inlined from: router.py
2647
  # -----------
 
2815
  self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
2816
 
2817
  @staticmethod
2818
+ def get_best_proposals(
2819
  tensor: torch.Tensor,
2820
  dim: int,
2821
  n: int | torch.Tensor,
 
2948
  choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
2949
 
2950
  unproposed_logits = logits.masked_fill(proposals, float('-inf'))
2951
+ new_proposals = cls.get_best_proposals(
2952
  unproposed_logits, dim=-1, n=choices_deficit, capacity_scalar=min_choices,
2953
  )
2954
  proposals = proposals | new_proposals
 
2959
  # Acceptances are recomputed from scratch each round so that a
2960
  # stronger new proposal can displace a weaker prior one.
2961
  proposed_logits = logits.masked_fill(~proposals, float('-inf'))
2962
+ acceptances = cls.get_best_proposals(
2963
  proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
2964
  )
2965
 
 
3341
  "position_ids": (position_ids, 0),
3342
  "active_mask": (active_mask, False),
3343
  }
3344
+ packed, unpacking_map = pack_experts(entries, setup, selected_heads, self.num_experts, self.packed_length)
3345
  packed_hidden_states = packed["hidden_states"]
3346
  packed_positions = packed["position_ids"]
3347
  active_mask = packed["active_mask"]
 
3377
  token_choice_outputs = unpack_experts(
3378
  expert_outputs=packed_outputs,
3379
  setup=setup,
3380
+ flat_packed_transfer_indices=unpacking_map,
3381
  selected_heads=selected_heads,
3382
  )
3383
  final_output = (
 
3876
 
3877
  @staticmethod
3878
  def create_masks_for_generate(
 
 
3879
  attention_mask: torch.Tensor | None,
 
 
3880
  **kwargs: Any,
3881
  ) -> torch.Tensor | None:
3882
  """Return the 2D attention_mask unchanged.
 
3930
  raise ValueError(
3931
  "position_ids must match the current input_ids shape exactly."
3932
  )
3933
+ if position_ids.dtype != torch.long:
3934
  raise TypeError("position_ids must be an long tensor.")
3935
 
3936
  def _validate_labels(
 
3945
  raise ValueError("labels must have shape (batch, seq_len).")
3946
  if labels.shape != input_ids.shape:
3947
  raise ValueError("labels must have the same shape as input_ids.")
3948
+ if labels.dtype != torch.long:
3949
  raise TypeError("labels must be a long tensor.")
3950
 
3951
  def _validate_cache_inputs(
 
4030
  (violated).
4031
  """
4032
  if torch.compiler.is_compiling():
4033
+ torch._assert_async(
4034
+ condition,
4035
+ "Uncached ShramForCausalLM: nonzero starting positions. "
4036
+ "Supply a ShramCache with prefix or rebase sequence to start at 0.",
4037
+ )
4038
  else:
4039
  if not condition.item():
4040
  raise RuntimeError(
 
4044
  "uncached sequence to start at 0.",
4045
  )
4046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4047
  def _standardize_full_attention_mask(
4048
  self,
4049
  input_ids: torch.Tensor,
 
4141
  # This keeps the main sequence readable while ensuring invalid states
4142
  # fail before they can silently contaminate backbone execution.
4143
  # ------------------------------------------------------------------
 
4144
  self._validate_input_ids(input_ids)
4145
  self._validate_attention_mask(input_ids, attention_mask)
4146
  self._validate_position_ids(input_ids, position_ids)