smithblack-0 commited on
Commit
a15e620
·
verified ·
1 Parent(s): e56b8fd

Update architecture and tokenizer

Browse files
Files changed (4) hide show
  1. README.md +0 -4
  2. config.json +1 -5
  3. configuration.py +243 -284
  4. huggingface.py +0 -0
README.md CHANGED
@@ -82,12 +82,8 @@ contains no weights. All values are overridable via kwargs.
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
- | `load_balance_loss_type` | causal_overcapacity |
86
  | `local_rope_theta` | 10000.0 |
87
- | `max_bid_rounds` | 10 |
88
- | `maximum_expert_overclaim` | 20 |
89
  | `mlp_width` | 1366 |
90
- | `mosrah_overallocation_factor` | 2.0 |
91
  | `mosrah_rope_theta` | 10000.0 |
92
  | `num_decoder_layers` | 12 |
93
  | `num_mosrah_heads` | 16 |
 
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
 
85
  | `local_rope_theta` | 10000.0 |
 
 
86
  | `mlp_width` | 1366 |
 
87
  | `mosrah_rope_theta` | 10000.0 |
88
  | `num_decoder_layers` | 12 |
89
  | `num_mosrah_heads` | 16 |
config.json CHANGED
@@ -9,13 +9,9 @@
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
- "load_balance_loss_type": "causal_overcapacity",
13
  "local_rope_theta": 10000.0,
14
- "max_bid_rounds": 10,
15
- "maximum_expert_overclaim": 20,
16
  "mlp_width": 1366,
17
  "model_type": "shram",
18
- "mosrah_overallocation_factor": 2.0,
19
  "mosrah_rope_theta": 10000.0,
20
  "num_decoder_layers": 12,
21
  "num_mosrah_heads": 16,
@@ -25,7 +21,7 @@
25
  "rope_mode": "main_sequence",
26
  "tie_word_embeddings": false,
27
  "training_sequence_length": 1024,
28
- "transformers_version": "5.11.0",
29
  "use_cache": true,
30
  "vocab_size": 50277,
31
  "window_size": 128
 
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
 
12
  "local_rope_theta": 10000.0,
 
 
13
  "mlp_width": 1366,
14
  "model_type": "shram",
 
15
  "mosrah_rope_theta": 10000.0,
16
  "num_decoder_layers": 12,
17
  "num_mosrah_heads": 16,
 
21
  "rope_mode": "main_sequence",
22
  "tie_word_embeddings": false,
23
  "training_sequence_length": 1024,
24
+ "transformers_version": "5.12.0",
25
  "use_cache": true,
26
  "vocab_size": 50277,
27
  "window_size": 128
configuration.py CHANGED
@@ -1,284 +1,243 @@
1
- """Configuration for the SHRAM transformer.
2
-
3
- All architectural parameters that vary across model scales or are meaningful research
4
- variables are expressed here. Architectural constants (no bias in linear layers,
5
- SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
- documented at the point of use — they are not config parameters because they do not
7
- vary and changing them produces a different architecture, not a different scale.
8
-
9
- RoPE configuration is owned entirely by this config. Each attention path reads its
10
- parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
- HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
- """
13
-
14
- import math
15
-
16
- from transformers import PretrainedConfig
17
-
18
-
19
- class ShramConfig(PretrainedConfig):
20
- """Configuration class for the SHRAM decoder-only transformer.
21
-
22
- SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
23
- attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
24
- local sliding-window causal attention path and h_s is the MoSRAH sparse routed
25
- path. All other components follow the Llama 3 baseline.
26
-
27
- This config is the single source of truth for every architectural dimension of the
28
- model. Nothing in the architecture may use a literal number that belongs here.
29
-
30
- Two independent RoPE configurations exist — one per attention path:
31
-
32
- - h_l always uses standard RoPE with ``local_rope_theta``.
33
- - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
34
- ``inference_sequence_length``, ``alpha``, and ``beta``. When
35
- ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
36
- ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
37
- and the correct setting for experiments that do not require context extension.
38
-
39
- Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
40
-
41
- config = AutoConfig.from_pretrained(
42
- "your-namespace/advanced-transformers-lib",
43
- trust_remote_code=True,
44
- num_decoder_layers=12,
45
- )
46
- model = AutoModelForCausalLM.from_config(config)
47
-
48
- Args:
49
- vocab_size: Vocabulary size. Controls the embedding table and output logits
50
- dimension. Must match the tokenizer.
51
- embedding_width: Model width ``d``. The dimension of the residual stream.
52
- mlp_width: FFN hidden dimension.
53
- num_decoder_layers: Number of transformer blocks stacked in sequence.
54
- num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
55
- num_mosrah_heads: Total MoSRAH expert heads available ``L``.
56
- num_selected_heads: MoSRAH heads each token selects ``K``.
57
- head_dim: Per-head dimension, shared by both attention paths. Must be even
58
- (RoPE rotates dimensions in pairs). Paper uses 16.
59
- window_size: Sliding window size for h_l. Paper uses 128.
60
- rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
61
- original sequence positions; ``"semantic_sequence"`` supplies local slot
62
- indices. Both are required; experimentally correct mode is undetermined
63
- (paper §4). Default ``"main_sequence"``.
64
- rms_norm_eps: Epsilon for RMSNorm layers.
65
- local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
66
- Paper uses b=10000.
67
- mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
68
- b=10000.
69
- training_sequence_length: Context length ``C_train`` the model was or will be
70
- trained at. Used to compute the YaRN scale factor for BEA.
71
- inference_sequence_length: Context length ``C_target`` the model must support
72
- at inference. Optional; defaults to ``training_sequence_length`` so that
73
- ``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
74
- alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
75
- ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
76
- beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
77
- ``r(d) > beta`` are left unscaled. Paper value: 32.0.
78
- attention_dropout: Dropout probability on attention weights. Default 0.0.
79
- use_cache: Whether to return past_key_values for KV caching.
80
- output_hidden_states: Whether to return hidden states after each layer.
81
- tie_word_embeddings: Whether input embedding and LM head share weights.
82
- mosrah_overallocation_factor: Overallocation multiplier for the expert packing
83
- buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
84
- num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
85
- Must be > 1.0 to guarantee a buffer larger than the balanced-routing
86
- baseline. Default 2.0.
87
- max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
88
- solver in ``balance_capacity``. 10 covers convergence at approximately
89
- the 98th percentile of routing densities; the top 2% of extreme-density
90
- cases are not expected under normal training. The bound exists as a
91
- correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
92
- Default 10.
93
- load_balance_loss_type: Formula used for the load-balance auxiliary loss.
94
- One of ``"gshard"``, ``"ce"``, ``"bce"``, ``"temporal_overcapacity"``, or
95
- ``"causal_overcapacity"``. ``"causal_overcapacity"`` (the default) attributes
96
- violations to the causal trajectory that produced them — each expert
97
- accumulates a running mean of its selection log-probability and the loss
98
- penalises the gap between overloaded and typical trajectories. Like
99
- ``"temporal_overcapacity"``, it fires only when a violation exists and shuts
100
- off automatically, making it safe to weight strongly. Default
101
- ``"causal_overcapacity"``.
102
- maximum_expert_overclaim: Maximum number of tokens an expert may receive above
103
- its ideal allocation trajectory before either overcapacity loss fires.
104
- A value of 0 means violations trigger immediately at any imbalance.
105
- Larger values permit short-lived semantic specialization before correction.
106
- Used by both ``"temporal_overcapacity"`` and ``"causal_overcapacity"``.
107
- Must be non-negative. Default 20.
108
- """
109
-
110
- model_type = "shram"
111
-
112
- auto_map = {
113
- "AutoConfig": "configuration.ShramConfig",
114
- "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
115
- }
116
-
117
- def __init__(
118
- self,
119
- vocab_size: int = 50277,
120
- embedding_width: int = 512,
121
- mlp_width: int = 1366,
122
- num_decoder_layers: int = 12,
123
- num_sliding_window_heads: int = 16,
124
- num_mosrah_heads: int = 16,
125
- num_selected_heads: int = 16,
126
- head_dim: int = 16,
127
- window_size: int = 128,
128
- rope_mode: str = "main_sequence",
129
- rms_norm_eps: float = 1e-5,
130
- local_rope_theta: float = 10000.0,
131
- mosrah_rope_theta: float = 10000.0,
132
- training_sequence_length: int = 1024,
133
- inference_sequence_length: int | None = None,
134
- alpha: float = 1.0,
135
- beta: float = 32.0,
136
- attention_dropout: float = 0.0,
137
- use_cache: bool = True,
138
- output_hidden_states: bool = False,
139
- tie_word_embeddings: bool = False,
140
- mosrah_overallocation_factor: float = 2.0,
141
- max_bid_rounds: int = 10,
142
- load_balance_loss_type: str = "causal_overcapacity",
143
- maximum_expert_overclaim: int = 20,
144
- **kwargs
145
- ):
146
- if head_dim % 2 != 0:
147
- raise ValueError(
148
- f"head_dim must be even (RoPE rotates dimensions in pairs). "
149
- f"Got head_dim={head_dim}."
150
- )
151
-
152
- if rope_mode not in {"main_sequence", "semantic_sequence"}:
153
- raise ValueError(
154
- f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
155
- f"got '{rope_mode}'."
156
- )
157
-
158
- if training_sequence_length <= 0:
159
- raise ValueError(
160
- f"training_sequence_length must be positive, "
161
- f"got {training_sequence_length}."
162
- )
163
-
164
- if inference_sequence_length is None:
165
- inference_sequence_length = training_sequence_length
166
- if inference_sequence_length <= 0:
167
- raise ValueError(
168
- f"inference_sequence_length must be positive, "
169
- f"got {inference_sequence_length}."
170
- )
171
-
172
- if mosrah_overallocation_factor <= 1.0:
173
- raise ValueError(
174
- f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
175
- f"buffer larger than the balanced-routing baseline. "
176
- f"Got {mosrah_overallocation_factor}."
177
- )
178
-
179
- if max_bid_rounds < 1:
180
- raise ValueError(
181
- f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
182
- )
183
-
184
- if maximum_expert_overclaim < 0:
185
- raise ValueError(
186
- f"maximum_expert_overclaim must be non-negative, "
187
- f"got {maximum_expert_overclaim}."
188
- )
189
-
190
- _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity", "causal_overcapacity"}
191
- if load_balance_loss_type not in _supported_loss_types:
192
- supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
193
- raise ValueError(
194
- f"load_balance_loss_type must be one of {supported}, "
195
- f"got {load_balance_loss_type!r}."
196
- )
197
-
198
- self.vocab_size = vocab_size
199
- self.embedding_width = embedding_width
200
- self.mlp_width = mlp_width
201
- self.num_decoder_layers = num_decoder_layers
202
- self.num_sliding_window_heads = num_sliding_window_heads
203
- self.num_mosrah_heads = num_mosrah_heads
204
- self.num_selected_heads = num_selected_heads
205
- self.head_dim = head_dim
206
- self.window_size = window_size
207
- self.rope_mode = rope_mode
208
- self.rms_norm_eps = rms_norm_eps
209
- self.local_rope_theta = local_rope_theta
210
- self.mosrah_rope_theta = mosrah_rope_theta
211
- self.training_sequence_length = training_sequence_length
212
- self.inference_sequence_length = inference_sequence_length
213
- self.alpha = alpha
214
- self.beta = beta
215
- self.mosrah_overallocation_factor = mosrah_overallocation_factor
216
- self.max_bid_rounds = max_bid_rounds
217
- self.load_balance_loss_type = load_balance_loss_type
218
- self.maximum_expert_overclaim = maximum_expert_overclaim
219
- self.attention_dropout = attention_dropout
220
- self.use_cache = use_cache
221
-
222
- super().__init__(
223
- tie_word_embeddings=tie_word_embeddings,
224
- output_hidden_states=output_hidden_states,
225
- **kwargs
226
- )
227
-
228
- # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
229
- # serialises it into config.json.
230
- self.auto_map = type(self).auto_map
231
-
232
- @property
233
- def scale(self) -> float:
234
- """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
235
-
236
- When scale == 1.0, YaRN reduces exactly to standard RoPE all frequency
237
- adjustments cancel and A_rope = 1. This is the default state.
238
- """
239
- return self.inference_sequence_length / self.training_sequence_length
240
-
241
- @property
242
- def mosrah_packed_length(self) -> int:
243
- """Static packed time dimension T for expert packing.
244
-
245
- The expected tokens per expert under perfectly balanced routing is
246
- ``training_sequence_length * num_selected_heads / num_mosrah_heads``.
247
- Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
248
- that baseline. The ceiling ensures T is always an integer >= 1.
249
-
250
- All consumers of the packed buffer size must read this property rather
251
- than deriving T independently.
252
- """
253
- return math.ceil(
254
- self.training_sequence_length
255
- * self.num_selected_heads
256
- / self.num_mosrah_heads
257
- * self.mosrah_overallocation_factor
258
- )
259
-
260
- @property
261
- def mosrah_cache_length(self) -> int:
262
- """Static per-(batch, head) slot capacity for the MoSRAH inference cache.
263
-
264
- The expected tokens per expert over the full inference context under perfectly
265
- balanced routing is ``inference_sequence_length * num_selected_heads /
266
- num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
267
- a buffer above that baseline. The ceiling ensures the result is always an
268
- integer >= 1.
269
-
270
- Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
271
- using ``training_sequence_length``. This property uses
272
- ``inference_sequence_length`` because the cache must hold the full accumulated
273
- token history across the entire inference run.
274
-
275
- All consumers of the MoSRAH cache buffer size must read this property rather
276
- than deriving the capacity independently.
277
- """
278
- return math.ceil(
279
- self.inference_sequence_length
280
- * self.num_selected_heads
281
- / self.num_mosrah_heads
282
- * self.mosrah_overallocation_factor
283
- )
284
-
 
1
+ """Configuration for the SHRAM transformer.
2
+
3
+ All architectural parameters that vary across model scales or are meaningful research
4
+ variables are expressed here. Architectural constants (no bias in linear layers,
5
+ SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
+ documented at the point of use — they are not config parameters because they do not
7
+ vary and changing them produces a different architecture, not a different scale.
8
+
9
+ RoPE configuration is owned entirely by this config. Each attention path reads its
10
+ parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
+ HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
+ """
13
+
14
+ import math
15
+
16
+ from transformers import PretrainedConfig
17
+
18
+
19
+ class ShramConfig(PretrainedConfig):
20
+ """Configuration class for the SHRAM decoder-only transformer.
21
+
22
+ SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
23
+ attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
24
+ local sliding-window causal attention path and h_s is the MoSRAH sparse routed
25
+ path. All other components follow the Llama 3 baseline.
26
+
27
+ This config is the single source of truth for every architectural dimension of the
28
+ model. Nothing in the architecture may use a literal number that belongs here.
29
+
30
+ Two independent RoPE configurations exist — one per attention path:
31
+
32
+ - h_l always uses standard RoPE with ``local_rope_theta``.
33
+ - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
34
+ ``inference_sequence_length``, ``alpha``, and ``beta``. When
35
+ ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
36
+ ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
37
+ and the correct setting for experiments that do not require context extension.
38
+
39
+ Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
40
+
41
+ config = AutoConfig.from_pretrained(
42
+ "your-namespace/advanced-transformers-lib",
43
+ trust_remote_code=True,
44
+ num_decoder_layers=12,
45
+ )
46
+ model = AutoModelForCausalLM.from_config(config)
47
+
48
+ Args:
49
+ vocab_size: Vocabulary size. Controls the embedding table and output logits
50
+ dimension. Must match the tokenizer.
51
+ embedding_width: Model width ``d``. The dimension of the residual stream.
52
+ mlp_width: FFN hidden dimension.
53
+ num_decoder_layers: Number of transformer blocks stacked in sequence.
54
+ num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
55
+ num_mosrah_heads: Total MoSRAH expert heads available ``L``.
56
+ num_selected_heads: MoSRAH heads each token selects ``K``.
57
+ head_dim: Per-head dimension, shared by both attention paths. Must be even
58
+ (RoPE rotates dimensions in pairs). Paper uses 16.
59
+ window_size: Sliding window size for h_l. Paper uses 128.
60
+ rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
61
+ original sequence positions; ``"semantic_sequence"`` supplies local slot
62
+ indices. Both are required; experimentally correct mode is undetermined
63
+ (paper §4). Default ``"main_sequence"``.
64
+ rms_norm_eps: Epsilon for RMSNorm layers.
65
+ local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
66
+ Paper uses b=10000.
67
+ mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
68
+ b=10000.
69
+ training_sequence_length: Context length ``C_train`` the model was or will be
70
+ trained at. Used to compute the YaRN scale factor for BEA.
71
+ inference_sequence_length: Context length ``C_target`` the model must support
72
+ at inference. Optional; defaults to ``training_sequence_length`` so that
73
+ ``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
74
+ alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
75
+ ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
76
+ beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
77
+ ``r(d) > beta`` are left unscaled. Paper value: 32.0.
78
+ attention_dropout: Dropout probability on attention weights. Default 0.0.
79
+ use_cache: Whether to return past_key_values for KV caching.
80
+ output_hidden_states: Whether to return hidden states after each layer.
81
+ tie_word_embeddings: Whether input embedding and LM head share weights.
82
+ """
83
+
84
+ model_type = "shram"
85
+
86
+ auto_map = {
87
+ "AutoConfig": "configuration.ShramConfig",
88
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
89
+ }
90
+
91
+ def __init__(
92
+ self,
93
+ vocab_size: int = 50277,
94
+ embedding_width: int = 512,
95
+ mlp_width: int = 1366,
96
+ num_decoder_layers: int = 12,
97
+ num_sliding_window_heads: int = 16,
98
+ num_mosrah_heads: int = 16,
99
+ num_selected_heads: int = 16,
100
+ head_dim: int = 16,
101
+ window_size: int = 128,
102
+ rope_mode: str = "main_sequence",
103
+ rms_norm_eps: float = 1e-5,
104
+ local_rope_theta: float = 10000.0,
105
+ mosrah_rope_theta: float = 10000.0,
106
+ training_sequence_length: int = 1024,
107
+ inference_sequence_length: int | None = None,
108
+ alpha: float = 1.0,
109
+ beta: float = 32.0,
110
+ attention_dropout: float = 0.0,
111
+ use_cache: bool = True,
112
+ output_hidden_states: bool = False,
113
+ tie_word_embeddings: bool = False,
114
+ **kwargs
115
+ ):
116
+ if head_dim % 2 != 0:
117
+ raise ValueError(
118
+ f"head_dim must be even (RoPE rotates dimensions in pairs). "
119
+ f"Got head_dim={head_dim}."
120
+ )
121
+
122
+ if rope_mode not in {"main_sequence", "semantic_sequence"}:
123
+ raise ValueError(
124
+ f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
125
+ f"got '{rope_mode}'."
126
+ )
127
+
128
+ if training_sequence_length <= 0:
129
+ raise ValueError(
130
+ f"training_sequence_length must be positive, "
131
+ f"got {training_sequence_length}."
132
+ )
133
+
134
+ if inference_sequence_length is None:
135
+ inference_sequence_length = training_sequence_length
136
+ if inference_sequence_length <= 0:
137
+ raise ValueError(
138
+ f"inference_sequence_length must be positive, "
139
+ f"got {inference_sequence_length}."
140
+ )
141
+
142
+ if num_mosrah_heads % num_selected_heads != 0:
143
+ raise ValueError(
144
+ f"num_mosrah_heads must be exactly divisible by num_selected_heads. "
145
+ f"Mechanical load balancing partitions the sequence into blocks of "
146
+ f"W = num_mosrah_heads // num_selected_heads tokens; each block covers "
147
+ f"every expert exactly once, which requires an integer W. "
148
+ f"Got num_mosrah_heads={num_mosrah_heads}, num_selected_heads={num_selected_heads}."
149
+ )
150
+
151
+ self.vocab_size = vocab_size
152
+ self.embedding_width = embedding_width
153
+ self.mlp_width = mlp_width
154
+ self.num_decoder_layers = num_decoder_layers
155
+ self.num_sliding_window_heads = num_sliding_window_heads
156
+ self.num_mosrah_heads = num_mosrah_heads
157
+ self.num_selected_heads = num_selected_heads
158
+ self.head_dim = head_dim
159
+ self.window_size = window_size
160
+ self.rope_mode = rope_mode
161
+ self.rms_norm_eps = rms_norm_eps
162
+ self.local_rope_theta = local_rope_theta
163
+ self.mosrah_rope_theta = mosrah_rope_theta
164
+ self.training_sequence_length = training_sequence_length
165
+ self.inference_sequence_length = inference_sequence_length
166
+ self.alpha = alpha
167
+ self.beta = beta
168
+ self.attention_dropout = attention_dropout
169
+ self.use_cache = use_cache
170
+
171
+ super().__init__(
172
+ tie_word_embeddings=tie_word_embeddings,
173
+ output_hidden_states=output_hidden_states,
174
+ **kwargs
175
+ )
176
+
177
+ # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
178
+ # serialises it into config.json.
179
+ self.auto_map = type(self).auto_map
180
+
181
+ @property
182
+ def scale(self) -> float:
183
+ """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
184
+
185
+ When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
186
+ adjustments cancel and A_rope = 1. This is the default state.
187
+ """
188
+ return self.inference_sequence_length / self.training_sequence_length
189
+
190
+ @property
191
+ def mosrah_packed_length(self) -> int:
192
+ """Static packed time dimension T for expert packing.
193
+
194
+ Mechanical load balancing guarantees exactly
195
+ ``training_sequence_length * num_selected_heads / num_mosrah_heads``
196
+ tokens per expert. The ceiling handles non-integer results when
197
+ training_sequence_length is not divisible by the block length W.
198
+
199
+ All consumers of the packed buffer size must read this property rather
200
+ than deriving T independently.
201
+ """
202
+ return math.ceil(
203
+ self.training_sequence_length
204
+ * self.num_selected_heads
205
+ / self.num_mosrah_heads
206
+ ) + self.block_length
207
+
208
+ @property
209
+ def mosrah_cache_length(self) -> int:
210
+ """Static per-(batch, head) slot capacity for the MoSRAH inference cache.
211
+
212
+ Mechanical load balancing guarantees exactly
213
+ ``inference_sequence_length * num_selected_heads / num_mosrah_heads``
214
+ tokens per expert over the full inference context. The ceiling handles
215
+ non-integer results when inference_sequence_length is not divisible by
216
+ the block length W.
217
+
218
+ Distinct from ``mosrah_packed_length``, which sizes the training packing
219
+ buffer using ``training_sequence_length``. This property uses
220
+ ``inference_sequence_length`` because the cache must hold the full
221
+ accumulated token history across the entire inference run.
222
+
223
+ All consumers of the MoSRAH cache buffer size must read this property
224
+ rather than deriving the capacity independently.
225
+ """
226
+ return math.ceil(
227
+ self.inference_sequence_length
228
+ * self.num_selected_heads
229
+ / self.num_mosrah_heads
230
+ ) + self.block_length
231
+
232
+ @property
233
+ def block_length(self) -> int:
234
+ """Routing block length W = num_mosrah_heads // num_selected_heads.
235
+
236
+ Within each block of W consecutive tokens every expert is used exactly once,
237
+ giving perfect load balance by construction. The E % K == 0 constraint
238
+ enforced at construction guarantees W is an exact integer.
239
+
240
+ All consumers of the routing block length must read this property rather
241
+ than deriving W independently.
242
+ """
243
+ return self.num_mosrah_heads // self.num_selected_heads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
huggingface.py CHANGED
The diff for this file is too large to render. See raw diff