smithblack-0 commited on
Commit
7bf638f
·
verified ·
1 Parent(s): af7974e

Update architecture and tokenizer

Browse files
README.md CHANGED
@@ -1,103 +1,104 @@
1
- ---
2
- language:
3
- - en
4
- license: mit
5
- library_name: transformers
6
- pipeline_tag: text-generation
7
- tags:
8
- - pytorch
9
- - research
10
- - sparse-attention
11
- - mixture-of-experts
12
- ---
13
-
14
- # SHRAM — Sparse Hybrid Token Routed Attention Mixture
15
-
16
- A research baseline implementing the SHRAM architecture from "An Examination of Sparse
17
- Attention for Long Context Purposes." No pretrained weights — pull the architecture from
18
- the Hub and instantiate a freshly initialised model from config. Every parameter is
19
- overridable at instantiation time via kwargs.
20
-
21
- > **Important:** `trust_remote_code=True` is required. It downloads the architecture
22
- > source files from the Hub and imports them into your Python process. Review the
23
- > source at [smithblack-0/SHRAM](https://huggingface.co/smithblack-0/SHRAM) before use.
24
-
25
- ## Architecture
26
-
27
- SHRAM replaces every standard attention layer with a hybrid layer `H(x) = h_l(x) + h_s(x)`:
28
-
29
- - **h_l** — local sliding-window causal attention path.
30
- - **h_s** — MoSRAH sparse routed path. Each token selects K of L available expert heads
31
- via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.
32
-
33
- All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).
34
-
35
- ## Usage
36
-
37
- This repository contains no pretrained weights. The intended workflow is: pull the
38
- architecture config from the Hub, instantiate a model with fresh random weights, then
39
- train it yourself.
40
-
41
- ```python
42
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
43
-
44
- # Step 1: pull the architecture config from the Hub.
45
- # AutoConfig.from_pretrained downloads config.json only no weights are loaded.
46
- # Override any parameter via kwargs.
47
- config = AutoConfig.from_pretrained(
48
- "smithblack-0/SHRAM",
49
- trust_remote_code=True,
50
- num_hidden_layers=16, # example override
51
- num_mosrah_heads=32, # example override
52
- )
53
-
54
- # Step 2: instantiate with fresh random weights.
55
- # from_config never loads a checkpoint it always produces a randomly initialised model.
56
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
57
-
58
- # Step 3: load the tokenizer.
59
- tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM")
60
- ```
61
-
62
- After training your own checkpoint, save and reload it in the standard way:
63
-
64
- ```python
65
- model.save_pretrained("./my-checkpoint")
66
- model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)
67
- ```
68
-
69
- ## Constructor Defaults
70
-
71
- The values below are the defaults you get if you call `AutoConfig.from_pretrained` with
72
- no overrides. They are not the parameters of a pretrained model this repository
73
- contains no weights. All values are overridable via kwargs.
74
-
75
- | Parameter | Default |
76
- |-----------|---------|
77
- | `alpha` | 1.0 |
78
- | `attention_dropout` | 0.0 |
79
- | `beta` | 32.0 |
80
- | `dtype` | None |
81
- | `head_dim` | 16 |
82
- | `hidden_size` | 512 |
83
- | `inference_sequence_length` | 1024 |
84
- | `intermediate_size` | 1366 |
85
- | `local_rope_theta` | 10000.0 |
86
- | `mosrah_rope_theta` | 10000.0 |
87
- | `num_hidden_layers` | 12 |
88
- | `num_mosrah_heads` | 16 |
89
- | `num_selected_heads` | 16 |
90
- | `num_sliding_window_heads` | 16 |
91
- | `output_hidden_states` | False |
92
- | `rms_norm_eps` | 1e-05 |
93
- | `rope_mode` | main_sequence |
94
- | `tie_word_embeddings` | False |
95
- | `training_sequence_length` | 1024 |
96
- | `use_cache` | True |
97
- | `vocab_size` | 50277 |
98
- | `window_size` | 128 |
99
-
100
- ## License
101
-
102
- MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX
103
- (`EleutherAI/gpt-neox-20b`, Apache 2.0).
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - pytorch
9
+ - research
10
+ - sparse-attention
11
+ - mixture-of-experts
12
+ ---
13
+
14
+ # SHRAM — Sparse Hybrid Token Routed Attention Mixture
15
+
16
+ A research baseline implementing the SHRAM architecture from "An Examination of Sparse
17
+ Attention for Long Context Purposes." No pretrained weights — pull the architecture from
18
+ the Hub and instantiate a freshly initialised model from config. Every parameter is
19
+ overridable at instantiation time via kwargs.
20
+
21
+ > **Important:** `trust_remote_code=True` is required. It downloads the architecture
22
+ > source files from the Hub and imports them into your Python process. Review the
23
+ > source at [smithblack-0/SHRAM](https://huggingface.co/smithblack-0/SHRAM) before use. Those interested can also
24
+ > clone the git repository at https://github.com/smithblack-0/advanced-transformers-lib
25
+
26
+ ## Architecture
27
+
28
+ SHRAM replaces every standard attention layer with a hybrid layer `H(x) = h_l(x) + h_s(x)`:
29
+
30
+ - **h_l** — local sliding-window causal attention path.
31
+ - **h_s** — MoSRAH sparse routed path. Each token selects K of L available expert heads
32
+ via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.
33
+
34
+ All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).
35
+
36
+ ## Usage
37
+
38
+ This repository contains no pretrained weights. The intended workflow is: pull the
39
+ architecture config from the Hub, instantiate a model with fresh random weights, then
40
+ train it yourself.
41
+
42
+ ```python
43
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
44
+
45
+ # Step 1: pull the architecture config from the Hub.
46
+ # AutoConfig.from_pretrained downloads config.json only — no weights are loaded.
47
+ # Override any parameter via kwargs.
48
+ config = AutoConfig.from_pretrained(
49
+ "smithblack-0/SHRAM",
50
+ trust_remote_code=True,
51
+ num_hidden_layers=16, # example override
52
+ num_mosrah_heads=32, # example override
53
+ )
54
+
55
+ # Step 2: instantiate with fresh random weights.
56
+ # from_config never loads a checkpoint — it always produces a randomly initialised model.
57
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
58
+
59
+ # Step 3: load the tokenizer.
60
+ tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM")
61
+ ```
62
+
63
+ After training your own checkpoint, save and reload it in the standard way:
64
+
65
+ ```python
66
+ model.save_pretrained("./my-checkpoint")
67
+ model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)
68
+ ```
69
+
70
+ ## Constructor Defaults
71
+
72
+ The values below are the defaults you get if you call `AutoConfig.from_pretrained` with
73
+ no overrides. They are not the parameters of a pretrained model — this repository
74
+ contains no weights. All values are overridable via kwargs.
75
+
76
+ | Parameter | Default |
77
+ |-----------|---------|
78
+ | `alpha` | 1.0 |
79
+ | `attention_dropout` | 0.0 |
80
+ | `beta` | 32.0 |
81
+ | `dtype` | None |
82
+ | `head_dim` | 16 |
83
+ | `hidden_size` | 512 |
84
+ | `inference_sequence_length` | 1024 |
85
+ | `intermediate_size` | 1366 |
86
+ | `local_rope_theta` | 10000.0 |
87
+ | `mosrah_rope_theta` | 10000.0 |
88
+ | `num_hidden_layers` | 12 |
89
+ | `num_mosrah_heads` | 16 |
90
+ | `num_selected_heads` | 16 |
91
+ | `num_sliding_window_heads` | 16 |
92
+ | `output_hidden_states` | False |
93
+ | `rms_norm_eps` | 1e-05 |
94
+ | `rope_mode` | main_sequence |
95
+ | `tie_word_embeddings` | False |
96
+ | `training_sequence_length` | 1024 |
97
+ | `use_cache` | True |
98
+ | `vocab_size` | 50277 |
99
+ | `window_size` | 128 |
100
+
101
+ ## License
102
+
103
+ MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX
104
+ (`EleutherAI/gpt-neox-20b`, Apache 2.0).
__attention__bottlenecked_ensemble_attention.py CHANGED
@@ -1,252 +1,252 @@
1
- """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
2
-
3
- BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
4
- It consumes packed expert-choice tensors, a supplied position tensor, an active
5
- token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
6
- same packed expert-choice space expected by later unpacking.
7
-
8
- BEA does not compute positions and does not choose packed-position semantics.
9
- Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
10
- (K̃) and raw values (V) into the sparse cache and attends against the
11
- accumulated cached state returned by that cache.
12
- """
13
-
14
- import math
15
-
16
- import torch
17
- import torch.nn as nn
18
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19
-
20
- from .configuration import ShramConfig
21
- from .__cache__mosrah_cache import MoSRAHCache
22
- from .rope import RotaryEmbedding
23
-
24
-
25
- class BottleneckedEnsembleAttention(nn.Module):
26
- """
27
- Packed expert-choice attention operator for the MoSRAH sparse path.
28
- Operates per-head independently on an ensemble of tokens.
29
- FlexAttention saves flops on dead tokens.
30
-
31
- Architectural properties:
32
- - consumes packed expert-choice tensors of shape (B, L, T, d)
33
- - uses independent per-head Q/K/V/O projection parameters
34
- - applies YaRN-capable RoPE using supplied position_ids
35
- - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
36
- - uses a fast fused attention path
37
- - returns outputs in the same packed expert-choice space (B, L, T, d)
38
-
39
- Args:
40
- config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
- `head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
42
- `inference_sequence_length`, `alpha`, and `beta`.
43
- """
44
-
45
- def __init__(self, config: ShramConfig) -> None:
46
- super().__init__()
47
-
48
- self.hidden_size = config.hidden_size
49
- self.num_heads = config.num_mosrah_heads
50
- self.head_dim = config.head_dim
51
-
52
- # Independent per-head projections. No cross-head parameter sharing.
53
- self.q_proj = nn.Parameter(
54
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
55
- )
56
- self.k_proj = nn.Parameter(
57
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
58
- )
59
- self.v_proj = nn.Parameter(
60
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
61
- )
62
- self.o_proj = nn.Parameter(
63
- torch.empty(self.num_heads, self.head_dim, self.hidden_size)
64
- )
65
-
66
- self._reset_parameters()
67
-
68
- # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
- # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
- # no yarn dilation occurs.
71
- self.rope = RotaryEmbedding(
72
- mode="yarn",
73
- head_dim=self.head_dim,
74
- theta=config.mosrah_rope_theta,
75
- initial_seq_length=config.training_sequence_length,
76
- dilation=config.scale,
77
- alpha=config.alpha,
78
- beta=config.beta,
79
- )
80
-
81
- def forward(
82
- self,
83
- packed_embeddings: torch.Tensor,
84
- position_ids: torch.Tensor,
85
- active_mask: torch.Tensor,
86
- cache: MoSRAHCache | None = None,
87
- ) -> torch.Tensor:
88
- """Apply BEA to packed expert-choice tensors.
89
-
90
- Args:
91
- packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
92
- position_ids: Supplied packed positions of shape (B, L, T).
93
- active_mask: Boolean active-token mask of shape (B, L, T).
94
- cache: Optional layer-local MoSRAH cache.
95
-
96
- Returns:
97
- Packed expert-choice output tensor of shape (B, L, T, d).
98
- """
99
- batch_size, _, query_length, _ = packed_embeddings.shape
100
- self._validate_tensor_shape(packed_embeddings)
101
- self._validate_position_shape(packed_embeddings, position_ids)
102
- self._validate_active_mask_shape(packed_embeddings, active_mask)
103
-
104
- # Independent per-head projections:
105
- # (B, L, T, d) x (L, d, u) -> (B, L, T, u)
106
- query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
107
- key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
108
- value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
109
-
110
- rotated_query_states, rotated_key_states, attention_scaling = self.rope(
111
- query_states,
112
- key_states,
113
- position_ids,
114
- )
115
-
116
- if cache is not None:
117
- # In cached execution, the current query tensor uses local tensor rows
118
- # 0..Q-1, but the key tensor returned by the cache is the full accumulated
119
- # packed sequence for each (batch, head) slot. The only additional data
120
- # needed to align those two views is the pre-update cached prefix length.
121
- # which will indicate how many queries were processed before now.
122
- num_tokens_processed = cache.get_heads_lengths().clone()
123
- key_states, value_states, key_active_mask = cache.update(
124
- rotated_key_states,
125
- value_states,
126
- active_mask,
127
- )
128
- else:
129
- num_tokens_processed = torch.zeros(
130
- batch_size,
131
- self.num_heads,
132
- dtype=torch.long,
133
- device=packed_embeddings.device,
134
- )
135
- key_states = rotated_key_states
136
- key_active_mask = active_mask
137
-
138
- block_mask = self._make_block_mask(
139
- query_active_mask=active_mask,
140
- key_active_mask=key_active_mask,
141
- num_tokens_processed=num_tokens_processed,
142
- query_length=query_length,
143
- key_length=key_states.shape[2],
144
- device=packed_embeddings.device,
145
- )
146
- attended_states = flex_attention(
147
- rotated_query_states,
148
- key_states,
149
- value_states,
150
- block_mask=block_mask,
151
- scale=attention_scaling / math.sqrt(self.head_dim),
152
- )
153
-
154
- # Project back to model width:
155
- # (B, L, T, u) x (L, u, d) -> (B, L, T, d)
156
- return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
157
-
158
- def _reset_parameters(self) -> None:
159
- """Initialize per-head projection weights."""
160
- for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
161
- nn.init.xavier_uniform_(weight)
162
-
163
- def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
164
- """Validate the local packed-embedding shape contract required by BEA."""
165
- if packed_embeddings.shape[1] != self.num_heads:
166
- raise ValueError(
167
- f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
168
- f"got {packed_embeddings.shape[1]}."
169
- )
170
-
171
- if packed_embeddings.shape[-1] != self.hidden_size:
172
- raise ValueError(
173
- f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
174
- f"got {packed_embeddings.shape[-1]}."
175
- )
176
-
177
- def _validate_position_shape(
178
- self,
179
- packed_embeddings: torch.Tensor,
180
- position_ids: torch.Tensor,
181
- ) -> None:
182
- """Validate the supplied packed-position tensor shape."""
183
- if position_ids.shape != packed_embeddings.shape[:3]:
184
- raise ValueError(
185
- f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
186
- f"got {tuple(position_ids.shape)}."
187
- )
188
-
189
- def _validate_active_mask_shape(
190
- self,
191
- packed_embeddings: torch.Tensor,
192
- active_mask: torch.Tensor,
193
- ) -> None:
194
- """Validate the supplied active-token mask shape."""
195
- if active_mask.shape != packed_embeddings.shape[:3]:
196
- raise ValueError(
197
- f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
198
- f"got {tuple(active_mask.shape)}."
199
- )
200
-
201
- def _make_block_mask(
202
- self,
203
- query_active_mask: torch.Tensor,
204
- key_active_mask: torch.Tensor,
205
- num_tokens_processed: torch.Tensor,
206
- query_length: int,
207
- key_length: int,
208
- device: torch.device,
209
- ):
210
- """Create the packed-sequence causal mask for FlexAttention.
211
-
212
- At the root, causality is still triangular. The only nuance is cached
213
- execution: query rows are indexed locally as 0..Q-1 inside the current
214
- query tensor, but the key tensor may already contain a cached prefix for
215
- that (batch, head) slot. The causal horizon for query tensor row q is
216
- therefore:
217
-
218
- cached_prefix_lengths[b, h] + q
219
-
220
- Query and key activity masks are then composed with that triangular rule
221
- so FlexAttention can skip padded query rows and ignore inactive key slots.
222
- """
223
- batch_size, num_heads, _ = query_active_mask.shape
224
-
225
- # Build the per-(batch, head, query_row) triangular horizon from a simple
226
- # arange over query rows plus the cached prefix lengths for each slot.
227
- relative_query_positions = torch.arange(
228
- query_length,
229
- device=device,
230
- dtype=torch.long,
231
- ).view(1, 1, query_length)
232
- causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
233
-
234
- def packed_causal_mask(
235
- batch_idx: torch.Tensor,
236
- head_idx: torch.Tensor,
237
- query_idx: torch.Tensor,
238
- key_idx: torch.Tensor,
239
- ) -> torch.Tensor:
240
- query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
241
- key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
242
- is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
243
- return query_is_active & key_is_active & is_causal
244
-
245
- return create_block_mask(
246
- packed_causal_mask,
247
- B=batch_size,
248
- H=num_heads,
249
- Q_LEN=query_length,
250
- KV_LEN=key_length,
251
- device=device,
252
  )
 
1
+ """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
2
+
3
+ BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
4
+ It consumes packed expert-choice tensors, a supplied position tensor, an active
5
+ token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
6
+ same packed expert-choice space expected by later unpacking.
7
+
8
+ BEA does not compute positions and does not choose packed-position semantics.
9
+ Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
10
+ (K̃) and raw values (V) into the sparse cache and attends against the
11
+ accumulated cached state returned by that cache.
12
+ """
13
+
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19
+
20
+ from .configuration import ShramConfig
21
+ from .__cache__mosrah_cache import MoSRAHCache
22
+ from .rope import RotaryEmbedding
23
+
24
+
25
+ class BottleneckedEnsembleAttention(nn.Module):
26
+ """
27
+ Packed expert-choice attention operator for the MoSRAH sparse path.
28
+ Operates per-head independently on an ensemble of tokens.
29
+ FlexAttention saves flops on dead tokens.
30
+
31
+ Architectural properties:
32
+ - consumes packed expert-choice tensors of shape (B, L, T, d)
33
+ - uses independent per-head Q/K/V/O projection parameters
34
+ - applies YaRN-capable RoPE using supplied position_ids
35
+ - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
36
+ - uses a fast fused attention path
37
+ - returns outputs in the same packed expert-choice space (B, L, T, d)
38
+
39
+ Args:
40
+ config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
+ `head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
42
+ `inference_sequence_length`, `alpha`, and `beta`.
43
+ """
44
+
45
+ def __init__(self, config: ShramConfig) -> None:
46
+ super().__init__()
47
+
48
+ self.hidden_size = config.hidden_size
49
+ self.num_heads = config.num_mosrah_heads
50
+ self.head_dim = config.head_dim
51
+
52
+ # Independent per-head projections. No cross-head parameter sharing.
53
+ self.q_proj = nn.Parameter(
54
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
55
+ )
56
+ self.k_proj = nn.Parameter(
57
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
58
+ )
59
+ self.v_proj = nn.Parameter(
60
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
61
+ )
62
+ self.o_proj = nn.Parameter(
63
+ torch.empty(self.num_heads, self.head_dim, self.hidden_size)
64
+ )
65
+
66
+ self._reset_parameters()
67
+
68
+ # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
+ # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
+ # no yarn dilation occurs.
71
+ self.rope = RotaryEmbedding(
72
+ mode="yarn",
73
+ head_dim=self.head_dim,
74
+ theta=config.mosrah_rope_theta,
75
+ initial_seq_length=config.training_sequence_length,
76
+ dilation=config.scale,
77
+ alpha=config.alpha,
78
+ beta=config.beta,
79
+ )
80
+
81
+ def forward(
82
+ self,
83
+ packed_embeddings: torch.Tensor,
84
+ position_ids: torch.Tensor,
85
+ active_mask: torch.Tensor,
86
+ cache: MoSRAHCache | None = None,
87
+ ) -> torch.Tensor:
88
+ """Apply BEA to packed expert-choice tensors.
89
+
90
+ Args:
91
+ packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
92
+ position_ids: Supplied packed positions of shape (B, L, T).
93
+ active_mask: Boolean active-token mask of shape (B, L, T).
94
+ cache: Optional layer-local MoSRAH cache.
95
+
96
+ Returns:
97
+ Packed expert-choice output tensor of shape (B, L, T, d).
98
+ """
99
+ batch_size, _, query_length, _ = packed_embeddings.shape
100
+ self._validate_tensor_shape(packed_embeddings)
101
+ self._validate_position_shape(packed_embeddings, position_ids)
102
+ self._validate_active_mask_shape(packed_embeddings, active_mask)
103
+
104
+ # Independent per-head projections:
105
+ # (B, L, T, d) x (L, d, u) -> (B, L, T, u)
106
+ query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
107
+ key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
108
+ value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
109
+
110
+ rotated_query_states, rotated_key_states, attention_scaling = self.rope(
111
+ query_states,
112
+ key_states,
113
+ position_ids,
114
+ )
115
+
116
+ if cache is not None:
117
+ # In cached execution, the current query tensor uses local tensor rows
118
+ # 0..Q-1, but the key tensor returned by the cache is the full accumulated
119
+ # packed sequence for each (batch, head) slot. The only additional data
120
+ # needed to align those two views is the pre-update cached prefix length.
121
+ # which will indicate how many queries were processed before now.
122
+ num_tokens_processed = cache.get_heads_lengths().clone()
123
+ key_states, value_states, key_active_mask = cache.update(
124
+ rotated_key_states,
125
+ value_states,
126
+ active_mask,
127
+ )
128
+ else:
129
+ num_tokens_processed = torch.zeros(
130
+ batch_size,
131
+ self.num_heads,
132
+ dtype=torch.long,
133
+ device=packed_embeddings.device,
134
+ )
135
+ key_states = rotated_key_states
136
+ key_active_mask = active_mask
137
+
138
+ block_mask = self._make_block_mask(
139
+ query_active_mask=active_mask,
140
+ key_active_mask=key_active_mask,
141
+ num_tokens_processed=num_tokens_processed,
142
+ query_length=query_length,
143
+ key_length=key_states.shape[2],
144
+ device=packed_embeddings.device,
145
+ )
146
+ attended_states = flex_attention(
147
+ rotated_query_states,
148
+ key_states,
149
+ value_states,
150
+ block_mask=block_mask,
151
+ scale=attention_scaling / math.sqrt(self.head_dim),
152
+ )
153
+
154
+ # Project back to model width:
155
+ # (B, L, T, u) x (L, u, d) -> (B, L, T, d)
156
+ return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
157
+
158
+ def _reset_parameters(self) -> None:
159
+ """Initialize per-head projection weights."""
160
+ for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
161
+ nn.init.xavier_uniform_(weight)
162
+
163
+ def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
164
+ """Validate the local packed-embedding shape contract required by BEA."""
165
+ if packed_embeddings.shape[1] != self.num_heads:
166
+ raise ValueError(
167
+ f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
168
+ f"got {packed_embeddings.shape[1]}."
169
+ )
170
+
171
+ if packed_embeddings.shape[-1] != self.hidden_size:
172
+ raise ValueError(
173
+ f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
174
+ f"got {packed_embeddings.shape[-1]}."
175
+ )
176
+
177
+ def _validate_position_shape(
178
+ self,
179
+ packed_embeddings: torch.Tensor,
180
+ position_ids: torch.Tensor,
181
+ ) -> None:
182
+ """Validate the supplied packed-position tensor shape."""
183
+ if position_ids.shape != packed_embeddings.shape[:3]:
184
+ raise ValueError(
185
+ f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
186
+ f"got {tuple(position_ids.shape)}."
187
+ )
188
+
189
+ def _validate_active_mask_shape(
190
+ self,
191
+ packed_embeddings: torch.Tensor,
192
+ active_mask: torch.Tensor,
193
+ ) -> None:
194
+ """Validate the supplied active-token mask shape."""
195
+ if active_mask.shape != packed_embeddings.shape[:3]:
196
+ raise ValueError(
197
+ f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
198
+ f"got {tuple(active_mask.shape)}."
199
+ )
200
+
201
+ def _make_block_mask(
202
+ self,
203
+ query_active_mask: torch.Tensor,
204
+ key_active_mask: torch.Tensor,
205
+ num_tokens_processed: torch.Tensor,
206
+ query_length: int,
207
+ key_length: int,
208
+ device: torch.device,
209
+ ):
210
+ """Create the packed-sequence causal mask for FlexAttention.
211
+
212
+ At the root, causality is still triangular. The only nuance is cached
213
+ execution: query rows are indexed locally as 0..Q-1 inside the current
214
+ query tensor, but the key tensor may already contain a cached prefix for
215
+ that (batch, head) slot. The causal horizon for query tensor row q is
216
+ therefore:
217
+
218
+ cached_prefix_lengths[b, h] + q
219
+
220
+ Query and key activity masks are then composed with that triangular rule
221
+ so FlexAttention can skip padded query rows and ignore inactive key slots.
222
+ """
223
+ batch_size, num_heads, _ = query_active_mask.shape
224
+
225
+ # Build the per-(batch, head, query_row) triangular horizon from a simple
226
+ # arange over query rows plus the cached prefix lengths for each slot.
227
+ relative_query_positions = torch.arange(
228
+ query_length,
229
+ device=device,
230
+ dtype=torch.long,
231
+ ).view(1, 1, query_length)
232
+ causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
233
+
234
+ def packed_causal_mask(
235
+ batch_idx: torch.Tensor,
236
+ head_idx: torch.Tensor,
237
+ query_idx: torch.Tensor,
238
+ key_idx: torch.Tensor,
239
+ ) -> torch.Tensor:
240
+ query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
241
+ key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
242
+ is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
243
+ return query_is_active & key_is_active & is_causal
244
+
245
+ return create_block_mask(
246
+ packed_causal_mask,
247
+ B=batch_size,
248
+ H=num_heads,
249
+ Q_LEN=query_length,
250
+ KV_LEN=key_length,
251
+ device=device,
252
  )
__attention__expert_packing.py CHANGED
@@ -1,335 +1,335 @@
1
- """Expert packing and unpacking for the MoSRAH path.
2
-
3
- This module implements the low-level token-choice -> expert-choice -> token-choice
4
- conversion boundary specified in the paper. The externally visible behavior is fixed:
5
-
6
- - setup_packing() prepares the auxiliary ordering data.
7
- - pack_experts() converts routed token-choice state into packed expert-choice state.
8
- - unpack_experts() restores token-choice ordering afterward.
9
-
10
- Stable sort is a correctness requirement. It preserves causal ordering inside each
11
- expert bucket, which is the foundation on which BEA's later triangular causal mask
12
- is correct.
13
-
14
- pack_experts() returns two distinct masks that serve different roles and must not
15
- be interchanged:
16
-
17
- - unpacking_mask: marks every packed slot that contains a routed token copy,
18
- live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
19
- so its reshape invariant holds regardless of outer token liveness.
20
- - active_mask: marks only the packed slots whose source token was semantically
21
- live. This is what BEA consumes for attention gating. Dead outer tokens must
22
- not influence sparse attention outputs.
23
- """
24
-
25
- import torch
26
-
27
-
28
- # ---------------------------------------------------------------------------
29
- # Setup
30
- # ---------------------------------------------------------------------------
31
-
32
- def setup_packing(
33
- selected_heads: torch.Tensor,
34
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
- """Prepare the auxiliary ordering data used by pack/unpack.
36
-
37
- Routing produces token-choice state I of shape (B, N, K): for each token, which
38
- K experts were selected. Packing needs the same routed token copies reordered into
39
- expert-major order so each expert bucket becomes contiguous.
40
-
41
- The paper's setup step does this by flattening (N, K) into one axis to produce
42
- H in token-major order, then computing a stable argsort permutation Pi over the
43
- expert indices stored in H. Applying Pi reorders the flattened routed copies into
44
- expert-major order while preserving their original token order *within* each expert
45
- bucket. That preservation is why stable sort is required for causality.
46
-
47
- Args:
48
- selected_heads: Routed token-choice head selections I of shape (B, N, K).
49
-
50
- Returns:
51
- Tuple of:
52
- - flattened_selected_heads: H of shape (B, N*K)
53
- - permutation: stable expert-major permutation Pi of shape (B, N*K)
54
- - inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
55
- """
56
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
57
- flattened_selected_heads = selected_heads.reshape(
58
- batch_size,
59
- sequence_length * num_selected_heads,
60
- )
61
-
62
- permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
63
- inverse_permutation = torch.argsort(permutation, dim=-1)
64
-
65
- return flattened_selected_heads, permutation, inverse_permutation
66
-
67
-
68
- # ---------------------------------------------------------------------------
69
- # Packing
70
- # ---------------------------------------------------------------------------
71
-
72
- def pack_experts(
73
- hidden_states: torch.Tensor,
74
- position_ids: torch.Tensor,
75
- selected_heads: torch.Tensor,
76
- num_experts: int,
77
- flattened_selected_heads: torch.Tensor,
78
- permutation: torch.Tensor,
79
- outer_active_mask: torch.Tensor,
80
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
- """Pack token-choice hidden states into expert-choice padded form.
82
-
83
- The paper's packing path has two jobs:
84
-
85
- 1. Convert routed token-choice copies into expert-major order.
86
- 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
87
-
88
- The routed hidden-state copies are not stored explicitly in token-choice form.
89
- Instead, the same token hidden state is conceptually copied once per selected expert.
90
- The packing step reconstructs those copies by expanding local source-token indices,
91
- reordering those indices with Pi, then gathering hidden states, positions, and outer
92
- liveness in that packed order. All three are carried through the same expert-major
93
- rearrangement so they remain aligned in the packed frame.
94
-
95
- Packed positions are sourced from the authoritative upstream position_ids tensor
96
- rather than synthesized locally from arange(N). This preserves advanced positions
97
- correctly during cached inference while leaving training/full-sequence behavior
98
- unchanged when position_ids is the ordinary sequential token positions.
99
-
100
- Args:
101
- hidden_states: Token-choice hidden states x of shape (B, N, d).
102
- position_ids: Authoritative upstream token positions J of shape (B, N).
103
- selected_heads: Routed head selections I of shape (B, N, K).
104
- num_experts: Total number of experts L.
105
- flattened_selected_heads: H from setup_packing(), shape (B, N*K).
106
- permutation: Pi from setup_packing(), shape (B, N*K).
107
- outer_active_mask: Current-chunk active mask of shape (B, N), where True
108
- means the token is semantically live. Dead tokens do not become
109
- semantically active in the packed sparse representation.
110
-
111
- Returns:
112
- Tuple of:
113
- - packed_hidden_states: x' of shape (B, L, T, d)
114
- - packed_positions: J' of shape (B, L, T)
115
- - unpacking_mask: of shape (B, L, T). True where a slot contains any
116
- routed token copy, live or dead. Always has exactly B*N*K True entries.
117
- Pass this to unpack_experts — not active_mask.
118
- - active_mask: of shape (B, L, T). True only where a slot contains a
119
- copy of a live outer token. Pass this to BEA for attention gating.
120
- """
121
- batch_size, sequence_length, hidden_dim = hidden_states.shape
122
- _, _, num_selected_heads = selected_heads.shape
123
-
124
- # -----------------------------------------------------------------------
125
- # Reconstruct routed local source-token indices in token-choice order.
126
- #
127
- # The internal arange(N) is no longer the packed position tensor. It is only
128
- # the local source-row index object used to gather from the current chunk
129
- # tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
130
- # token-major routed-copy order.
131
- # -----------------------------------------------------------------------
132
- source_token_indices = torch.arange(
133
- sequence_length,
134
- device=hidden_states.device,
135
- dtype=torch.long,
136
- ).view(1, sequence_length, 1).expand(
137
- batch_size,
138
- sequence_length,
139
- num_selected_heads,
140
- )
141
- flattened_source_indices = source_token_indices.reshape(
142
- batch_size,
143
- sequence_length * num_selected_heads,
144
- )
145
-
146
- # -----------------------------------------------------------------------
147
- # Reorder source-token indices into expert-major order.
148
- #
149
- # Applying Pi yields the local source-token rows in the packed expert-major
150
- # order required by the paper. Those same reordered source indices are then
151
- # used to gather hidden states, authoritative upstream positions, and outer
152
- # liveness so all three remain aligned under the exact same packing
153
- # transformation.
154
- # -----------------------------------------------------------------------
155
- sorted_source_indices = flattened_source_indices.gather(
156
- dim=1,
157
- index=permutation,
158
- )
159
- sorted_hidden_states = hidden_states.gather(
160
- dim=1,
161
- index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
162
- )
163
- sorted_positions = position_ids.gather(
164
- dim=1,
165
- index=sorted_source_indices,
166
- )
167
- sorted_active_mask = outer_active_mask.gather(
168
- dim=1,
169
- index=sorted_source_indices,
170
- )
171
-
172
- # -----------------------------------------------------------------------
173
- # Count how many routed copies land in each expert bucket.
174
- #
175
- # S[b, l] is the number of routed token copies assigned to expert l in batch b.
176
- # T is the maximum such count across all batches and experts; it determines the
177
- # padded expert-length dimension of the packed representation.
178
- # -----------------------------------------------------------------------
179
- tokens_per_expert = _bincount_rows(
180
- values=flattened_selected_heads,
181
- num_bins=num_experts,
182
- )
183
- max_tokens_per_expert = int(tokens_per_expert.max().item())
184
-
185
- # -----------------------------------------------------------------------
186
- # Construct the active-token mask M.
187
- #
188
- # Each expert bucket is left-justified: if S[b, l] = s, then slots
189
- # t = 0, ..., s-1 are active and all later slots are padding. The resulting
190
- # mask therefore both identifies real packed tokens and enforces left-justified
191
- # packing. This is the unpacking_mask — it marks slot occupancy regardless of
192
- # outer token liveness, and always has exactly B*N*K True entries.
193
- # -----------------------------------------------------------------------
194
- time_axis = torch.arange(
195
- max_tokens_per_expert,
196
- device=hidden_states.device,
197
- dtype=torch.long,
198
- ).view(1, 1, max_tokens_per_expert)
199
- unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
200
-
201
- # -----------------------------------------------------------------------
202
- # Materialize the padded packed tensors.
203
- #
204
- # The packed hidden states x', packed original-token positions J', and packed
205
- # active-token mask are allocated as zero-filled tensors. Active entries are
206
- # then written into those buffers in the expert-major order established above.
207
- # Padding remains zero / inactive.
208
- # -----------------------------------------------------------------------
209
- packed_hidden_states = hidden_states.new_zeros(
210
- batch_size,
211
- num_experts,
212
- max_tokens_per_expert,
213
- hidden_dim,
214
- )
215
- packed_positions = position_ids.new_zeros(
216
- batch_size,
217
- num_experts,
218
- max_tokens_per_expert,
219
- )
220
- active_mask = torch.zeros(
221
- batch_size,
222
- num_experts,
223
- max_tokens_per_expert,
224
- dtype=torch.bool,
225
- device=hidden_states.device,
226
- )
227
-
228
- packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
229
- packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
230
- active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
231
-
232
- return packed_hidden_states, packed_positions, unpacking_mask, active_mask
233
-
234
-
235
- # ---------------------------------------------------------------------------
236
- # Unpacking
237
- # ---------------------------------------------------------------------------
238
-
239
- def unpack_experts(
240
- expert_outputs: torch.Tensor,
241
- selected_heads: torch.Tensor,
242
- unpacking_mask: torch.Tensor,
243
- inverse_permutation: torch.Tensor,
244
- ) -> torch.Tensor:
245
- """Restore token-choice ordering from BEA expert-choice output.
246
-
247
- Unpacking inverts the packing path only on occupied entries. Padding does not
248
- participate: the output tensor is first filtered by unpacking_mask to recover
249
- only the real routed-token copies in expert-major order, then Pi^{-1} restores
250
- the original token-choice ordering, and finally the tensor is reshaped back to
251
- (B, N, K, d).
252
-
253
- The unpacking_mask — not active_mask — must be used here. Even copies of dead
254
- outer tokens occupy slots and must be un-scattered correctly for the inverse
255
- permutation to hold. The total True entry count in unpacking_mask is always
256
- B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
257
-
258
- Args:
259
- expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
260
- selected_heads: Routed head selections I of shape (B, N, K).
261
- unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
262
- occupied packed slots regardless of outer token liveness.
263
- inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
264
-
265
- Returns:
266
- Restored token-choice tensor y_tilde of shape (B, N, K, d).
267
- """
268
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
269
- hidden_dim = expert_outputs.shape[-1]
270
-
271
- active_outputs = expert_outputs[unpacking_mask]
272
- sorted_token_choice_outputs = active_outputs.reshape(
273
- batch_size,
274
- sequence_length * num_selected_heads,
275
- hidden_dim,
276
- )
277
- restored_outputs = sorted_token_choice_outputs.gather(
278
- dim=1,
279
- index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
280
- )
281
-
282
- return restored_outputs.reshape(
283
- batch_size,
284
- sequence_length,
285
- num_selected_heads,
286
- hidden_dim,
287
- )
288
-
289
-
290
- # ---------------------------------------------------------------------------
291
- # Helpers
292
- # ---------------------------------------------------------------------------
293
-
294
- def _bincount_rows(
295
- values: torch.Tensor,
296
- num_bins: int,
297
- ) -> torch.Tensor:
298
- """Count per-row integer occurrences for a 2D tensor.
299
-
300
- torch.bincount operates on a flat 1D vector, but the packing algorithm needs
301
- one bincount per batch row. The trick used here is to shift each row into its
302
- own disjoint bin range before flattening:
303
-
304
- row 0 uses bins [0, ..., num_bins - 1]
305
- row 1 uses bins [num_bins, ..., 2*num_bins - 1]
306
- row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
307
- ...
308
-
309
- After that shift, one global torch.bincount produces all row-local counts at
310
- once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
311
-
312
- This is a vectorized implementation detail only; externally visible behavior
313
- remains exactly the paper's S tensor of per-batch per-expert token counts.
314
-
315
- Args:
316
- values: Integer tensor of shape (B, M) with entries in [0, num_bins).
317
- num_bins: Number of bins.
318
-
319
- Returns:
320
- Counts tensor of shape (B, num_bins).
321
- """
322
- batch_size = values.shape[0]
323
-
324
- row_offsets = torch.arange(
325
- batch_size,
326
- device=values.device,
327
- dtype=values.dtype,
328
- ).unsqueeze(1) * num_bins
329
- shifted_values = values + row_offsets
330
-
331
- counts = torch.bincount(
332
- shifted_values.reshape(-1),
333
- minlength=batch_size * num_bins,
334
- )
335
- return counts.reshape(batch_size, num_bins)
 
1
+ """Expert packing and unpacking for the MoSRAH path.
2
+
3
+ This module implements the low-level token-choice -> expert-choice -> token-choice
4
+ conversion boundary specified in the paper. The externally visible behavior is fixed:
5
+
6
+ - setup_packing() prepares the auxiliary ordering data.
7
+ - pack_experts() converts routed token-choice state into packed expert-choice state.
8
+ - unpack_experts() restores token-choice ordering afterward.
9
+
10
+ Stable sort is a correctness requirement. It preserves causal ordering inside each
11
+ expert bucket, which is the foundation on which BEA's later triangular causal mask
12
+ is correct.
13
+
14
+ pack_experts() returns two distinct masks that serve different roles and must not
15
+ be interchanged:
16
+
17
+ - unpacking_mask: marks every packed slot that contains a routed token copy,
18
+ live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
19
+ so its reshape invariant holds regardless of outer token liveness.
20
+ - active_mask: marks only the packed slots whose source token was semantically
21
+ live. This is what BEA consumes for attention gating. Dead outer tokens must
22
+ not influence sparse attention outputs.
23
+ """
24
+
25
+ import torch
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Setup
30
+ # ---------------------------------------------------------------------------
31
+
32
+ def setup_packing(
33
+ selected_heads: torch.Tensor,
34
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
+ """Prepare the auxiliary ordering data used by pack/unpack.
36
+
37
+ Routing produces token-choice state I of shape (B, N, K): for each token, which
38
+ K experts were selected. Packing needs the same routed token copies reordered into
39
+ expert-major order so each expert bucket becomes contiguous.
40
+
41
+ The paper's setup step does this by flattening (N, K) into one axis to produce
42
+ H in token-major order, then computing a stable argsort permutation Pi over the
43
+ expert indices stored in H. Applying Pi reorders the flattened routed copies into
44
+ expert-major order while preserving their original token order *within* each expert
45
+ bucket. That preservation is why stable sort is required for causality.
46
+
47
+ Args:
48
+ selected_heads: Routed token-choice head selections I of shape (B, N, K).
49
+
50
+ Returns:
51
+ Tuple of:
52
+ - flattened_selected_heads: H of shape (B, N*K)
53
+ - permutation: stable expert-major permutation Pi of shape (B, N*K)
54
+ - inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
55
+ """
56
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
57
+ flattened_selected_heads = selected_heads.reshape(
58
+ batch_size,
59
+ sequence_length * num_selected_heads,
60
+ )
61
+
62
+ permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
63
+ inverse_permutation = torch.argsort(permutation, dim=-1)
64
+
65
+ return flattened_selected_heads, permutation, inverse_permutation
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Packing
70
+ # ---------------------------------------------------------------------------
71
+
72
+ def pack_experts(
73
+ hidden_states: torch.Tensor,
74
+ position_ids: torch.Tensor,
75
+ selected_heads: torch.Tensor,
76
+ num_experts: int,
77
+ flattened_selected_heads: torch.Tensor,
78
+ permutation: torch.Tensor,
79
+ outer_active_mask: torch.Tensor,
80
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
+ """Pack token-choice hidden states into expert-choice padded form.
82
+
83
+ The paper's packing path has two jobs:
84
+
85
+ 1. Convert routed token-choice copies into expert-major order.
86
+ 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
87
+
88
+ The routed hidden-state copies are not stored explicitly in token-choice form.
89
+ Instead, the same token hidden state is conceptually copied once per selected expert.
90
+ The packing step reconstructs those copies by expanding local source-token indices,
91
+ reordering those indices with Pi, then gathering hidden states, positions, and outer
92
+ liveness in that packed order. All three are carried through the same expert-major
93
+ rearrangement so they remain aligned in the packed frame.
94
+
95
+ Packed positions are sourced from the authoritative upstream position_ids tensor
96
+ rather than synthesized locally from arange(N). This preserves advanced positions
97
+ correctly during cached inference while leaving training/full-sequence behavior
98
+ unchanged when position_ids is the ordinary sequential token positions.
99
+
100
+ Args:
101
+ hidden_states: Token-choice hidden states x of shape (B, N, d).
102
+ position_ids: Authoritative upstream token positions J of shape (B, N).
103
+ selected_heads: Routed head selections I of shape (B, N, K).
104
+ num_experts: Total number of experts L.
105
+ flattened_selected_heads: H from setup_packing(), shape (B, N*K).
106
+ permutation: Pi from setup_packing(), shape (B, N*K).
107
+ outer_active_mask: Current-chunk active mask of shape (B, N), where True
108
+ means the token is semantically live. Dead tokens do not become
109
+ semantically active in the packed sparse representation.
110
+
111
+ Returns:
112
+ Tuple of:
113
+ - packed_hidden_states: x' of shape (B, L, T, d)
114
+ - packed_positions: J' of shape (B, L, T)
115
+ - unpacking_mask: of shape (B, L, T). True where a slot contains any
116
+ routed token copy, live or dead. Always has exactly B*N*K True entries.
117
+ Pass this to unpack_experts — not active_mask.
118
+ - active_mask: of shape (B, L, T). True only where a slot contains a
119
+ copy of a live outer token. Pass this to BEA for attention gating.
120
+ """
121
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
122
+ _, _, num_selected_heads = selected_heads.shape
123
+
124
+ # -----------------------------------------------------------------------
125
+ # Reconstruct routed local source-token indices in token-choice order.
126
+ #
127
+ # The internal arange(N) is no longer the packed position tensor. It is only
128
+ # the local source-row index object used to gather from the current chunk
129
+ # tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
130
+ # token-major routed-copy order.
131
+ # -----------------------------------------------------------------------
132
+ source_token_indices = torch.arange(
133
+ sequence_length,
134
+ device=hidden_states.device,
135
+ dtype=torch.long,
136
+ ).view(1, sequence_length, 1).expand(
137
+ batch_size,
138
+ sequence_length,
139
+ num_selected_heads,
140
+ )
141
+ flattened_source_indices = source_token_indices.reshape(
142
+ batch_size,
143
+ sequence_length * num_selected_heads,
144
+ )
145
+
146
+ # -----------------------------------------------------------------------
147
+ # Reorder source-token indices into expert-major order.
148
+ #
149
+ # Applying Pi yields the local source-token rows in the packed expert-major
150
+ # order required by the paper. Those same reordered source indices are then
151
+ # used to gather hidden states, authoritative upstream positions, and outer
152
+ # liveness so all three remain aligned under the exact same packing
153
+ # transformation.
154
+ # -----------------------------------------------------------------------
155
+ sorted_source_indices = flattened_source_indices.gather(
156
+ dim=1,
157
+ index=permutation,
158
+ )
159
+ sorted_hidden_states = hidden_states.gather(
160
+ dim=1,
161
+ index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
162
+ )
163
+ sorted_positions = position_ids.gather(
164
+ dim=1,
165
+ index=sorted_source_indices,
166
+ )
167
+ sorted_active_mask = outer_active_mask.gather(
168
+ dim=1,
169
+ index=sorted_source_indices,
170
+ )
171
+
172
+ # -----------------------------------------------------------------------
173
+ # Count how many routed copies land in each expert bucket.
174
+ #
175
+ # S[b, l] is the number of routed token copies assigned to expert l in batch b.
176
+ # T is the maximum such count across all batches and experts; it determines the
177
+ # padded expert-length dimension of the packed representation.
178
+ # -----------------------------------------------------------------------
179
+ tokens_per_expert = _bincount_rows(
180
+ values=flattened_selected_heads,
181
+ num_bins=num_experts,
182
+ )
183
+ max_tokens_per_expert = int(tokens_per_expert.max().item())
184
+
185
+ # -----------------------------------------------------------------------
186
+ # Construct the active-token mask M.
187
+ #
188
+ # Each expert bucket is left-justified: if S[b, l] = s, then slots
189
+ # t = 0, ..., s-1 are active and all later slots are padding. The resulting
190
+ # mask therefore both identifies real packed tokens and enforces left-justified
191
+ # packing. This is the unpacking_mask — it marks slot occupancy regardless of
192
+ # outer token liveness, and always has exactly B*N*K True entries.
193
+ # -----------------------------------------------------------------------
194
+ time_axis = torch.arange(
195
+ max_tokens_per_expert,
196
+ device=hidden_states.device,
197
+ dtype=torch.long,
198
+ ).view(1, 1, max_tokens_per_expert)
199
+ unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
200
+
201
+ # -----------------------------------------------------------------------
202
+ # Materialize the padded packed tensors.
203
+ #
204
+ # The packed hidden states x', packed original-token positions J', and packed
205
+ # active-token mask are allocated as zero-filled tensors. Active entries are
206
+ # then written into those buffers in the expert-major order established above.
207
+ # Padding remains zero / inactive.
208
+ # -----------------------------------------------------------------------
209
+ packed_hidden_states = hidden_states.new_zeros(
210
+ batch_size,
211
+ num_experts,
212
+ max_tokens_per_expert,
213
+ hidden_dim,
214
+ )
215
+ packed_positions = position_ids.new_zeros(
216
+ batch_size,
217
+ num_experts,
218
+ max_tokens_per_expert,
219
+ )
220
+ active_mask = torch.zeros(
221
+ batch_size,
222
+ num_experts,
223
+ max_tokens_per_expert,
224
+ dtype=torch.bool,
225
+ device=hidden_states.device,
226
+ )
227
+
228
+ packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
229
+ packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
230
+ active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
231
+
232
+ return packed_hidden_states, packed_positions, unpacking_mask, active_mask
233
+
234
+
235
+ # ---------------------------------------------------------------------------
236
+ # Unpacking
237
+ # ---------------------------------------------------------------------------
238
+
239
+ def unpack_experts(
240
+ expert_outputs: torch.Tensor,
241
+ selected_heads: torch.Tensor,
242
+ unpacking_mask: torch.Tensor,
243
+ inverse_permutation: torch.Tensor,
244
+ ) -> torch.Tensor:
245
+ """Restore token-choice ordering from BEA expert-choice output.
246
+
247
+ Unpacking inverts the packing path only on occupied entries. Padding does not
248
+ participate: the output tensor is first filtered by unpacking_mask to recover
249
+ only the real routed-token copies in expert-major order, then Pi^{-1} restores
250
+ the original token-choice ordering, and finally the tensor is reshaped back to
251
+ (B, N, K, d).
252
+
253
+ The unpacking_mask — not active_mask — must be used here. Even copies of dead
254
+ outer tokens occupy slots and must be un-scattered correctly for the inverse
255
+ permutation to hold. The total True entry count in unpacking_mask is always
256
+ B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
257
+
258
+ Args:
259
+ expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
260
+ selected_heads: Routed head selections I of shape (B, N, K).
261
+ unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
262
+ occupied packed slots regardless of outer token liveness.
263
+ inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
264
+
265
+ Returns:
266
+ Restored token-choice tensor y_tilde of shape (B, N, K, d).
267
+ """
268
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
269
+ hidden_dim = expert_outputs.shape[-1]
270
+
271
+ active_outputs = expert_outputs[unpacking_mask]
272
+ sorted_token_choice_outputs = active_outputs.reshape(
273
+ batch_size,
274
+ sequence_length * num_selected_heads,
275
+ hidden_dim,
276
+ )
277
+ restored_outputs = sorted_token_choice_outputs.gather(
278
+ dim=1,
279
+ index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
280
+ )
281
+
282
+ return restored_outputs.reshape(
283
+ batch_size,
284
+ sequence_length,
285
+ num_selected_heads,
286
+ hidden_dim,
287
+ )
288
+
289
+
290
+ # ---------------------------------------------------------------------------
291
+ # Helpers
292
+ # ---------------------------------------------------------------------------
293
+
294
+ def _bincount_rows(
295
+ values: torch.Tensor,
296
+ num_bins: int,
297
+ ) -> torch.Tensor:
298
+ """Count per-row integer occurrences for a 2D tensor.
299
+
300
+ torch.bincount operates on a flat 1D vector, but the packing algorithm needs
301
+ one bincount per batch row. The trick used here is to shift each row into its
302
+ own disjoint bin range before flattening:
303
+
304
+ row 0 uses bins [0, ..., num_bins - 1]
305
+ row 1 uses bins [num_bins, ..., 2*num_bins - 1]
306
+ row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
307
+ ...
308
+
309
+ After that shift, one global torch.bincount produces all row-local counts at
310
+ once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
311
+
312
+ This is a vectorized implementation detail only; externally visible behavior
313
+ remains exactly the paper's S tensor of per-batch per-expert token counts.
314
+
315
+ Args:
316
+ values: Integer tensor of shape (B, M) with entries in [0, num_bins).
317
+ num_bins: Number of bins.
318
+
319
+ Returns:
320
+ Counts tensor of shape (B, num_bins).
321
+ """
322
+ batch_size = values.shape[0]
323
+
324
+ row_offsets = torch.arange(
325
+ batch_size,
326
+ device=values.device,
327
+ dtype=values.dtype,
328
+ ).unsqueeze(1) * num_bins
329
+ shifted_values = values + row_offsets
330
+
331
+ counts = torch.bincount(
332
+ shifted_values.reshape(-1),
333
+ minlength=batch_size * num_bins,
334
+ )
335
+ return counts.reshape(batch_size, num_bins)
__attention__load_balance_loss.py CHANGED
@@ -1,88 +1,88 @@
1
- """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2
-
3
- This module implements the custom autograd Function H(b, f) described in the paper's
4
- Implementation Concerns section. The operator bridges two requirements that are in
5
- tension: it must behave like a standard auxiliary loss (scalar output, scalable via
6
- multiplication) so that existing training loops remain compatible, while simultaneously
7
- implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
8
- path through the router weights.
9
-
10
- The resolution is a custom backward pass. The forward emits the load balance imbalance
11
- as a scalar loss. The backward, instead of differentiating that scalar with respect to
12
- its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
13
- then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
14
- correction without a standalone SGD update step.
15
-
16
- Paper ref: Appendix A.Implementation Concerns.
17
- """
18
-
19
- import torch
20
-
21
-
22
- class LoadBalanceLoss(torch.autograd.Function):
23
- """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
24
-
25
- Forward computes the load balance imbalance:
26
-
27
- L_load_balance = H(b, f) = sum_l | f_l - 1/L |
28
-
29
- Backward emits a bias-correction gradient to expert_bias:
30
-
31
- grad_b = L_grad * sign(f_l - 1/L)
32
-
33
- expert_bias (b) is included as a forward input so PyTorch registers it as a node
34
- in the computation graph and routes gradients through it. routing_freqs (f) receives
35
- no gradient — its origin is the discrete TopK operation which has no gradient, so
36
- defining a gradient for f here would be mathematically incorrect.
37
-
38
- Paper ref: Appendix A.Implementation Concerns.
39
- """
40
-
41
- @staticmethod
42
- def forward(
43
- ctx: torch.autograd.function.FunctionCtx,
44
- expert_bias: torch.Tensor,
45
- routing_freqs: torch.Tensor,
46
- ) -> torch.Tensor:
47
- """Compute the load balance loss.
48
-
49
- Args:
50
- ctx: Autograd context for saving state needed in backward.
51
- expert_bias: Learned per-head bias b, shape (L,). Included as an input so
52
- PyTorch tracks it as a computation graph node needing a gradient.
53
- routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
54
- from the discrete TopK selection — not differentiable.
55
-
56
- Returns:
57
- Scalar loss equal to sum_l |f_l - 1/L|.
58
- """
59
- L = expert_bias.shape[0]
60
- # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
61
- # underloaded. Saved for backward where sign(imbalance) determines the direction
62
- # of the bias-correction update.
63
- imbalance = routing_freqs - 1.0 / L
64
- ctx.save_for_backward(imbalance)
65
- return imbalance.abs().sum()
66
-
67
- @staticmethod
68
- def backward(
69
- ctx: torch.autograd.function.FunctionCtx,
70
- grad_output: torch.Tensor,
71
- ) -> tuple[torch.Tensor, None]:
72
- """Emit the DeepSeek-style bias-correction gradient.
73
-
74
- Args:
75
- ctx: Autograd context carrying imbalance saved in forward.
76
- grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
77
- by the training loop arrives here and is propagated to grad_b, so the
78
- correction magnitude is proportional to the loss weight chosen by the
79
- consumer.
80
-
81
- Returns:
82
- Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
83
- None for routing_freqs: no gradient is defined for the discrete routing
84
- frequency.
85
- """
86
- (imbalance,) = ctx.saved_tensors
87
- grad_expert_bias = grad_output * imbalance.sign()
88
- return grad_expert_bias, None
 
1
+ """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2
+
3
+ This module implements the custom autograd Function H(b, f) described in the paper's
4
+ Implementation Concerns section. The operator bridges two requirements that are in
5
+ tension: it must behave like a standard auxiliary loss (scalar output, scalable via
6
+ multiplication) so that existing training loops remain compatible, while simultaneously
7
+ implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
8
+ path through the router weights.
9
+
10
+ The resolution is a custom backward pass. The forward emits the load balance imbalance
11
+ as a scalar loss. The backward, instead of differentiating that scalar with respect to
12
+ its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
13
+ then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
14
+ correction without a standalone SGD update step.
15
+
16
+ Paper ref: Appendix A.Implementation Concerns.
17
+ """
18
+
19
+ import torch
20
+
21
+
22
+ class LoadBalanceLoss(torch.autograd.Function):
23
+ """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
24
+
25
+ Forward computes the load balance imbalance:
26
+
27
+ L_load_balance = H(b, f) = sum_l | f_l - 1/L |
28
+
29
+ Backward emits a bias-correction gradient to expert_bias:
30
+
31
+ grad_b = L_grad * sign(f_l - 1/L)
32
+
33
+ expert_bias (b) is included as a forward input so PyTorch registers it as a node
34
+ in the computation graph and routes gradients through it. routing_freqs (f) receives
35
+ no gradient — its origin is the discrete TopK operation which has no gradient, so
36
+ defining a gradient for f here would be mathematically incorrect.
37
+
38
+ Paper ref: Appendix A.Implementation Concerns.
39
+ """
40
+
41
+ @staticmethod
42
+ def forward(
43
+ ctx: torch.autograd.function.FunctionCtx,
44
+ expert_bias: torch.Tensor,
45
+ routing_freqs: torch.Tensor,
46
+ ) -> torch.Tensor:
47
+ """Compute the load balance loss.
48
+
49
+ Args:
50
+ ctx: Autograd context for saving state needed in backward.
51
+ expert_bias: Learned per-head bias b, shape (L,). Included as an input so
52
+ PyTorch tracks it as a computation graph node needing a gradient.
53
+ routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
54
+ from the discrete TopK selection — not differentiable.
55
+
56
+ Returns:
57
+ Scalar loss equal to sum_l |f_l - 1/L|.
58
+ """
59
+ L = expert_bias.shape[0]
60
+ # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
61
+ # underloaded. Saved for backward where sign(imbalance) determines the direction
62
+ # of the bias-correction update.
63
+ imbalance = routing_freqs - 1.0 / L
64
+ ctx.save_for_backward(imbalance)
65
+ return imbalance.abs().sum()
66
+
67
+ @staticmethod
68
+ def backward(
69
+ ctx: torch.autograd.function.FunctionCtx,
70
+ grad_output: torch.Tensor,
71
+ ) -> tuple[torch.Tensor, None]:
72
+ """Emit the DeepSeek-style bias-correction gradient.
73
+
74
+ Args:
75
+ ctx: Autograd context carrying imbalance saved in forward.
76
+ grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
77
+ by the training loop arrives here and is propagated to grad_b, so the
78
+ correction magnitude is proportional to the loss weight chosen by the
79
+ consumer.
80
+
81
+ Returns:
82
+ Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
83
+ None for routing_freqs: no gradient is defined for the discrete routing
84
+ frequency.
85
+ """
86
+ (imbalance,) = ctx.saved_tensors
87
+ grad_expert_bias = grad_output * imbalance.sign()
88
+ return grad_expert_bias, None
__attention__mosrah.py CHANGED
@@ -1,140 +1,140 @@
1
- """Full MoSRAH sparse path for SHRAM.
2
-
3
- This module coordinates the routed sparse attention path used inside the SHRAM
4
- hybrid attention layer. The underlying mechanics already live in verified
5
- subunits. The responsibility here is to connect those subunits without
6
- corrupting their bridge contracts.
7
-
8
- In particular, this path must preserve three architectural distinctions:
9
-
10
- - selected head indices are not routing probabilities
11
- - packed position semantics are chosen before BEA, not inside it
12
- - weighted reduction must consume the router's unbiased renormalized
13
- probabilities after token-choice order has been restored
14
- """
15
-
16
- import torch
17
- from torch import nn
18
-
19
- from .__cache__mosrah_cache import MoSRAHCache
20
- from .configuration import ShramConfig
21
- from .__attention__bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
22
- from .__attention__expert_packing import (
23
- pack_experts,
24
- setup_packing,
25
- unpack_experts,
26
- )
27
- from .__attention__router import MoSRAHRouter
28
- from .__attention__positions_converter import SparseMoSRAHPositions
29
-
30
-
31
- class MoSRAHLayer(nn.Module):
32
- """Full routed sparse attention path for SHRAM.
33
-
34
- The MoSRAH path consumes model-space hidden states together with
35
- authoritative per-token positions and returns the model-space sparse-path
36
- contribution, the router's load-balance loss, and the router's MaxVio
37
- routing-imbalance scalar.
38
- """
39
-
40
- def __init__(self, config: ShramConfig) -> None:
41
- super().__init__()
42
- self.num_experts = config.num_mosrah_heads
43
-
44
- self.router = MoSRAHRouter(config)
45
- self.positions = SparseMoSRAHPositions(config)
46
- self.bea = BottleneckedEnsembleAttention(config)
47
-
48
- def forward(
49
- self,
50
- hidden_states: torch.Tensor,
51
- position_ids: torch.Tensor,
52
- active_mask: torch.Tensor,
53
- cache: MoSRAHCache | None,
54
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55
- """Run the full MoSRAH sparse path.
56
-
57
- Args:
58
- hidden_states: Model-space hidden states x of shape (B, N, d).
59
- position_ids: Authoritative per-token positions of shape (B, N).
60
- active_mask: Current-chunk active mask of shape (B, N), where True
61
- means the token is semantically live. Forwarded to the router
62
- so dead tokens are excluded from routing statistics, and to
63
- pack_experts so dead outer tokens do not become semantically
64
- active packed entries.
65
- cache: Optional layer-local MoSRAH cache. Pass None for uncached
66
- execution and the layer-local cache instance for cached execution.
67
-
68
- Returns:
69
- sparse_output: Model-space sparse-path output of shape (B, N, d).
70
- load_balance_loss: Scalar router load-balance loss.
71
- max_vio: Detached scalar routing-imbalance summary. Passed through
72
- unchanged from the router; see MoSRAHRouter for semantics.
73
- """
74
-
75
- # -------------------------------------------------------------------
76
- # The first transition moves from model-space token-choice input into
77
- # the packed expert-choice sparse-attention state. Routing decides both
78
- # which experts each token uses and which unbiased probabilities must be
79
- # reserved for the final reduction. The active mask is forwarded to the
80
- # router so dead tokens are excluded from routing statistics, and to
81
- # pack_experts so outer liveness is faithfully carried into the packed
82
- # frame. Packing returns both the unpacking mask (slot occupancy, always
83
- # B*N*K True entries) and the packed active mask (live slots only);
84
- # active_mask is rebound to the packed form after this point.
85
- # -------------------------------------------------------------------
86
- selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
87
- hidden_states, active_mask
88
- )
89
-
90
- flattened_selected_heads, permutation, inverse_permutation = setup_packing(
91
- selected_heads
92
- )
93
- packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
94
- hidden_states=hidden_states,
95
- position_ids=position_ids,
96
- selected_heads=selected_heads,
97
- num_experts=self.num_experts,
98
- flattened_selected_heads=flattened_selected_heads,
99
- permutation=permutation,
100
- outer_active_mask=active_mask,
101
- )
102
-
103
- # -------------------------------------------------------------------
104
- # Sparse attention runs entirely in the packed expert-choice frame, so
105
- # the RoPE position semantics must also be chosen in that frame. The
106
- # position layer therefore decides whether BEA should see packed
107
- # original-token positions or packed local-slot positions. BEA then
108
- # consumes that packed position tensor together with the packed hidden
109
- # states and the layer-local sparse cache, which it owns directly.
110
- # -------------------------------------------------------------------
111
- bea_positions = self.positions(
112
- packed_positions=packed_positions,
113
- cache=cache,
114
- )
115
- packed_outputs = self.bea(
116
- packed_embeddings=packed_hidden_states,
117
- position_ids=bea_positions,
118
- active_mask=active_mask,
119
- cache=cache,
120
- )
121
-
122
- # -------------------------------------------------------------------
123
- # The final transition restores token-choice meaning and only then
124
- # collapses the K routed copies back into model space. This ordering is
125
- # required because routing_probs live in token-choice space, whereas BEA
126
- # returns expert-choice packed outputs. The reduction must therefore
127
- # happen after unpacking, and it must use the router's unbiased
128
- # renormalized probabilities rather than any biased selection scores.
129
- # -------------------------------------------------------------------
130
- token_choice_outputs = unpack_experts(
131
- expert_outputs=packed_outputs,
132
- selected_heads=selected_heads,
133
- unpacking_mask=unpacking_mask,
134
- inverse_permutation=inverse_permutation,
135
- )
136
- final_output = (
137
- token_choice_outputs * routing_probs.unsqueeze(-1)
138
- ).sum(dim=2)
139
-
140
- return final_output, load_balance_loss, max_vio
 
1
+ """Full MoSRAH sparse path for SHRAM.
2
+
3
+ This module coordinates the routed sparse attention path used inside the SHRAM
4
+ hybrid attention layer. The underlying mechanics already live in verified
5
+ subunits. The responsibility here is to connect those subunits without
6
+ corrupting their bridge contracts.
7
+
8
+ In particular, this path must preserve three architectural distinctions:
9
+
10
+ - selected head indices are not routing probabilities
11
+ - packed position semantics are chosen before BEA, not inside it
12
+ - weighted reduction must consume the router's unbiased renormalized
13
+ probabilities after token-choice order has been restored
14
+ """
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from .__cache__mosrah_cache import MoSRAHCache
20
+ from .configuration import ShramConfig
21
+ from .__attention__bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
22
+ from .__attention__expert_packing import (
23
+ pack_experts,
24
+ setup_packing,
25
+ unpack_experts,
26
+ )
27
+ from .__attention__router import MoSRAHRouter
28
+ from .__attention__positions_converter import SparseMoSRAHPositions
29
+
30
+
31
+ class MoSRAHLayer(nn.Module):
32
+ """Full routed sparse attention path for SHRAM.
33
+
34
+ The MoSRAH path consumes model-space hidden states together with
35
+ authoritative per-token positions and returns the model-space sparse-path
36
+ contribution, the router's load-balance loss, and the router's MaxVio
37
+ routing-imbalance scalar.
38
+ """
39
+
40
+ def __init__(self, config: ShramConfig) -> None:
41
+ super().__init__()
42
+ self.num_experts = config.num_mosrah_heads
43
+
44
+ self.router = MoSRAHRouter(config)
45
+ self.positions = SparseMoSRAHPositions(config)
46
+ self.bea = BottleneckedEnsembleAttention(config)
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: torch.Tensor,
51
+ position_ids: torch.Tensor,
52
+ active_mask: torch.Tensor,
53
+ cache: MoSRAHCache | None,
54
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55
+ """Run the full MoSRAH sparse path.
56
+
57
+ Args:
58
+ hidden_states: Model-space hidden states x of shape (B, N, d).
59
+ position_ids: Authoritative per-token positions of shape (B, N).
60
+ active_mask: Current-chunk active mask of shape (B, N), where True
61
+ means the token is semantically live. Forwarded to the router
62
+ so dead tokens are excluded from routing statistics, and to
63
+ pack_experts so dead outer tokens do not become semantically
64
+ active packed entries.
65
+ cache: Optional layer-local MoSRAH cache. Pass None for uncached
66
+ execution and the layer-local cache instance for cached execution.
67
+
68
+ Returns:
69
+ sparse_output: Model-space sparse-path output of shape (B, N, d).
70
+ load_balance_loss: Scalar router load-balance loss.
71
+ max_vio: Detached scalar routing-imbalance summary. Passed through
72
+ unchanged from the router; see MoSRAHRouter for semantics.
73
+ """
74
+
75
+ # -------------------------------------------------------------------
76
+ # The first transition moves from model-space token-choice input into
77
+ # the packed expert-choice sparse-attention state. Routing decides both
78
+ # which experts each token uses and which unbiased probabilities must be
79
+ # reserved for the final reduction. The active mask is forwarded to the
80
+ # router so dead tokens are excluded from routing statistics, and to
81
+ # pack_experts so outer liveness is faithfully carried into the packed
82
+ # frame. Packing returns both the unpacking mask (slot occupancy, always
83
+ # B*N*K True entries) and the packed active mask (live slots only);
84
+ # active_mask is rebound to the packed form after this point.
85
+ # -------------------------------------------------------------------
86
+ selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
87
+ hidden_states, active_mask
88
+ )
89
+
90
+ flattened_selected_heads, permutation, inverse_permutation = setup_packing(
91
+ selected_heads
92
+ )
93
+ packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
94
+ hidden_states=hidden_states,
95
+ position_ids=position_ids,
96
+ selected_heads=selected_heads,
97
+ num_experts=self.num_experts,
98
+ flattened_selected_heads=flattened_selected_heads,
99
+ permutation=permutation,
100
+ outer_active_mask=active_mask,
101
+ )
102
+
103
+ # -------------------------------------------------------------------
104
+ # Sparse attention runs entirely in the packed expert-choice frame, so
105
+ # the RoPE position semantics must also be chosen in that frame. The
106
+ # position layer therefore decides whether BEA should see packed
107
+ # original-token positions or packed local-slot positions. BEA then
108
+ # consumes that packed position tensor together with the packed hidden
109
+ # states and the layer-local sparse cache, which it owns directly.
110
+ # -------------------------------------------------------------------
111
+ bea_positions = self.positions(
112
+ packed_positions=packed_positions,
113
+ cache=cache,
114
+ )
115
+ packed_outputs = self.bea(
116
+ packed_embeddings=packed_hidden_states,
117
+ position_ids=bea_positions,
118
+ active_mask=active_mask,
119
+ cache=cache,
120
+ )
121
+
122
+ # -------------------------------------------------------------------
123
+ # The final transition restores token-choice meaning and only then
124
+ # collapses the K routed copies back into model space. This ordering is
125
+ # required because routing_probs live in token-choice space, whereas BEA
126
+ # returns expert-choice packed outputs. The reduction must therefore
127
+ # happen after unpacking, and it must use the router's unbiased
128
+ # renormalized probabilities rather than any biased selection scores.
129
+ # -------------------------------------------------------------------
130
+ token_choice_outputs = unpack_experts(
131
+ expert_outputs=packed_outputs,
132
+ selected_heads=selected_heads,
133
+ unpacking_mask=unpacking_mask,
134
+ inverse_permutation=inverse_permutation,
135
+ )
136
+ final_output = (
137
+ token_choice_outputs * routing_probs.unsqueeze(-1)
138
+ ).sum(dim=2)
139
+
140
+ return final_output, load_balance_loss, max_vio
__attention__positions_converter.py CHANGED
@@ -1,105 +1,105 @@
1
- """Position computation for the MoSRAH sparse path.
2
-
3
- This layer computes the packed position tensor P consumed by BEA.
4
-
5
- - In main-sequence mode, P is the packed original-token position tensor from the
6
- packing path.
7
- - In semantic-sequence mode, P is a per-expert local sequence over the packed
8
- expert-choice layout, optionally offset by the current sparse-cache occupancies
9
- during cached inference.
10
- """
11
-
12
- import torch
13
- from torch import nn
14
-
15
- from .configuration import ShramConfig
16
- from .__cache__mosrah_cache import MoSRAHCache
17
-
18
-
19
- class SparseMoSRAHPositions(nn.Module):
20
- """Compute the packed RoPE position tensor for the MoSRAH sparse path.
21
-
22
- This layer operates in the packed expert-choice frame used by BEA. The input
23
- packed_positions tensor is always the packed original-token position tensor
24
- produced by the packing path. The configured rope_mode determines whether that
25
- tensor is forwarded directly or replaced by a semantic local-slot sequence.
26
- """
27
-
28
- def __init__(self, config: ShramConfig) -> None:
29
- super().__init__()
30
- self.rope_mode = config.rope_mode
31
-
32
- def forward(
33
- self,
34
- packed_positions: torch.Tensor,
35
- cache: MoSRAHCache | None,
36
- ) -> torch.Tensor:
37
- """Compute the packed position tensor P consumed by BEA.
38
-
39
- Args:
40
- packed_positions: Packed original-token positions J' of shape (B, L, T).
41
- cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
42
- mode, the current per-head occupancies offset the local packed sequence.
43
-
44
- Returns:
45
- Packed position tensor P of shape (B, L, T).
46
- """
47
- if self.rope_mode == "main_sequence":
48
- return self._main_sequence_positions(packed_positions)
49
-
50
- if self.rope_mode == "semantic_sequence":
51
- return self._semantic_sequence_positions(packed_positions, cache)
52
-
53
- raise NotImplementedError(
54
- f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
55
- )
56
-
57
- def _main_sequence_positions(
58
- self,
59
- packed_positions: torch.Tensor,
60
- ) -> torch.Tensor:
61
- """Forward packed original-token positions unchanged."""
62
- return packed_positions
63
-
64
- def _semantic_sequence_positions(
65
- self,
66
- packed_positions: torch.Tensor,
67
- cache: MoSRAHCache | None,
68
- ) -> torch.Tensor:
69
- """Compute semantic-sequence packed positions in expert-choice space.
70
-
71
- Without a sparse cache, semantic positions are the local packed sequence
72
- 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
73
- same local sequence is offset by the current per-(batch, expert) occupancies
74
- returned by get_heads_lengths().
75
- """
76
- batch_size, num_experts, packed_length = packed_positions.shape
77
-
78
- # -------------------------------------------------------------------
79
- # Construct the local packed sequence 0, 1, 2, ... over the expert-local
80
- # sequence dimension T. This is then broadcast across batch and experts.
81
- # -------------------------------------------------------------------
82
- local_positions = torch.arange(
83
- packed_length,
84
- device=packed_positions.device,
85
- dtype=packed_positions.dtype,
86
- ).view(1, 1, packed_length).expand(
87
- batch_size,
88
- num_experts,
89
- packed_length,
90
- )
91
-
92
- # -------------------------------------------------------------------
93
- # In cached semantic-sequence mode, positions continue from the current
94
- # sparse-cache occupancies rather than restarting at zero for the local
95
- # chunk.
96
- # -------------------------------------------------------------------
97
- if cache is None:
98
- return local_positions
99
-
100
- cached_lengths = cache.get_heads_lengths().to(
101
- device=packed_positions.device,
102
- dtype=packed_positions.dtype,
103
- ).unsqueeze(-1)
104
-
105
  return local_positions + cached_lengths
 
1
+ """Position computation for the MoSRAH sparse path.
2
+
3
+ This layer computes the packed position tensor P consumed by BEA.
4
+
5
+ - In main-sequence mode, P is the packed original-token position tensor from the
6
+ packing path.
7
+ - In semantic-sequence mode, P is a per-expert local sequence over the packed
8
+ expert-choice layout, optionally offset by the current sparse-cache occupancies
9
+ during cached inference.
10
+ """
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from .configuration import ShramConfig
16
+ from .__cache__mosrah_cache import MoSRAHCache
17
+
18
+
19
+ class SparseMoSRAHPositions(nn.Module):
20
+ """Compute the packed RoPE position tensor for the MoSRAH sparse path.
21
+
22
+ This layer operates in the packed expert-choice frame used by BEA. The input
23
+ packed_positions tensor is always the packed original-token position tensor
24
+ produced by the packing path. The configured rope_mode determines whether that
25
+ tensor is forwarded directly or replaced by a semantic local-slot sequence.
26
+ """
27
+
28
+ def __init__(self, config: ShramConfig) -> None:
29
+ super().__init__()
30
+ self.rope_mode = config.rope_mode
31
+
32
+ def forward(
33
+ self,
34
+ packed_positions: torch.Tensor,
35
+ cache: MoSRAHCache | None,
36
+ ) -> torch.Tensor:
37
+ """Compute the packed position tensor P consumed by BEA.
38
+
39
+ Args:
40
+ packed_positions: Packed original-token positions J' of shape (B, L, T).
41
+ cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
42
+ mode, the current per-head occupancies offset the local packed sequence.
43
+
44
+ Returns:
45
+ Packed position tensor P of shape (B, L, T).
46
+ """
47
+ if self.rope_mode == "main_sequence":
48
+ return self._main_sequence_positions(packed_positions)
49
+
50
+ if self.rope_mode == "semantic_sequence":
51
+ return self._semantic_sequence_positions(packed_positions, cache)
52
+
53
+ raise NotImplementedError(
54
+ f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
55
+ )
56
+
57
+ def _main_sequence_positions(
58
+ self,
59
+ packed_positions: torch.Tensor,
60
+ ) -> torch.Tensor:
61
+ """Forward packed original-token positions unchanged."""
62
+ return packed_positions
63
+
64
+ def _semantic_sequence_positions(
65
+ self,
66
+ packed_positions: torch.Tensor,
67
+ cache: MoSRAHCache | None,
68
+ ) -> torch.Tensor:
69
+ """Compute semantic-sequence packed positions in expert-choice space.
70
+
71
+ Without a sparse cache, semantic positions are the local packed sequence
72
+ 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
73
+ same local sequence is offset by the current per-(batch, expert) occupancies
74
+ returned by get_heads_lengths().
75
+ """
76
+ batch_size, num_experts, packed_length = packed_positions.shape
77
+
78
+ # -------------------------------------------------------------------
79
+ # Construct the local packed sequence 0, 1, 2, ... over the expert-local
80
+ # sequence dimension T. This is then broadcast across batch and experts.
81
+ # -------------------------------------------------------------------
82
+ local_positions = torch.arange(
83
+ packed_length,
84
+ device=packed_positions.device,
85
+ dtype=packed_positions.dtype,
86
+ ).view(1, 1, packed_length).expand(
87
+ batch_size,
88
+ num_experts,
89
+ packed_length,
90
+ )
91
+
92
+ # -------------------------------------------------------------------
93
+ # In cached semantic-sequence mode, positions continue from the current
94
+ # sparse-cache occupancies rather than restarting at zero for the local
95
+ # chunk.
96
+ # -------------------------------------------------------------------
97
+ if cache is None:
98
+ return local_positions
99
+
100
+ cached_lengths = cache.get_heads_lengths().to(
101
+ device=packed_positions.device,
102
+ dtype=packed_positions.dtype,
103
+ ).unsqueeze(-1)
104
+
105
  return local_positions + cached_lengths
__attention__router.py CHANGED
@@ -1,162 +1,162 @@
1
- """Token-choice router for the MoSRAH sparse attention path.
2
-
3
- This module implements the routing mechanism described in Appendix A.Routing of the
4
- paper. Given an input hidden state x, the router produces two outputs used downstream:
5
-
6
- - selected_heads (I): which K of the L available expert heads each token routes to,
7
- determined by TopK over biased routing scores.
8
- - routing_probs (P): the weights used for the weighted output reduction, gathered from
9
- *unbiased* routing scores at the selected indices and renormalized. The learned expert
10
- bias b must not influence P.
11
-
12
- This separation is architecturally critical: expert_bias drives selection (and thus load
13
- balancing) but does not corrupt the gradient path from the output through routing_probs
14
- back to the routing projection weights.
15
-
16
- The router also computes and returns the load balance loss via the LoadBalanceLoss custom
17
- autograd operator (see load_balance_loss.py). This loss is a scalar that the training
18
- loop can weight and add to the language modeling loss.
19
-
20
- The router additionally computes and returns MaxVio, a detached scalar summarising
21
- routing imbalance for the current forward pass:
22
-
23
- MaxVio = L · max_l(f_l − 1/L)
24
-
25
- where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
26
- target. MaxVio is a monitoring quantity only; it never contributes gradients.
27
-
28
- Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
29
- """
30
-
31
- import torch
32
- import torch.nn as nn
33
- import torch.nn.functional as F
34
-
35
- from .configuration import ShramConfig
36
- from .__attention__load_balance_loss import LoadBalanceLoss
37
-
38
-
39
- class MoSRAHRouter(nn.Module):
40
- """Token-choice router for MoSRAH sparse attention.
41
-
42
- Each input token independently selects K of the L available expert heads. Selection
43
- is driven by biased routing scores to enable load balancing, but the routing
44
- probabilities used for output reduction are computed from unbiased scores so that
45
- the expert bias does not interfere with the gradient path to the router weights.
46
-
47
- The routing projection W_r has no bias term — the paper specifies xW_r with no
48
- additional projection bias. The only bias-like parameter is expert_bias (b), which
49
- has an entirely separate role and update mechanism.
50
-
51
- Args:
52
- config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
53
- (L), and ``num_selected_heads`` (K).
54
- """
55
-
56
- def __init__(self, config: ShramConfig) -> None:
57
- super().__init__()
58
- self.num_mosrah_heads = config.num_mosrah_heads
59
- self.num_selected_heads = config.num_selected_heads
60
-
61
- # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
62
- self.routing_projection = nn.Linear(
63
- config.hidden_size, config.num_mosrah_heads, bias=False
64
- )
65
-
66
- # b: learned per-head bias for load balancing. Initialized to zero so that all
67
- # heads start with equal selection probability. Updated by the main optimizer
68
- # via the LoadBalanceLoss custom backward.
69
- self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
70
-
71
- def forward(
72
- self, x: torch.Tensor, active_mask: torch.Tensor
73
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
- """Route input tokens to K expert heads each and compute routing probabilities.
75
-
76
- Args:
77
- x: Input hidden states of shape (batch, seq_len, hidden_size).
78
- active_mask: Current-chunk active mask of shape (batch, seq_len), where
79
- True means the token is semantically live. Dead tokens do not
80
- contribute to routing frequencies, load_balance_loss, or max_vio.
81
-
82
- Returns:
83
- selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
84
- Each token's K selected head indices, determined by TopK on biased scores.
85
- routing_probs: Routing probabilities P of shape (batch, seq_len,
86
- num_selected_heads). Gathered from unbiased scores at selected_heads
87
- indices and renormalized to sum to 1 per token.
88
- load_balance_loss: Scalar load balance imbalance loss for this forward pass.
89
- Training loop scales this by a weight and adds it to the main loss.
90
- max_vio: Detached scalar routing-imbalance summary for this forward pass.
91
- Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
92
- never contributes gradients.
93
- """
94
- B, N, _ = x.shape
95
- L = self.num_mosrah_heads
96
- K = self.num_selected_heads
97
-
98
- # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
99
- # compute routing_probs — expert_bias must not influence them.
100
- logits = self.routing_projection(x) # (B, N, L)
101
- routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
102
-
103
- # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
104
- # selection. expert_bias is added to logits before softmax so that the bias
105
- # shifts selection probability without rescaling the unbiased distribution.
106
- biased_routing_scores = F.softmax( # R̂, (B, N, L)
107
- logits + self.expert_bias, dim=-1
108
- )
109
-
110
- # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
111
- selected_heads = biased_routing_scores.topk(K, dim=-1).indices
112
-
113
- # Routing probabilities P: gathered from unbiased R at selected_heads indices,
114
- # then renormalized so they sum to 1 per token. Gathering from routing_scores
115
- # (not biased_routing_scores) is the invariant that keeps the gradient path from
116
- # the output back to the router weights free of expert_bias influence.
117
- gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
118
- routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
119
-
120
- # Routing frequency f_l: fraction of active (batch, token, head_slot) triples
121
- # assigned to each head. Dead tokens are excluded by zeroing their rows in the
122
- # assignment mask before reduction. Normalization uses the active assignment
123
- # count so frequencies remain properly scaled regardless of how many tokens
124
- # are live in this chunk.
125
- assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
- assignment_mask.scatter_(-1, selected_heads, 1.0)
127
- active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
- num_active_assignments = active_mask.sum() * K
129
- routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
130
-
131
- # Load balance loss via custom autograd. expert_bias is an input so PyTorch
132
- # registers it as a graph node; the custom backward writes the DeepSeek-style
133
- # correction gradient to expert_bias.grad for the optimizer to consume.
134
- load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
135
-
136
- # MaxVio is a detached monitoring scalar derived from routing_freqs. It must
137
- # not contribute gradients under any circumstance, so it is detached at the
138
- # point of computation rather than left to callers to detach.
139
- max_vio = self._compute_max_vio(routing_freqs, L)
140
-
141
- return selected_heads, routing_probs, load_balance_loss, max_vio
142
-
143
- @staticmethod
144
- def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
145
- """Compute the MaxVio routing-imbalance scalar.
146
-
147
- MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
148
- head l and 1/L is the perfectly balanced target. A value of zero indicates
149
- perfect balance; a value of 1 means the most overloaded head received exactly
150
- double its fair share.
151
-
152
- The result is detached from the autograd graph — MaxVio is a monitoring scalar
153
- and must never contribute gradients to any parameter.
154
-
155
- Args:
156
- routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
157
- num_heads: Total number of MoSRAH heads L.
158
-
159
- Returns:
160
- Detached scalar MaxVio tensor.
161
- """
162
- return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
 
1
+ """Token-choice router for the MoSRAH sparse attention path.
2
+
3
+ This module implements the routing mechanism described in Appendix A.Routing of the
4
+ paper. Given an input hidden state x, the router produces two outputs used downstream:
5
+
6
+ - selected_heads (I): which K of the L available expert heads each token routes to,
7
+ determined by TopK over biased routing scores.
8
+ - routing_probs (P): the weights used for the weighted output reduction, gathered from
9
+ *unbiased* routing scores at the selected indices and renormalized. The learned expert
10
+ bias b must not influence P.
11
+
12
+ This separation is architecturally critical: expert_bias drives selection (and thus load
13
+ balancing) but does not corrupt the gradient path from the output through routing_probs
14
+ back to the routing projection weights.
15
+
16
+ The router also computes and returns the load balance loss via the LoadBalanceLoss custom
17
+ autograd operator (see load_balance_loss.py). This loss is a scalar that the training
18
+ loop can weight and add to the language modeling loss.
19
+
20
+ The router additionally computes and returns MaxVio, a detached scalar summarising
21
+ routing imbalance for the current forward pass:
22
+
23
+ MaxVio = L · max_l(f_l − 1/L)
24
+
25
+ where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
26
+ target. MaxVio is a monitoring quantity only; it never contributes gradients.
27
+
28
+ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ from .configuration import ShramConfig
36
+ from .__attention__load_balance_loss import LoadBalanceLoss
37
+
38
+
39
+ class MoSRAHRouter(nn.Module):
40
+ """Token-choice router for MoSRAH sparse attention.
41
+
42
+ Each input token independently selects K of the L available expert heads. Selection
43
+ is driven by biased routing scores to enable load balancing, but the routing
44
+ probabilities used for output reduction are computed from unbiased scores so that
45
+ the expert bias does not interfere with the gradient path to the router weights.
46
+
47
+ The routing projection W_r has no bias term — the paper specifies xW_r with no
48
+ additional projection bias. The only bias-like parameter is expert_bias (b), which
49
+ has an entirely separate role and update mechanism.
50
+
51
+ Args:
52
+ config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
53
+ (L), and ``num_selected_heads`` (K).
54
+ """
55
+
56
+ def __init__(self, config: ShramConfig) -> None:
57
+ super().__init__()
58
+ self.num_mosrah_heads = config.num_mosrah_heads
59
+ self.num_selected_heads = config.num_selected_heads
60
+
61
+ # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
62
+ self.routing_projection = nn.Linear(
63
+ config.hidden_size, config.num_mosrah_heads, bias=False
64
+ )
65
+
66
+ # b: learned per-head bias for load balancing. Initialized to zero so that all
67
+ # heads start with equal selection probability. Updated by the main optimizer
68
+ # via the LoadBalanceLoss custom backward.
69
+ self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
70
+
71
+ def forward(
72
+ self, x: torch.Tensor, active_mask: torch.Tensor
73
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
+ """Route input tokens to K expert heads each and compute routing probabilities.
75
+
76
+ Args:
77
+ x: Input hidden states of shape (batch, seq_len, hidden_size).
78
+ active_mask: Current-chunk active mask of shape (batch, seq_len), where
79
+ True means the token is semantically live. Dead tokens do not
80
+ contribute to routing frequencies, load_balance_loss, or max_vio.
81
+
82
+ Returns:
83
+ selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
84
+ Each token's K selected head indices, determined by TopK on biased scores.
85
+ routing_probs: Routing probabilities P of shape (batch, seq_len,
86
+ num_selected_heads). Gathered from unbiased scores at selected_heads
87
+ indices and renormalized to sum to 1 per token.
88
+ load_balance_loss: Scalar load balance imbalance loss for this forward pass.
89
+ Training loop scales this by a weight and adds it to the main loss.
90
+ max_vio: Detached scalar routing-imbalance summary for this forward pass.
91
+ Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
92
+ never contributes gradients.
93
+ """
94
+ B, N, _ = x.shape
95
+ L = self.num_mosrah_heads
96
+ K = self.num_selected_heads
97
+
98
+ # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
99
+ # compute routing_probs — expert_bias must not influence them.
100
+ logits = self.routing_projection(x) # (B, N, L)
101
+ routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
102
+
103
+ # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
104
+ # selection. expert_bias is added to logits before softmax so that the bias
105
+ # shifts selection probability without rescaling the unbiased distribution.
106
+ biased_routing_scores = F.softmax( # R̂, (B, N, L)
107
+ logits + self.expert_bias, dim=-1
108
+ )
109
+
110
+ # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
111
+ selected_heads = biased_routing_scores.topk(K, dim=-1).indices
112
+
113
+ # Routing probabilities P: gathered from unbiased R at selected_heads indices,
114
+ # then renormalized so they sum to 1 per token. Gathering from routing_scores
115
+ # (not biased_routing_scores) is the invariant that keeps the gradient path from
116
+ # the output back to the router weights free of expert_bias influence.
117
+ gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
118
+ routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
119
+
120
+ # Routing frequency f_l: fraction of active (batch, token, head_slot) triples
121
+ # assigned to each head. Dead tokens are excluded by zeroing their rows in the
122
+ # assignment mask before reduction. Normalization uses the active assignment
123
+ # count so frequencies remain properly scaled regardless of how many tokens
124
+ # are live in this chunk.
125
+ assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
+ assignment_mask.scatter_(-1, selected_heads, 1.0)
127
+ active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
+ num_active_assignments = active_mask.sum() * K
129
+ routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
130
+
131
+ # Load balance loss via custom autograd. expert_bias is an input so PyTorch
132
+ # registers it as a graph node; the custom backward writes the DeepSeek-style
133
+ # correction gradient to expert_bias.grad for the optimizer to consume.
134
+ load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
135
+
136
+ # MaxVio is a detached monitoring scalar derived from routing_freqs. It must
137
+ # not contribute gradients under any circumstance, so it is detached at the
138
+ # point of computation rather than left to callers to detach.
139
+ max_vio = self._compute_max_vio(routing_freqs, L)
140
+
141
+ return selected_heads, routing_probs, load_balance_loss, max_vio
142
+
143
+ @staticmethod
144
+ def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
145
+ """Compute the MaxVio routing-imbalance scalar.
146
+
147
+ MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
148
+ head l and 1/L is the perfectly balanced target. A value of zero indicates
149
+ perfect balance; a value of 1 means the most overloaded head received exactly
150
+ double its fair share.
151
+
152
+ The result is detached from the autograd graph — MaxVio is a monitoring scalar
153
+ and must never contribute gradients to any parameter.
154
+
155
+ Args:
156
+ routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
157
+ num_heads: Total number of MoSRAH heads L.
158
+
159
+ Returns:
160
+ Detached scalar MaxVio tensor.
161
+ """
162
+ return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
__attention__shram.py CHANGED
@@ -1,116 +1,116 @@
1
- """SHRAM hybrid attention layer.
2
-
3
- This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
4
- used at one decoder attention slot in SHRAM.
5
-
6
- The local sliding-window path and the MoSRAH sparse path are already verified
7
- independently. The responsibility here is therefore not to introduce new
8
- attention logic, but to preserve the bridge contracts between them: both paths
9
- must consume the same input hidden state, each path must receive the sub-cache
10
- it actually owns, the two model-space outputs must be summed directly, and the
11
- sparse-path load-balance loss must remain visible to the caller.
12
- """
13
-
14
- import torch
15
- from torch import nn
16
-
17
- from .__cache__shram_layer_cache import ShramLayerCache
18
- from .configuration import ShramConfig
19
- from .__attention__sliding_window_attention import SlidingWindowAttention
20
- from .__attention__mosrah import MoSRAHLayer
21
-
22
-
23
- class SHRAMHybridLayer(nn.Module):
24
- """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
25
-
26
- The local path preserves nearby-token behavior through sliding-window causal
27
- attention. The sparse path is the theorem-facing MoSRAH routed attention
28
- path. Both operate over the same model-space hidden state and return
29
- model-space outputs, so the hybrid composition is a direct sum in model
30
- space.
31
- """
32
-
33
- def __init__(self, config: ShramConfig) -> None:
34
- super().__init__()
35
- self.local_attention = SlidingWindowAttention(config)
36
- self.sparse_attention = MoSRAHLayer(config)
37
-
38
- def forward(
39
- self,
40
- hidden_states: torch.Tensor,
41
- position_ids: torch.Tensor,
42
- active_mask: torch.Tensor,
43
- cache: ShramLayerCache | None,
44
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
- """Apply the SHRAM hybrid attention layer.
46
-
47
- Args:
48
- hidden_states: Input hidden states of shape (B, N, d).
49
- position_ids: Authoritative token positions of shape (B, N).
50
- active_mask: Current-chunk active mask of shape (B, N), where True
51
- means the token is semantically live. Forwarded unchanged to
52
- both the local path and the sparse path.
53
- cache: Optional per-layer SHRAM cache. When provided, the owned
54
- sliding-window and MoSRAH sub-caches are dispatched directly to
55
- their corresponding attention paths.
56
-
57
- Returns:
58
- hybrid_output: Model-space hybrid attention output of shape (B, N, d).
59
- load_balance_loss: Scalar sparse-path load-balance loss.
60
- max_vio: Detached scalar routing-imbalance summary. Passed through
61
- unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
62
- """
63
- # ------------------------------------------------
64
- # It is not possible, due to how bea constructs its block mask,
65
- # for the model to process a sequence that does not start at zero
66
- # without a cache to track the per-head offsets
67
- # ------------------------------------------------
68
-
69
- if cache is None and torch.any(position_ids[:, 0] != 0):
70
- raise ValueError(
71
- "Uncached SHRAMHybridLayer does not support nonzero starting positions. "
72
- "Either provide a matching ShramLayerCache populated by the prefix for "
73
- "continued decoding, or rebase the uncached sequence to start at 0."
74
- )
75
-
76
- # -------------------------------------------------------------------
77
- # The hybrid layer's first responsibility is cache dispatch. The layer
78
- # cache already owns the concrete sub-cache objects required by each
79
- # path, so this unit should forward those exact references rather than
80
- # reinterpret cache ownership or invent a composite update protocol here.
81
- # -------------------------------------------------------------------
82
- if cache is None:
83
- sliding_window_cache = None
84
- mosrah_cache = None
85
- else:
86
- sliding_window_cache = cache.sliding_window_cache
87
- mosrah_cache = cache.mosrah_cache
88
-
89
- # -------------------------------------------------------------------
90
- # Both attention paths must see the same model-space hidden state for
91
- # the current decoder layer. The local path preserves short-range
92
- # structure, while the sparse path provides the routed long-range
93
- # contribution and emits the load-balance signal used by training.
94
- # -------------------------------------------------------------------
95
- local_output = self.local_attention(
96
- x=hidden_states,
97
- position_ids=position_ids,
98
- active_mask=active_mask,
99
- cache=sliding_window_cache,
100
- )
101
- sparse_output, load_balance_loss, max_vio = self.sparse_attention(
102
- hidden_states=hidden_states,
103
- position_ids=position_ids,
104
- active_mask=active_mask,
105
- cache=mosrah_cache,
106
- )
107
-
108
- # -------------------------------------------------------------------
109
- # The composition rule is intentionally simple at this boundary. Both
110
- # sublayers already return model-space tensors of matching shape, so the
111
- # correct hybrid behavior is their direct sum with no additional mixing
112
- # logic introduced here.
113
- # -------------------------------------------------------------------
114
- hybrid_output = local_output + sparse_output
115
-
116
  return hybrid_output, load_balance_loss, max_vio
 
1
+ """SHRAM hybrid attention layer.
2
+
3
+ This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
4
+ used at one decoder attention slot in SHRAM.
5
+
6
+ The local sliding-window path and the MoSRAH sparse path are already verified
7
+ independently. The responsibility here is therefore not to introduce new
8
+ attention logic, but to preserve the bridge contracts between them: both paths
9
+ must consume the same input hidden state, each path must receive the sub-cache
10
+ it actually owns, the two model-space outputs must be summed directly, and the
11
+ sparse-path load-balance loss must remain visible to the caller.
12
+ """
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from .__cache__shram_layer_cache import ShramLayerCache
18
+ from .configuration import ShramConfig
19
+ from .__attention__sliding_window_attention import SlidingWindowAttention
20
+ from .__attention__mosrah import MoSRAHLayer
21
+
22
+
23
+ class SHRAMHybridLayer(nn.Module):
24
+ """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
25
+
26
+ The local path preserves nearby-token behavior through sliding-window causal
27
+ attention. The sparse path is the theorem-facing MoSRAH routed attention
28
+ path. Both operate over the same model-space hidden state and return
29
+ model-space outputs, so the hybrid composition is a direct sum in model
30
+ space.
31
+ """
32
+
33
+ def __init__(self, config: ShramConfig) -> None:
34
+ super().__init__()
35
+ self.local_attention = SlidingWindowAttention(config)
36
+ self.sparse_attention = MoSRAHLayer(config)
37
+
38
+ def forward(
39
+ self,
40
+ hidden_states: torch.Tensor,
41
+ position_ids: torch.Tensor,
42
+ active_mask: torch.Tensor,
43
+ cache: ShramLayerCache | None,
44
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
+ """Apply the SHRAM hybrid attention layer.
46
+
47
+ Args:
48
+ hidden_states: Input hidden states of shape (B, N, d).
49
+ position_ids: Authoritative token positions of shape (B, N).
50
+ active_mask: Current-chunk active mask of shape (B, N), where True
51
+ means the token is semantically live. Forwarded unchanged to
52
+ both the local path and the sparse path.
53
+ cache: Optional per-layer SHRAM cache. When provided, the owned
54
+ sliding-window and MoSRAH sub-caches are dispatched directly to
55
+ their corresponding attention paths.
56
+
57
+ Returns:
58
+ hybrid_output: Model-space hybrid attention output of shape (B, N, d).
59
+ load_balance_loss: Scalar sparse-path load-balance loss.
60
+ max_vio: Detached scalar routing-imbalance summary. Passed through
61
+ unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
62
+ """
63
+ # ------------------------------------------------
64
+ # It is not possible, due to how bea constructs its block mask,
65
+ # for the model to process a sequence that does not start at zero
66
+ # without a cache to track the per-head offsets
67
+ # ------------------------------------------------
68
+
69
+ if cache is None and torch.any(position_ids[:, 0] != 0):
70
+ raise ValueError(
71
+ "Uncached SHRAMHybridLayer does not support nonzero starting positions. "
72
+ "Either provide a matching ShramLayerCache populated by the prefix for "
73
+ "continued decoding, or rebase the uncached sequence to start at 0."
74
+ )
75
+
76
+ # -------------------------------------------------------------------
77
+ # The hybrid layer's first responsibility is cache dispatch. The layer
78
+ # cache already owns the concrete sub-cache objects required by each
79
+ # path, so this unit should forward those exact references rather than
80
+ # reinterpret cache ownership or invent a composite update protocol here.
81
+ # -------------------------------------------------------------------
82
+ if cache is None:
83
+ sliding_window_cache = None
84
+ mosrah_cache = None
85
+ else:
86
+ sliding_window_cache = cache.sliding_window_cache
87
+ mosrah_cache = cache.mosrah_cache
88
+
89
+ # -------------------------------------------------------------------
90
+ # Both attention paths must see the same model-space hidden state for
91
+ # the current decoder layer. The local path preserves short-range
92
+ # structure, while the sparse path provides the routed long-range
93
+ # contribution and emits the load-balance signal used by training.
94
+ # -------------------------------------------------------------------
95
+ local_output = self.local_attention(
96
+ x=hidden_states,
97
+ position_ids=position_ids,
98
+ active_mask=active_mask,
99
+ cache=sliding_window_cache,
100
+ )
101
+ sparse_output, load_balance_loss, max_vio = self.sparse_attention(
102
+ hidden_states=hidden_states,
103
+ position_ids=position_ids,
104
+ active_mask=active_mask,
105
+ cache=mosrah_cache,
106
+ )
107
+
108
+ # -------------------------------------------------------------------
109
+ # The composition rule is intentionally simple at this boundary. Both
110
+ # sublayers already return model-space tensors of matching shape, so the
111
+ # correct hybrid behavior is their direct sum with no additional mixing
112
+ # logic introduced here.
113
+ # -------------------------------------------------------------------
114
+ hybrid_output = local_output + sparse_output
115
+
116
  return hybrid_output, load_balance_loss, max_vio
__attention__sliding_window_attention.py CHANGED
@@ -1,233 +1,233 @@
1
- # src/shram/model/attention/sliding_window_attention.py
2
-
3
- """Local sliding-window attention path for SHRAM.
4
-
5
- This file defines `SlidingWindowAttention`, the local short-range attention path
6
- used inside the SHRAM hybrid layer.
7
-
8
- In the masked-continuation variant, the local cache no longer returns a
9
- semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
10
- returns:
11
-
12
- - the retained local window memory concatenated with the current chunk
13
- - an aligned active mask over that returned frame
14
-
15
- This module consumes that returned frame directly and constructs effective local
16
- causal/window visibility from the mask. It does not own cache retention policy;
17
- it owns only local attention semantics.
18
- """
19
-
20
- import math
21
- from typing import Any
22
-
23
- import torch
24
- import torch.nn as nn
25
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention
26
-
27
- from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
28
- from .configuration import ShramConfig
29
- from .rope import RotaryEmbedding
30
-
31
-
32
- class SlidingWindowAttention(nn.Module):
33
- """Causal local sliding-window attention for one SHRAM layer.
34
-
35
- Args:
36
- config: SHRAM config. Must expose `hidden_size`,
37
- `num_sliding_window_heads`, `head_dim`, `window_size`,
38
- `attention_dropout`, and `local_rope_theta`.
39
-
40
- Raises:
41
- NotImplementedError: If `attention_dropout != 0.0`.
42
- """
43
-
44
- def __init__(self, config: ShramConfig) -> None:
45
- super().__init__()
46
-
47
- self.hidden_size = config.hidden_size
48
- self.num_heads = config.num_sliding_window_heads
49
- self.head_dim = config.head_dim
50
- self.window_size = config.window_size
51
- self.attention_dropout = config.attention_dropout
52
-
53
- if self.attention_dropout != 0.0:
54
- raise NotImplementedError(
55
- "SlidingWindowAttention currently supports only "
56
- "attention_dropout == 0.0."
57
- )
58
-
59
- self.inner_dim = self.num_heads * self.head_dim
60
-
61
- # Standard MHA projections for the local path.
62
- self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
63
- self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
64
- self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
65
- self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
66
-
67
- # The local path always uses default-mode RoPE with its own theta.
68
- self.rope = RotaryEmbedding(
69
- mode="default",
70
- head_dim=self.head_dim,
71
- theta=config.local_rope_theta,
72
- )
73
-
74
- def forward(
75
- self,
76
- x: torch.Tensor,
77
- position_ids: torch.Tensor,
78
- active_mask: torch.Tensor,
79
- cache: LocalSlidingWindowLayerCache | None = None,
80
- ) -> torch.Tensor:
81
- """Apply local causal sliding-window attention.
82
-
83
- Args:
84
- x: Input tensor of shape `(B, N, hidden_size)`.
85
- position_ids: Position tensor of shape `(B, N)`.
86
- active_mask: Current-chunk active mask of shape `(B, N)`, where
87
- `True` means active.
88
- cache: Optional `LocalSlidingWindowLayerCache`.
89
-
90
- Returns:
91
- Output tensor of shape `(B, N, hidden_size)`.
92
- """
93
- batch_size, query_len, _ = x.shape
94
-
95
- self._validate_position_shape(x, position_ids)
96
- self._validate_active_mask_shape(x, active_mask)
97
-
98
- # (B, N, H*D) -> (B, H, N, D)
99
- q = self.q_proj(x).view(
100
- batch_size,
101
- query_len,
102
- self.num_heads,
103
- self.head_dim,
104
- ).transpose(1, 2)
105
- k = self.k_proj(x).view(
106
- batch_size,
107
- query_len,
108
- self.num_heads,
109
- self.head_dim,
110
- ).transpose(1, 2)
111
- v = self.v_proj(x).view(
112
- batch_size,
113
- query_len,
114
- self.num_heads,
115
- self.head_dim,
116
- ).transpose(1, 2)
117
-
118
- q, k, attention_scaling = self.rope(q, k, position_ids)
119
-
120
- # The cache returns the current-step visible local frame, not merely the
121
- # retained next-step cache buffer.
122
- if cache is not None:
123
- k_full, v_full, full_active_mask = cache.update(k, v, active_mask)
124
- else:
125
- k_full, v_full, full_active_mask = k, v, active_mask
126
-
127
- block_mask = self._make_block_mask(
128
- active_mask=full_active_mask,
129
- batch_size=batch_size,
130
- num_heads=self.num_heads,
131
- query_len=query_len,
132
- kv_len=k_full.shape[-2],
133
- window_size=self.window_size,
134
- device=x.device,
135
- )
136
-
137
- attn_output = flex_attention(
138
- q,
139
- k_full,
140
- v_full,
141
- block_mask=block_mask,
142
- scale=attention_scaling / math.sqrt(self.head_dim),
143
- )
144
-
145
- # (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
146
- attn_output = (
147
- attn_output.transpose(1, 2)
148
- .contiguous()
149
- .view(batch_size, query_len, self.inner_dim)
150
- )
151
-
152
- return self.o_proj(attn_output)
153
-
154
- def _validate_position_shape(
155
- self,
156
- x: torch.Tensor,
157
- position_ids: torch.Tensor,
158
- ) -> None:
159
- """Validate the position tensor shape expected by local RoPE."""
160
- if position_ids.shape != x.shape[:2]:
161
- raise ValueError(
162
- f"position_ids must have shape {tuple(x.shape[:2])}, "
163
- f"got {tuple(position_ids.shape)}."
164
- )
165
-
166
- def _validate_active_mask_shape(
167
- self,
168
- x: torch.Tensor,
169
- active_mask: torch.Tensor,
170
- ) -> None:
171
- """Validate the current-chunk active-mask contract."""
172
- if active_mask.shape != x.shape[:2]:
173
- raise ValueError(
174
- f"active_mask must have shape {tuple(x.shape[:2])}, "
175
- f"got {tuple(active_mask.shape)}."
176
- )
177
- if active_mask.dtype != torch.bool:
178
- raise ValueError(
179
- f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
180
- )
181
-
182
- def _make_block_mask(
183
- self,
184
- active_mask: torch.Tensor,
185
- batch_size: int,
186
- num_heads: int,
187
- query_len: int,
188
- kv_len: int,
189
- window_size: int,
190
- device: torch.device,
191
- ) -> Any:
192
- """Create the FlexAttention block mask for masked local continuation.
193
-
194
- The returned local frame is chronological in raw buffer order, but dead
195
- positions may remain inside it. Effective local order is therefore
196
- recovered from the active mask itself by taking a cumulative count over
197
- active positions.
198
-
199
- Queries still occupy the tail of the returned frame, so raw buffer order
200
- is used to locate query rows. Semantic active-token positions are then
201
- used to decide causality and sliding-window distance.
202
- """
203
- query_offset = kv_len - query_len
204
- semantic_positions = active_mask.long().cumsum(dim=-1) - 1
205
-
206
- def sliding_window_mask(
207
- batch_idx: torch.Tensor,
208
- head_idx: torch.Tensor,
209
- q_idx: torch.Tensor,
210
- kv_idx: torch.Tensor,
211
- ) -> torch.Tensor:
212
-
213
- q_abs = query_offset + q_idx
214
-
215
- query_is_active = active_mask[batch_idx, q_abs]
216
- key_is_active = active_mask[batch_idx, kv_idx]
217
-
218
- q_sem = semantic_positions[batch_idx, q_abs]
219
- k_sem = semantic_positions[batch_idx, kv_idx]
220
-
221
- is_causal = k_sem <= q_sem
222
- in_window = (q_sem - k_sem) < window_size
223
-
224
- return query_is_active & key_is_active & is_causal & in_window
225
-
226
- return create_block_mask(
227
- sliding_window_mask,
228
- B=batch_size,
229
- H=num_heads,
230
- Q_LEN=query_len,
231
- KV_LEN=kv_len,
232
- device=device,
233
  )
 
1
+ # src/shram/model/attention/sliding_window_attention.py
2
+
3
+ """Local sliding-window attention path for SHRAM.
4
+
5
+ This file defines `SlidingWindowAttention`, the local short-range attention path
6
+ used inside the SHRAM hybrid layer.
7
+
8
+ In the masked-continuation variant, the local cache no longer returns a
9
+ semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
10
+ returns:
11
+
12
+ - the retained local window memory concatenated with the current chunk
13
+ - an aligned active mask over that returned frame
14
+
15
+ This module consumes that returned frame directly and constructs effective local
16
+ causal/window visibility from the mask. It does not own cache retention policy;
17
+ it owns only local attention semantics.
18
+ """
19
+
20
+ import math
21
+ from typing import Any
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
26
+
27
+ from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
28
+ from .configuration import ShramConfig
29
+ from .rope import RotaryEmbedding
30
+
31
+
32
+ class SlidingWindowAttention(nn.Module):
33
+ """Causal local sliding-window attention for one SHRAM layer.
34
+
35
+ Args:
36
+ config: SHRAM config. Must expose `hidden_size`,
37
+ `num_sliding_window_heads`, `head_dim`, `window_size`,
38
+ `attention_dropout`, and `local_rope_theta`.
39
+
40
+ Raises:
41
+ NotImplementedError: If `attention_dropout != 0.0`.
42
+ """
43
+
44
+ def __init__(self, config: ShramConfig) -> None:
45
+ super().__init__()
46
+
47
+ self.hidden_size = config.hidden_size
48
+ self.num_heads = config.num_sliding_window_heads
49
+ self.head_dim = config.head_dim
50
+ self.window_size = config.window_size
51
+ self.attention_dropout = config.attention_dropout
52
+
53
+ if self.attention_dropout != 0.0:
54
+ raise NotImplementedError(
55
+ "SlidingWindowAttention currently supports only "
56
+ "attention_dropout == 0.0."
57
+ )
58
+
59
+ self.inner_dim = self.num_heads * self.head_dim
60
+
61
+ # Standard MHA projections for the local path.
62
+ self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
63
+ self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
64
+ self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
65
+ self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
66
+
67
+ # The local path always uses default-mode RoPE with its own theta.
68
+ self.rope = RotaryEmbedding(
69
+ mode="default",
70
+ head_dim=self.head_dim,
71
+ theta=config.local_rope_theta,
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ x: torch.Tensor,
77
+ position_ids: torch.Tensor,
78
+ active_mask: torch.Tensor,
79
+ cache: LocalSlidingWindowLayerCache | None = None,
80
+ ) -> torch.Tensor:
81
+ """Apply local causal sliding-window attention.
82
+
83
+ Args:
84
+ x: Input tensor of shape `(B, N, hidden_size)`.
85
+ position_ids: Position tensor of shape `(B, N)`.
86
+ active_mask: Current-chunk active mask of shape `(B, N)`, where
87
+ `True` means active.
88
+ cache: Optional `LocalSlidingWindowLayerCache`.
89
+
90
+ Returns:
91
+ Output tensor of shape `(B, N, hidden_size)`.
92
+ """
93
+ batch_size, query_len, _ = x.shape
94
+
95
+ self._validate_position_shape(x, position_ids)
96
+ self._validate_active_mask_shape(x, active_mask)
97
+
98
+ # (B, N, H*D) -> (B, H, N, D)
99
+ q = self.q_proj(x).view(
100
+ batch_size,
101
+ query_len,
102
+ self.num_heads,
103
+ self.head_dim,
104
+ ).transpose(1, 2)
105
+ k = self.k_proj(x).view(
106
+ batch_size,
107
+ query_len,
108
+ self.num_heads,
109
+ self.head_dim,
110
+ ).transpose(1, 2)
111
+ v = self.v_proj(x).view(
112
+ batch_size,
113
+ query_len,
114
+ self.num_heads,
115
+ self.head_dim,
116
+ ).transpose(1, 2)
117
+
118
+ q, k, attention_scaling = self.rope(q, k, position_ids)
119
+
120
+ # The cache returns the current-step visible local frame, not merely the
121
+ # retained next-step cache buffer.
122
+ if cache is not None:
123
+ k_full, v_full, full_active_mask = cache.update(k, v, active_mask)
124
+ else:
125
+ k_full, v_full, full_active_mask = k, v, active_mask
126
+
127
+ block_mask = self._make_block_mask(
128
+ active_mask=full_active_mask,
129
+ batch_size=batch_size,
130
+ num_heads=self.num_heads,
131
+ query_len=query_len,
132
+ kv_len=k_full.shape[-2],
133
+ window_size=self.window_size,
134
+ device=x.device,
135
+ )
136
+
137
+ attn_output = flex_attention(
138
+ q,
139
+ k_full,
140
+ v_full,
141
+ block_mask=block_mask,
142
+ scale=attention_scaling / math.sqrt(self.head_dim),
143
+ )
144
+
145
+ # (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
146
+ attn_output = (
147
+ attn_output.transpose(1, 2)
148
+ .contiguous()
149
+ .view(batch_size, query_len, self.inner_dim)
150
+ )
151
+
152
+ return self.o_proj(attn_output)
153
+
154
+ def _validate_position_shape(
155
+ self,
156
+ x: torch.Tensor,
157
+ position_ids: torch.Tensor,
158
+ ) -> None:
159
+ """Validate the position tensor shape expected by local RoPE."""
160
+ if position_ids.shape != x.shape[:2]:
161
+ raise ValueError(
162
+ f"position_ids must have shape {tuple(x.shape[:2])}, "
163
+ f"got {tuple(position_ids.shape)}."
164
+ )
165
+
166
+ def _validate_active_mask_shape(
167
+ self,
168
+ x: torch.Tensor,
169
+ active_mask: torch.Tensor,
170
+ ) -> None:
171
+ """Validate the current-chunk active-mask contract."""
172
+ if active_mask.shape != x.shape[:2]:
173
+ raise ValueError(
174
+ f"active_mask must have shape {tuple(x.shape[:2])}, "
175
+ f"got {tuple(active_mask.shape)}."
176
+ )
177
+ if active_mask.dtype != torch.bool:
178
+ raise ValueError(
179
+ f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
180
+ )
181
+
182
+ def _make_block_mask(
183
+ self,
184
+ active_mask: torch.Tensor,
185
+ batch_size: int,
186
+ num_heads: int,
187
+ query_len: int,
188
+ kv_len: int,
189
+ window_size: int,
190
+ device: torch.device,
191
+ ) -> Any:
192
+ """Create the FlexAttention block mask for masked local continuation.
193
+
194
+ The returned local frame is chronological in raw buffer order, but dead
195
+ positions may remain inside it. Effective local order is therefore
196
+ recovered from the active mask itself by taking a cumulative count over
197
+ active positions.
198
+
199
+ Queries still occupy the tail of the returned frame, so raw buffer order
200
+ is used to locate query rows. Semantic active-token positions are then
201
+ used to decide causality and sliding-window distance.
202
+ """
203
+ query_offset = kv_len - query_len
204
+ semantic_positions = active_mask.long().cumsum(dim=-1) - 1
205
+
206
+ def sliding_window_mask(
207
+ batch_idx: torch.Tensor,
208
+ head_idx: torch.Tensor,
209
+ q_idx: torch.Tensor,
210
+ kv_idx: torch.Tensor,
211
+ ) -> torch.Tensor:
212
+
213
+ q_abs = query_offset + q_idx
214
+
215
+ query_is_active = active_mask[batch_idx, q_abs]
216
+ key_is_active = active_mask[batch_idx, kv_idx]
217
+
218
+ q_sem = semantic_positions[batch_idx, q_abs]
219
+ k_sem = semantic_positions[batch_idx, kv_idx]
220
+
221
+ is_causal = k_sem <= q_sem
222
+ in_window = (q_sem - k_sem) < window_size
223
+
224
+ return query_is_active & key_is_active & is_causal & in_window
225
+
226
+ return create_block_mask(
227
+ sliding_window_mask,
228
+ B=batch_size,
229
+ H=num_heads,
230
+ Q_LEN=query_len,
231
+ KV_LEN=kv_len,
232
+ device=device,
233
  )
__cache__mosrah_cache.py CHANGED
@@ -1,359 +1,359 @@
1
- """MoSRAH sparse KV cache — single-layer implementation.
2
-
3
- MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
4
- by head rather than by sequence position. The routing is dynamic and produces a ragged
5
- distribution of token counts across (batch, head) slots — different batch items may
6
- route different numbers of tokens to the same head, and different heads accumulate at
7
- different rates. DynamicCache cannot represent this correctly: it concatenates along
8
- the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
9
- therefore uses a custom buffer design.
10
-
11
- Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
12
- attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
13
- heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
14
- head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
15
- valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
16
- buffer_capacity property and is derived directly from self.keys rather than tracked
17
- as a separate variable.
18
-
19
- The primary interface is update(key_states, value_states, active_mask), which accepts
20
- expert-choice layout, stores only active entries in causal order, and returns the full
21
- accumulated (keys, values, active_mask) for immediate use by BEA. The returned
22
- active_mask identifies valid cached positions; everything beyond each slot's count is
23
- junk data that downstream attention must exclude.
24
-
25
- BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts
26
- exposed by get_heads_lengths() must be read before update() if the caller needs the
27
- pre-update occupancy for position computation (Unit 10.A). update() increments counts
28
- in-place and the pre-update values are not recoverable afterward.
29
-
30
- All buffers are allocated at construction time. MoSRAHCache is constructed by
31
- ShramLayerCache, which has access to batch size, device, and all model config parameters
32
- needed to fully specify the storage layout upfront.
33
- """
34
-
35
- import torch
36
- from transformers.cache_utils import CacheLayerMixin
37
-
38
-
39
- class MoSRAHCache(CacheLayerMixin):
40
- """KV cache for the MoSRAH sparse attention path — single decoder layer.
41
-
42
- Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
43
- Stores keys and values in the mixin-standard self.keys and self.values attributes
44
- using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
45
- which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
46
-
47
- All storage is allocated at construction time and is_initialized is True
48
- immediately. The caller (ShramLayerCache) provides batch size, device, and model
49
- config parameters so no lazy allocation is needed.
50
-
51
- Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
52
- (B, L, T) boolean active_mask. Only positions where active_mask is True are written.
53
- This matches the packed representation produced by expert packing in the MoSRAH
54
- forward pass, where BEA has already applied RoPE before calling update().
55
-
56
- Args:
57
- num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
58
- second dimension of all storage tensors.
59
- head_dim: Bottlenecked head embedding width (u). Determines the fourth
60
- dimension of all storage tensors.
61
- batch_size: Number of sequences in the batch. Determines the first dimension
62
- of all storage tensors.
63
- device: Device on which to allocate all tensors. Should match the model device.
64
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
65
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
66
- during prompt processing.
67
- """
68
-
69
- is_compileable = False
70
- is_sliding = False
71
-
72
- def __init__(
73
- self,
74
- num_mosrah_heads: int,
75
- head_dim: int,
76
- batch_size: int,
77
- device: torch.device,
78
- initial_buffer_size: int = 64,
79
- ) -> None:
80
- super().__init__()
81
- self.num_mosrah_heads = num_mosrah_heads
82
- self.head_dim = head_dim
83
- self.batch_size = batch_size
84
- self.device = device
85
-
86
- # Allocate primary storage into the mixin-standard self.keys / self.values so
87
- # that inherited methods (offload, prefetch) operate on real tensors. _counts
88
- # tracks valid occupancy per (batch, head) slot.
89
- self.keys: torch.Tensor = torch.zeros(
90
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
91
- )
92
- self.values: torch.Tensor = torch.zeros(
93
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
94
- )
95
- self._counts: torch.Tensor = torch.zeros(
96
- batch_size, num_mosrah_heads, dtype=torch.long, device=device
97
- )
98
-
99
- # Storage is fully allocated at construction — the cache is initialized.
100
- self.is_initialized = True
101
-
102
- # ---------------------------------------------------------------------------
103
- # Properties
104
- # ---------------------------------------------------------------------------
105
-
106
- @property
107
- def buffer_capacity(self) -> int:
108
- """Current number of slots allocated per (batch, head) pair.
109
-
110
- Derived directly from self.keys rather than tracked separately, so it is
111
- always consistent with the actual buffer after expansion.
112
- """
113
- return self.keys.shape[2]
114
-
115
- # ---------------------------------------------------------------------------
116
- # Primary API
117
- # ---------------------------------------------------------------------------
118
-
119
- def update( # type: ignore[override]
120
- self,
121
- key_states: torch.Tensor,
122
- value_states: torch.Tensor,
123
- active_mask: torch.Tensor,
124
- cache_kwargs: dict | None = None,
125
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
- """Scatter active key/value states into the buffer and return the full cache state.
127
-
128
- Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
129
- active_mask is (B, L, T) bool with True marking real tokens. Only active
130
- positions are written; inactive positions are ignored.
131
-
132
- Uses a cumsum construction to derive the absolute buffer position for each
133
- active token without any Python loops. For a given (batch, head) slot,
134
- positions are assigned in the order tokens appear along the T dimension,
135
- preserving causal ordering.
136
-
137
- Returns the full accumulated (keys, values, active_mask) across the cached
138
- sparse sequence. The returned active_mask is True exactly for slots t <
139
- counts[b, l]; everything beyond is junk data that BEA must exclude.
140
-
141
- Note: get_heads_lengths() must be called before update() if the caller needs
142
- the pre-update occupancy for position computation (Unit 10.A). update()
143
- increments counts in-place and the pre-update values are not recoverable.
144
-
145
- Args:
146
- key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
147
- value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
148
- active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
149
- cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
150
-
151
- Returns:
152
- Tuple of (keys, values, active_mask):
153
- keys: (B, L, T, u) float — full key buffer including junk slots.
154
- values: (B, L, T, u) float — full value buffer including junk slots.
155
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
156
- """
157
- incoming_delta = active_mask.long().sum(dim=2) # (B, L)
158
-
159
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
160
- self._expand()
161
-
162
- # Cumulative count of active positions along T for each (b, l) slot. Entry
163
- # [b, l, t] is the 1-based rank of position t among all active positions in
164
- # that slot. Subtract 1 for a zero-based within-slot index. Inactive positions
165
- # produce a negative value, which is excluded by the mask gate below.
166
- within_slot = active_mask.long().cumsum(dim=2) - 1 # (B, L, T)
167
-
168
- # Add the pre-update count to get the absolute buffer position for each
169
- # active token.
170
- abs_pos = within_slot + self._counts.unsqueeze(-1) # (B, L, T)
171
-
172
- # Scatter key and value vectors at all active positions.
173
- b_idx, l_idx, t_idx = torch.where(active_mask)
174
- self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
175
- key_states[b_idx, l_idx, t_idx]
176
- )
177
- self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
178
- value_states[b_idx, l_idx, t_idx]
179
- )
180
-
181
- self._counts += incoming_delta
182
-
183
- return self.keys, self.values, self._make_active_mask()
184
-
185
- def get_heads_lengths(self) -> torch.Tensor:
186
- """Return the per-(batch, head) token count for this layer.
187
-
188
- This is the authoritative occupancy tensor consumed by BEA for attention
189
- masking and by position computation (Unit 10.A) for semantic-sequence
190
- position computation.
191
-
192
- Note: in the MoSRAH forward pass, this must be called before update() if the
193
- caller needs the pre-update occupancy. update() increments these counts in-place.
194
-
195
- Returns:
196
- Integer tensor of shape (B, L) where entry [b, h] is the number of valid
197
- tokens stored in the (b, h) slot. Zero for slots with no writes yet.
198
- """
199
- return self._counts
200
-
201
- # ---------------------------------------------------------------------------
202
- # CacheLayerMixin — overridden coordination methods
203
- # ---------------------------------------------------------------------------
204
-
205
- def reset(self) -> None:
206
- """Clear all cached key and value tensors.
207
-
208
- Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
209
- and is_initialized remains True — only the contents are cleared.
210
- """
211
- self.keys.zero_()
212
- self.values.zero_()
213
- self._counts.zero_()
214
-
215
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
216
- """Reorder the batch dimension of all cached tensors for beam search.
217
-
218
- Applied atomically across self.keys, self.values, and _counts. Beam search
219
- must reorder all three together or the occupancy counts and buffer contents
220
- will correspond to different beam hypotheses.
221
-
222
- Overrides the parent because the parent's implementation calls get_seq_length(),
223
- which is not supported for this cache.
224
-
225
- Args:
226
- beam_idx: Permutation indices of shape (batch,) produced by the beam
227
- search algorithm.
228
- """
229
- self.keys = self.keys[beam_idx]
230
- self.values = self.values[beam_idx]
231
- self._counts = self._counts[beam_idx]
232
-
233
- def batch_repeat_interleave(self, repeats: int) -> None:
234
- """Expand the batch dimension by repeating each entry repeats times.
235
-
236
- Used at beam search initialisation to expand the cache from batch size B to
237
- B * repeats, matching the expanded beam candidate batch. Applied atomically
238
- across keys, values, and _counts; batch_size is updated to reflect the new size.
239
-
240
- Args:
241
- repeats: Number of times to repeat each batch entry.
242
- """
243
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
244
- self.values = self.values.repeat_interleave(repeats, dim=0)
245
- self._counts = self._counts.repeat_interleave(repeats, dim=0)
246
- self.batch_size = self.batch_size * repeats
247
-
248
- def batch_select_indices(self, indices: torch.Tensor) -> None:
249
- """Select a subset of batch entries by index.
250
-
251
- Used in contrastive search to retain only the selected candidate entries.
252
- Applied atomically across keys, values, and _counts; batch_size is updated
253
- to reflect the number of retained entries.
254
-
255
- Args:
256
- indices: 1-D integer tensor of batch indices to retain.
257
- """
258
- self.keys = self.keys[indices]
259
- self.values = self.values[indices]
260
- self._counts = self._counts[indices]
261
- self.batch_size = indices.shape[0]
262
-
263
- def offload(self) -> None:
264
- """Offload all cached tensors to CPU.
265
-
266
- Extends the parent to also offload _counts, which the parent does not know
267
- about. All three tensors are moved atomically so device state remains consistent.
268
- """
269
- super().offload()
270
- self._counts = self._counts.to("cpu", non_blocking=True)
271
-
272
- def prefetch(self) -> None:
273
- """Move all cached tensors back to the model device ahead of time.
274
-
275
- Extends the parent to also prefetch _counts, which the parent does not know
276
- about. _counts is synced to self.keys.device after the parent moves keys and
277
- values, so all three remain consistent.
278
- """
279
- super().prefetch()
280
- if self._counts.device != self.keys.device:
281
- self._counts = self._counts.to(self.keys.device, non_blocking=True)
282
-
283
- def lazy_initialization( # type: ignore[override]
284
- self, key_states: torch.Tensor, value_states: torch.Tensor
285
- ) -> None:
286
- """No-op — storage is fully allocated at construction time."""
287
- pass
288
-
289
- # ---------------------------------------------------------------------------
290
- # CacheLayerMixin — unsupported abstract methods
291
- # ---------------------------------------------------------------------------
292
-
293
- def get_seq_length(self) -> int: # type: ignore[override]
294
- """Not supported — no single sequence length represents this cache's state.
295
-
296
- MoSRAH heads accumulate independently; (batch, head) slots have different
297
- lengths depending on routing history. There is no meaningful scalar summary.
298
- Use get_heads_lengths() for per-head occupancy.
299
- """
300
- raise NotImplementedError(
301
- "MoSRAHCache has no single sequence length. "
302
- "Use get_heads_lengths() for per-head occupancy."
303
- )
304
-
305
- def get_max_cache_shape(self) -> int: # type: ignore[override]
306
- """Not supported — MoSRAHCache is dynamic and unbounded."""
307
- raise NotImplementedError(
308
- "MoSRAHCache is unbounded; get_max_cache_shape() is not supported."
309
- )
310
-
311
- def get_mask_sizes( # type: ignore[override]
312
- self,
313
- cache_position: torch.Tensor,
314
- ) -> tuple[int, int]:
315
- """Not supported — MoSRAHCache does not participate in HF mask construction."""
316
- raise NotImplementedError(
317
- "MoSRAHCache does not support get_mask_sizes()."
318
- )
319
-
320
- # ---------------------------------------------------------------------------
321
- # Internal helpers
322
- # ---------------------------------------------------------------------------
323
-
324
- def _make_active_mask(self) -> torch.Tensor:
325
- """Construct the (B, L, T) active mask from current counts.
326
-
327
- Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
328
- has been written. Positions at or beyond the count are junk and must be
329
- excluded by downstream attention.
330
- """
331
- cap = self.buffer_capacity
332
- return (
333
- torch.arange(cap, device=self.keys.device)
334
- .expand(self.batch_size, self.num_mosrah_heads, cap)
335
- < self._counts.unsqueeze(-1)
336
- )
337
-
338
- def _expand(self) -> None:
339
- """Double the buffer capacity, preserving existing data.
340
-
341
- Called by update() when an incoming batch of tokens would cause any
342
- (batch, head) slot to exceed the current buffer capacity. All existing
343
- key and value data is copied into the low half of the new buffer; the
344
- high half is zero-initialised and will be filled by subsequent writes.
345
- After reassignment, buffer_capacity reflects the new size automatically.
346
- """
347
- old_cap = self.buffer_capacity
348
- new_cap = old_cap * 2
349
- dev = self.keys.device
350
- new_keys = torch.zeros(
351
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
352
- )
353
- new_values = torch.zeros(
354
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
355
- )
356
- new_keys[:, :, :old_cap, :] = self.keys
357
- new_values[:, :, :old_cap, :] = self.values
358
- self.keys = new_keys
359
- self.values = new_values
 
1
+ """MoSRAH sparse KV cache — single-layer implementation.
2
+
3
+ MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
4
+ by head rather than by sequence position. The routing is dynamic and produces a ragged
5
+ distribution of token counts across (batch, head) slots — different batch items may
6
+ route different numbers of tokens to the same head, and different heads accumulate at
7
+ different rates. DynamicCache cannot represent this correctly: it concatenates along
8
+ the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
9
+ therefore uses a custom buffer design.
10
+
11
+ Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
12
+ attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
13
+ heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
14
+ head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
15
+ valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
16
+ buffer_capacity property and is derived directly from self.keys rather than tracked
17
+ as a separate variable.
18
+
19
+ The primary interface is update(key_states, value_states, active_mask), which accepts
20
+ expert-choice layout, stores only active entries in causal order, and returns the full
21
+ accumulated (keys, values, active_mask) for immediate use by BEA. The returned
22
+ active_mask identifies valid cached positions; everything beyond each slot's count is
23
+ junk data that downstream attention must exclude.
24
+
25
+ BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts
26
+ exposed by get_heads_lengths() must be read before update() if the caller needs the
27
+ pre-update occupancy for position computation (Unit 10.A). update() increments counts
28
+ in-place and the pre-update values are not recoverable afterward.
29
+
30
+ All buffers are allocated at construction time. MoSRAHCache is constructed by
31
+ ShramLayerCache, which has access to batch size, device, and all model config parameters
32
+ needed to fully specify the storage layout upfront.
33
+ """
34
+
35
+ import torch
36
+ from transformers.cache_utils import CacheLayerMixin
37
+
38
+
39
+ class MoSRAHCache(CacheLayerMixin):
40
+ """KV cache for the MoSRAH sparse attention path — single decoder layer.
41
+
42
+ Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
43
+ Stores keys and values in the mixin-standard self.keys and self.values attributes
44
+ using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
45
+ which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
46
+
47
+ All storage is allocated at construction time and is_initialized is True
48
+ immediately. The caller (ShramLayerCache) provides batch size, device, and model
49
+ config parameters so no lazy allocation is needed.
50
+
51
+ Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
52
+ (B, L, T) boolean active_mask. Only positions where active_mask is True are written.
53
+ This matches the packed representation produced by expert packing in the MoSRAH
54
+ forward pass, where BEA has already applied RoPE before calling update().
55
+
56
+ Args:
57
+ num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
58
+ second dimension of all storage tensors.
59
+ head_dim: Bottlenecked head embedding width (u). Determines the fourth
60
+ dimension of all storage tensors.
61
+ batch_size: Number of sequences in the batch. Determines the first dimension
62
+ of all storage tensors.
63
+ device: Device on which to allocate all tensors. Should match the model device.
64
+ initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
65
+ when any slot overflows. Defaults to 64 to avoid repeated reallocation
66
+ during prompt processing.
67
+ """
68
+
69
+ is_compileable = False
70
+ is_sliding = False
71
+
72
+ def __init__(
73
+ self,
74
+ num_mosrah_heads: int,
75
+ head_dim: int,
76
+ batch_size: int,
77
+ device: torch.device,
78
+ initial_buffer_size: int = 64,
79
+ ) -> None:
80
+ super().__init__()
81
+ self.num_mosrah_heads = num_mosrah_heads
82
+ self.head_dim = head_dim
83
+ self.batch_size = batch_size
84
+ self.device = device
85
+
86
+ # Allocate primary storage into the mixin-standard self.keys / self.values so
87
+ # that inherited methods (offload, prefetch) operate on real tensors. _counts
88
+ # tracks valid occupancy per (batch, head) slot.
89
+ self.keys: torch.Tensor = torch.zeros(
90
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
91
+ )
92
+ self.values: torch.Tensor = torch.zeros(
93
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
94
+ )
95
+ self._counts: torch.Tensor = torch.zeros(
96
+ batch_size, num_mosrah_heads, dtype=torch.long, device=device
97
+ )
98
+
99
+ # Storage is fully allocated at construction — the cache is initialized.
100
+ self.is_initialized = True
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Properties
104
+ # ---------------------------------------------------------------------------
105
+
106
+ @property
107
+ def buffer_capacity(self) -> int:
108
+ """Current number of slots allocated per (batch, head) pair.
109
+
110
+ Derived directly from self.keys rather than tracked separately, so it is
111
+ always consistent with the actual buffer after expansion.
112
+ """
113
+ return self.keys.shape[2]
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Primary API
117
+ # ---------------------------------------------------------------------------
118
+
119
+ def update( # type: ignore[override]
120
+ self,
121
+ key_states: torch.Tensor,
122
+ value_states: torch.Tensor,
123
+ active_mask: torch.Tensor,
124
+ cache_kwargs: dict | None = None,
125
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ """Scatter active key/value states into the buffer and return the full cache state.
127
+
128
+ Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
129
+ active_mask is (B, L, T) bool with True marking real tokens. Only active
130
+ positions are written; inactive positions are ignored.
131
+
132
+ Uses a cumsum construction to derive the absolute buffer position for each
133
+ active token without any Python loops. For a given (batch, head) slot,
134
+ positions are assigned in the order tokens appear along the T dimension,
135
+ preserving causal ordering.
136
+
137
+ Returns the full accumulated (keys, values, active_mask) across the cached
138
+ sparse sequence. The returned active_mask is True exactly for slots t <
139
+ counts[b, l]; everything beyond is junk data that BEA must exclude.
140
+
141
+ Note: get_heads_lengths() must be called before update() if the caller needs
142
+ the pre-update occupancy for position computation (Unit 10.A). update()
143
+ increments counts in-place and the pre-update values are not recoverable.
144
+
145
+ Args:
146
+ key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
147
+ value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
148
+ active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
149
+ cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
150
+
151
+ Returns:
152
+ Tuple of (keys, values, active_mask):
153
+ keys: (B, L, T, u) float — full key buffer including junk slots.
154
+ values: (B, L, T, u) float — full value buffer including junk slots.
155
+ active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
156
+ """
157
+ incoming_delta = active_mask.long().sum(dim=2) # (B, L)
158
+
159
+ if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
160
+ self._expand()
161
+
162
+ # Cumulative count of active positions along T for each (b, l) slot. Entry
163
+ # [b, l, t] is the 1-based rank of position t among all active positions in
164
+ # that slot. Subtract 1 for a zero-based within-slot index. Inactive positions
165
+ # produce a negative value, which is excluded by the mask gate below.
166
+ within_slot = active_mask.long().cumsum(dim=2) - 1 # (B, L, T)
167
+
168
+ # Add the pre-update count to get the absolute buffer position for each
169
+ # active token.
170
+ abs_pos = within_slot + self._counts.unsqueeze(-1) # (B, L, T)
171
+
172
+ # Scatter key and value vectors at all active positions.
173
+ b_idx, l_idx, t_idx = torch.where(active_mask)
174
+ self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
175
+ key_states[b_idx, l_idx, t_idx]
176
+ )
177
+ self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
178
+ value_states[b_idx, l_idx, t_idx]
179
+ )
180
+
181
+ self._counts += incoming_delta
182
+
183
+ return self.keys, self.values, self._make_active_mask()
184
+
185
+ def get_heads_lengths(self) -> torch.Tensor:
186
+ """Return the per-(batch, head) token count for this layer.
187
+
188
+ This is the authoritative occupancy tensor consumed by BEA for attention
189
+ masking and by position computation (Unit 10.A) for semantic-sequence
190
+ position computation.
191
+
192
+ Note: in the MoSRAH forward pass, this must be called before update() if the
193
+ caller needs the pre-update occupancy. update() increments these counts in-place.
194
+
195
+ Returns:
196
+ Integer tensor of shape (B, L) where entry [b, h] is the number of valid
197
+ tokens stored in the (b, h) slot. Zero for slots with no writes yet.
198
+ """
199
+ return self._counts
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # CacheLayerMixin — overridden coordination methods
203
+ # ---------------------------------------------------------------------------
204
+
205
+ def reset(self) -> None:
206
+ """Clear all cached key and value tensors.
207
+
208
+ Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
209
+ and is_initialized remains True — only the contents are cleared.
210
+ """
211
+ self.keys.zero_()
212
+ self.values.zero_()
213
+ self._counts.zero_()
214
+
215
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
216
+ """Reorder the batch dimension of all cached tensors for beam search.
217
+
218
+ Applied atomically across self.keys, self.values, and _counts. Beam search
219
+ must reorder all three together or the occupancy counts and buffer contents
220
+ will correspond to different beam hypotheses.
221
+
222
+ Overrides the parent because the parent's implementation calls get_seq_length(),
223
+ which is not supported for this cache.
224
+
225
+ Args:
226
+ beam_idx: Permutation indices of shape (batch,) produced by the beam
227
+ search algorithm.
228
+ """
229
+ self.keys = self.keys[beam_idx]
230
+ self.values = self.values[beam_idx]
231
+ self._counts = self._counts[beam_idx]
232
+
233
+ def batch_repeat_interleave(self, repeats: int) -> None:
234
+ """Expand the batch dimension by repeating each entry repeats times.
235
+
236
+ Used at beam search initialisation to expand the cache from batch size B to
237
+ B * repeats, matching the expanded beam candidate batch. Applied atomically
238
+ across keys, values, and _counts; batch_size is updated to reflect the new size.
239
+
240
+ Args:
241
+ repeats: Number of times to repeat each batch entry.
242
+ """
243
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
244
+ self.values = self.values.repeat_interleave(repeats, dim=0)
245
+ self._counts = self._counts.repeat_interleave(repeats, dim=0)
246
+ self.batch_size = self.batch_size * repeats
247
+
248
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
249
+ """Select a subset of batch entries by index.
250
+
251
+ Used in contrastive search to retain only the selected candidate entries.
252
+ Applied atomically across keys, values, and _counts; batch_size is updated
253
+ to reflect the number of retained entries.
254
+
255
+ Args:
256
+ indices: 1-D integer tensor of batch indices to retain.
257
+ """
258
+ self.keys = self.keys[indices]
259
+ self.values = self.values[indices]
260
+ self._counts = self._counts[indices]
261
+ self.batch_size = indices.shape[0]
262
+
263
+ def offload(self) -> None:
264
+ """Offload all cached tensors to CPU.
265
+
266
+ Extends the parent to also offload _counts, which the parent does not know
267
+ about. All three tensors are moved atomically so device state remains consistent.
268
+ """
269
+ super().offload()
270
+ self._counts = self._counts.to("cpu", non_blocking=True)
271
+
272
+ def prefetch(self) -> None:
273
+ """Move all cached tensors back to the model device ahead of time.
274
+
275
+ Extends the parent to also prefetch _counts, which the parent does not know
276
+ about. _counts is synced to self.keys.device after the parent moves keys and
277
+ values, so all three remain consistent.
278
+ """
279
+ super().prefetch()
280
+ if self._counts.device != self.keys.device:
281
+ self._counts = self._counts.to(self.keys.device, non_blocking=True)
282
+
283
+ def lazy_initialization( # type: ignore[override]
284
+ self, key_states: torch.Tensor, value_states: torch.Tensor
285
+ ) -> None:
286
+ """No-op — storage is fully allocated at construction time."""
287
+ pass
288
+
289
+ # ---------------------------------------------------------------------------
290
+ # CacheLayerMixin — unsupported abstract methods
291
+ # ---------------------------------------------------------------------------
292
+
293
+ def get_seq_length(self) -> int: # type: ignore[override]
294
+ """Not supported — no single sequence length represents this cache's state.
295
+
296
+ MoSRAH heads accumulate independently; (batch, head) slots have different
297
+ lengths depending on routing history. There is no meaningful scalar summary.
298
+ Use get_heads_lengths() for per-head occupancy.
299
+ """
300
+ raise NotImplementedError(
301
+ "MoSRAHCache has no single sequence length. "
302
+ "Use get_heads_lengths() for per-head occupancy."
303
+ )
304
+
305
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
306
+ """Not supported — MoSRAHCache is dynamic and unbounded."""
307
+ raise NotImplementedError(
308
+ "MoSRAHCache is unbounded; get_max_cache_shape() is not supported."
309
+ )
310
+
311
+ def get_mask_sizes( # type: ignore[override]
312
+ self,
313
+ cache_position: torch.Tensor,
314
+ ) -> tuple[int, int]:
315
+ """Not supported — MoSRAHCache does not participate in HF mask construction."""
316
+ raise NotImplementedError(
317
+ "MoSRAHCache does not support get_mask_sizes()."
318
+ )
319
+
320
+ # ---------------------------------------------------------------------------
321
+ # Internal helpers
322
+ # ---------------------------------------------------------------------------
323
+
324
+ def _make_active_mask(self) -> torch.Tensor:
325
+ """Construct the (B, L, T) active mask from current counts.
326
+
327
+ Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
328
+ has been written. Positions at or beyond the count are junk and must be
329
+ excluded by downstream attention.
330
+ """
331
+ cap = self.buffer_capacity
332
+ return (
333
+ torch.arange(cap, device=self.keys.device)
334
+ .expand(self.batch_size, self.num_mosrah_heads, cap)
335
+ < self._counts.unsqueeze(-1)
336
+ )
337
+
338
+ def _expand(self) -> None:
339
+ """Double the buffer capacity, preserving existing data.
340
+
341
+ Called by update() when an incoming batch of tokens would cause any
342
+ (batch, head) slot to exceed the current buffer capacity. All existing
343
+ key and value data is copied into the low half of the new buffer; the
344
+ high half is zero-initialised and will be filled by subsequent writes.
345
+ After reassignment, buffer_capacity reflects the new size automatically.
346
+ """
347
+ old_cap = self.buffer_capacity
348
+ new_cap = old_cap * 2
349
+ dev = self.keys.device
350
+ new_keys = torch.zeros(
351
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
352
+ )
353
+ new_values = torch.zeros(
354
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
355
+ )
356
+ new_keys[:, :, :old_cap, :] = self.keys
357
+ new_values[:, :, :old_cap, :] = self.values
358
+ self.keys = new_keys
359
+ self.values = new_values
__cache__shram_cache.py CHANGED
@@ -1,141 +1,141 @@
1
- """SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
2
-
3
- The HuggingFace Cache protocol expects a single top-level Cache object that owns one
4
- CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
5
- lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
6
- ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
7
- presents them through the Cache interface, and transparently forwards model-wide operations
8
- across all of them.
9
-
10
- ShramCache does not define a composite update() interface. The two attention paths inside each
11
- SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
12
- nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
13
- relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
14
- coordination of the layer caches — not routing attention inputs.
15
-
16
- Sequence length is reported by delegating to the local sliding-window sub-cache of the
17
- specified layer, which tracks the cumulative count of token positions processed. This is
18
- what HuggingFace generation reads through get_seq_length().
19
- """
20
-
21
- import torch
22
- from transformers.cache_utils import Cache
23
-
24
- from .__cache__shram_layer_cache import ShramLayerCache
25
-
26
-
27
- class ShramCache(Cache):
28
- """Top-level cache for the full SHRAM model.
29
-
30
- Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
31
- role and transparently forwards reset, reorder, and sequence-length queries across all
32
- owned layer caches.
33
-
34
- No composite update() interface is provided. The two attention paths inside each SHRAM
35
- layer have materially different update semantics; callers must update sub-caches directly
36
- via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
37
-
38
- Args:
39
- num_hidden_layers: Number of SHRAM decoder layers. Determines how many
40
- ShramLayerCache objects are constructed.
41
- sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
42
- num_local_heads: Number of local attention heads per layer.
43
- local_head_dim: Per-head embedding width for the local path.
44
- num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
45
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
46
- batch_size: Number of sequences in the batch.
47
- device: Device on which to allocate cache tensors.
48
- initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
49
- Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
50
- during prompt processing.
51
- """
52
-
53
- def __init__(
54
- self,
55
- num_hidden_layers: int,
56
- sliding_window: int,
57
- num_local_heads: int,
58
- local_head_dim: int,
59
- num_mosrah_heads: int,
60
- mosrah_head_dim: int,
61
- batch_size: int,
62
- device: torch.device,
63
- initial_buffer_size: int = 64,
64
- ) -> None:
65
- layers = [
66
- ShramLayerCache(
67
- sliding_window=sliding_window,
68
- num_local_heads=num_local_heads,
69
- local_head_dim=local_head_dim,
70
- num_mosrah_heads=num_mosrah_heads,
71
- mosrah_head_dim=mosrah_head_dim,
72
- batch_size=batch_size,
73
- device=device,
74
- initial_buffer_size=initial_buffer_size,
75
- )
76
- for _ in range(num_hidden_layers)
77
- ]
78
- super().__init__(layers=layers)
79
-
80
- # ---------------------------------------------------------------------------
81
- # Cache — composite-meaningful methods
82
- # ---------------------------------------------------------------------------
83
- #
84
- # reset(): Inherited. Iterates all layer caches and calls reset() on each.
85
- #
86
- # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
87
- #
88
- # is_initialized: Inherited property. True iff all layer caches are initialized.
89
- # Since ShramLayerCache.is_initialized is True from construction, this is True
90
- # immediately after ShramCache.__init__ returns.
91
-
92
- def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
93
- """Return the cumulative sequence length for the specified layer.
94
-
95
- Delegates to the layer cache at layer_idx, which in turn delegates to the
96
- local sliding-window sub-cache. That sub-cache is authoritative for sequence
97
- progress: it sees every token presented to the layer and accumulates a truthful
98
- total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
99
- """
100
- return self.layers[layer_idx].get_seq_length()
101
-
102
- # ---------------------------------------------------------------------------
103
- # Cache — unsupported methods
104
- # ---------------------------------------------------------------------------
105
-
106
- def update( # type: ignore[override]
107
- self,
108
- key_states: torch.Tensor,
109
- value_states: torch.Tensor,
110
- layer_idx: int,
111
- cache_kwargs: dict | None = None,
112
- ) -> tuple[torch.Tensor, torch.Tensor]:
113
- """Not supported — ShramCache has no composite update interface.
114
-
115
- The two attention paths inside each SHRAM layer have different update semantics.
116
- Callers must update sub-caches directly:
117
- cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
118
- cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
119
- """
120
- raise NotImplementedError(
121
- "ShramCache has no composite update interface. "
122
- "Update sliding_window_cache or mosrah_cache on the relevant layer directly."
123
- )
124
-
125
- def crop(self, max_length: int) -> None:
126
- """Not supported — ShramCache layers do not implement crop()."""
127
- raise NotImplementedError("ShramCache does not support crop().")
128
-
129
- @property
130
- def max_batch_size(self) -> int:
131
- """Not supported — ShramCache does not track a uniform batch size across layers."""
132
- raise NotImplementedError("ShramCache does not expose max_batch_size.")
133
-
134
- @property
135
- def max_cache_len(self) -> int:
136
- """Not supported — ShramCache has no single maximum cache length.
137
-
138
- The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
139
- No truthful scalar maximum represents the composite.
140
- """
141
- raise NotImplementedError("ShramCache does not expose max_cache_len.")
 
1
+ """SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
2
+
3
+ The HuggingFace Cache protocol expects a single top-level Cache object that owns one
4
+ CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
5
+ lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
6
+ ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
7
+ presents them through the Cache interface, and transparently forwards model-wide operations
8
+ across all of them.
9
+
10
+ ShramCache does not define a composite update() interface. The two attention paths inside each
11
+ SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
12
+ nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
13
+ relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
14
+ coordination of the layer caches — not routing attention inputs.
15
+
16
+ Sequence length is reported by delegating to the local sliding-window sub-cache of the
17
+ specified layer, which tracks the cumulative count of token positions processed. This is
18
+ what HuggingFace generation reads through get_seq_length().
19
+ """
20
+
21
+ import torch
22
+ from transformers.cache_utils import Cache
23
+
24
+ from .__cache__shram_layer_cache import ShramLayerCache
25
+
26
+
27
+ class ShramCache(Cache):
28
+ """Top-level cache for the full SHRAM model.
29
+
30
+ Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
31
+ role and transparently forwards reset, reorder, and sequence-length queries across all
32
+ owned layer caches.
33
+
34
+ No composite update() interface is provided. The two attention paths inside each SHRAM
35
+ layer have materially different update semantics; callers must update sub-caches directly
36
+ via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
37
+
38
+ Args:
39
+ num_hidden_layers: Number of SHRAM decoder layers. Determines how many
40
+ ShramLayerCache objects are constructed.
41
+ sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
42
+ num_local_heads: Number of local attention heads per layer.
43
+ local_head_dim: Per-head embedding width for the local path.
44
+ num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
45
+ mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
46
+ batch_size: Number of sequences in the batch.
47
+ device: Device on which to allocate cache tensors.
48
+ initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
49
+ Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
50
+ during prompt processing.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ num_hidden_layers: int,
56
+ sliding_window: int,
57
+ num_local_heads: int,
58
+ local_head_dim: int,
59
+ num_mosrah_heads: int,
60
+ mosrah_head_dim: int,
61
+ batch_size: int,
62
+ device: torch.device,
63
+ initial_buffer_size: int = 64,
64
+ ) -> None:
65
+ layers = [
66
+ ShramLayerCache(
67
+ sliding_window=sliding_window,
68
+ num_local_heads=num_local_heads,
69
+ local_head_dim=local_head_dim,
70
+ num_mosrah_heads=num_mosrah_heads,
71
+ mosrah_head_dim=mosrah_head_dim,
72
+ batch_size=batch_size,
73
+ device=device,
74
+ initial_buffer_size=initial_buffer_size,
75
+ )
76
+ for _ in range(num_hidden_layers)
77
+ ]
78
+ super().__init__(layers=layers)
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Cache — composite-meaningful methods
82
+ # ---------------------------------------------------------------------------
83
+ #
84
+ # reset(): Inherited. Iterates all layer caches and calls reset() on each.
85
+ #
86
+ # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
87
+ #
88
+ # is_initialized: Inherited property. True iff all layer caches are initialized.
89
+ # Since ShramLayerCache.is_initialized is True from construction, this is True
90
+ # immediately after ShramCache.__init__ returns.
91
+
92
+ def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
93
+ """Return the cumulative sequence length for the specified layer.
94
+
95
+ Delegates to the layer cache at layer_idx, which in turn delegates to the
96
+ local sliding-window sub-cache. That sub-cache is authoritative for sequence
97
+ progress: it sees every token presented to the layer and accumulates a truthful
98
+ total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
99
+ """
100
+ return self.layers[layer_idx].get_seq_length()
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Cache — unsupported methods
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def update( # type: ignore[override]
107
+ self,
108
+ key_states: torch.Tensor,
109
+ value_states: torch.Tensor,
110
+ layer_idx: int,
111
+ cache_kwargs: dict | None = None,
112
+ ) -> tuple[torch.Tensor, torch.Tensor]:
113
+ """Not supported — ShramCache has no composite update interface.
114
+
115
+ The two attention paths inside each SHRAM layer have different update semantics.
116
+ Callers must update sub-caches directly:
117
+ cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
118
+ cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
119
+ """
120
+ raise NotImplementedError(
121
+ "ShramCache has no composite update interface. "
122
+ "Update sliding_window_cache or mosrah_cache on the relevant layer directly."
123
+ )
124
+
125
+ def crop(self, max_length: int) -> None:
126
+ """Not supported — ShramCache layers do not implement crop()."""
127
+ raise NotImplementedError("ShramCache does not support crop().")
128
+
129
+ @property
130
+ def max_batch_size(self) -> int:
131
+ """Not supported — ShramCache does not track a uniform batch size across layers."""
132
+ raise NotImplementedError("ShramCache does not expose max_batch_size.")
133
+
134
+ @property
135
+ def max_cache_len(self) -> int:
136
+ """Not supported — ShramCache has no single maximum cache length.
137
+
138
+ The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
139
+ No truthful scalar maximum represents the composite.
140
+ """
141
+ raise NotImplementedError("ShramCache does not expose max_cache_len.")
__cache__shram_layer_cache.py CHANGED
@@ -1,233 +1,233 @@
1
- """SHRAM per-layer cache — composite owner for one SHRAM decoder layer.
2
-
3
- A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the
4
- local sliding-window path and the MoSRAH sparse path. Each path has its own cache with
5
- different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies
6
- the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention
7
- path can interact with it without indirection.
8
-
9
- ShramLayerCache does not define a composite update() interface. The two paths have materially
10
- different update semantics — the local side uses chunk-local key/value/mask concatenation
11
- while the MoSRAH side uses expert-choice scatter with an active mask — and merging these
12
- behind a single update() would hide those differences behind a misleading abstraction. Instead,
13
- each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the
14
- ownership, coordination, and reset/reorder boundary for one decoder layer.
15
-
16
- Sequence length at this boundary is reported by delegating to the local sliding-window
17
- sub-cache, which tracks the cumulative count of token positions processed. This is the
18
- quantity HuggingFace generation reads through get_seq_length().
19
- """
20
-
21
- import torch
22
- from transformers.cache_utils import CacheLayerMixin
23
-
24
- from .__cache__mosrah_cache import MoSRAHCache
25
- from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
26
-
27
-
28
- class ShramLayerCache(CacheLayerMixin):
29
- """Cache subsystem for one SHRAM decoder layer.
30
-
31
- Owns and coordinates two sub-caches:
32
- - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path.
33
- - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path.
34
-
35
- Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are
36
- exposed directly for their downstream attention paths — no composite update() interface is
37
- provided, because the two paths have materially different update semantics.
38
-
39
- Sequence length is reported by delegating to the local sliding-window sub-cache, which
40
- tracks the cumulative count of token positions processed across all update() calls.
41
-
42
- Args:
43
- sliding_window: Number of tokens retained by the local sliding-window cache.
44
- num_local_heads: Number of local attention heads.
45
- local_head_dim: Per-head embedding width for the local path.
46
- num_mosrah_heads: Total number of MoSRAH expert heads (L).
47
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
48
- batch_size: Number of sequences in the batch.
49
- device: Device on which to allocate cache tensors.
50
- initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
51
- when any slot overflows. Defaults to 64 to avoid repeated reallocation during
52
- prompt processing.
53
- """
54
-
55
- is_compileable = False
56
- is_sliding = False
57
-
58
- def __init__(
59
- self,
60
- sliding_window: int,
61
- num_local_heads: int,
62
- local_head_dim: int,
63
- num_mosrah_heads: int,
64
- mosrah_head_dim: int,
65
- batch_size: int,
66
- device: torch.device,
67
- initial_buffer_size: int = 64,
68
- ) -> None:
69
- super().__init__()
70
- self.sliding_window_cache = LocalSlidingWindowLayerCache(
71
- sliding_window=sliding_window,
72
- num_heads=num_local_heads,
73
- head_dim=local_head_dim,
74
- batch_size=batch_size,
75
- device=device,
76
- )
77
- self.mosrah_cache = MoSRAHCache(
78
- num_mosrah_heads=num_mosrah_heads,
79
- head_dim=mosrah_head_dim,
80
- batch_size=batch_size,
81
- device=device,
82
- initial_buffer_size=initial_buffer_size,
83
- )
84
-
85
- # ---------------------------------------------------------------------------
86
- # Properties
87
- # ---------------------------------------------------------------------------
88
-
89
- @property
90
- def is_initialized(self) -> bool:
91
- """True iff both sub-caches have allocated their storage.
92
-
93
- Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction,
94
- so this is True immediately after ShramLayerCache.__init__ returns.
95
- """
96
- return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized
97
-
98
- @is_initialized.setter
99
- def is_initialized(self, value: bool) -> None:
100
- # CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance
101
- # attribute. Since property is a data descriptor it takes precedence, but Python
102
- # still routes the assignment through __set__. Absorb it silently — state is
103
- # derived from sub-caches, not stored here.
104
- pass
105
-
106
- # ---------------------------------------------------------------------------
107
- # CacheLayerMixin — composite-meaningful methods
108
- # ---------------------------------------------------------------------------
109
-
110
- def get_seq_length(self) -> int: # type: ignore[override]
111
- """Return the cumulative sequence length from the local sliding-window path.
112
-
113
- The local path is authoritative for sequence progress: it sees every token
114
- presented to this layer and accumulates a truthful total. Delegates to
115
- sliding_window_cache.get_seq_length().
116
- """
117
- return self.sliding_window_cache.get_seq_length()
118
-
119
- def reset(self) -> None:
120
- """Clear both sub-caches.
121
-
122
- Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window
123
- state and MoSRAH sparse state remain consistent.
124
- """
125
- self.sliding_window_cache.reset()
126
- self.mosrah_cache.reset()
127
-
128
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
129
- """Reorder the batch dimension of both sub-caches for beam search.
130
-
131
- Delegates to each sub-cache. Both are reordered atomically so the sliding-window
132
- and MoSRAH state correspond to the same beam hypotheses after reordering.
133
-
134
- Args:
135
- beam_idx: Permutation indices of shape (batch,) produced by beam search.
136
- """
137
- self.sliding_window_cache.reorder_cache(beam_idx)
138
- self.mosrah_cache.reorder_cache(beam_idx)
139
-
140
- def batch_repeat_interleave(self, repeats: int) -> None:
141
- """Expand the batch dimension of both sub-caches for beam search initialisation.
142
-
143
- Delegates atomically to each sub-cache. Both must be expanded together so the
144
- sliding-window and MoSRAH state correspond to the same beam candidates.
145
-
146
- Args:
147
- repeats: Number of times to repeat each batch entry.
148
- """
149
- self.sliding_window_cache.batch_repeat_interleave(repeats)
150
- self.mosrah_cache.batch_repeat_interleave(repeats)
151
-
152
- def batch_select_indices(self, indices: torch.Tensor) -> None:
153
- """Select a subset of batch entries in both sub-caches for contrastive search.
154
-
155
- Delegates atomically to each sub-cache. Both must be trimmed together so the
156
- sliding-window and MoSRAH state remain consistent.
157
-
158
- Args:
159
- indices: 1-D integer tensor of batch indices to retain.
160
- """
161
- self.sliding_window_cache.batch_select_indices(indices)
162
- self.mosrah_cache.batch_select_indices(indices)
163
-
164
- def offload(self) -> None:
165
- """Offload both sub-caches to CPU.
166
-
167
- Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache
168
- does not own self.keys/self.values directly; all cached data lives in the sub-caches.
169
- """
170
- self.sliding_window_cache.offload()
171
- self.mosrah_cache.offload()
172
-
173
- def prefetch(self) -> None:
174
- """Move both sub-caches back to their model device ahead of time.
175
-
176
- Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache
177
- does not own self.keys/self.values directly; all cached data lives in the sub-caches.
178
- """
179
- self.sliding_window_cache.prefetch()
180
- self.mosrah_cache.prefetch()
181
-
182
- def lazy_initialization( # type: ignore[override]
183
- self, key_states: torch.Tensor, value_states: torch.Tensor
184
- ) -> None:
185
- """No-op — both sub-caches handle their own initialization."""
186
- pass
187
-
188
- # ---------------------------------------------------------------------------
189
- # CacheLayerMixin — unsupported abstract methods
190
- # ---------------------------------------------------------------------------
191
-
192
- def update( # type: ignore[override]
193
- self,
194
- key_states: torch.Tensor,
195
- value_states: torch.Tensor,
196
- cache_kwargs: dict | None = None,
197
- ) -> tuple[torch.Tensor, torch.Tensor]:
198
- """Not supported — ShramLayerCache has no composite update interface.
199
-
200
- The two sub-caches have materially different update semantics: the sliding-window
201
- side uses standard key/value concatenation while the MoSRAH side uses expert-choice
202
- scatter with an active mask. Callers must update each sub-cache directly via
203
- sliding_window_cache.update() or mosrah_cache.update().
204
- """
205
- raise NotImplementedError(
206
- "ShramLayerCache has no composite update interface. "
207
- "Update sliding_window_cache or mosrah_cache directly."
208
- )
209
-
210
- def get_max_cache_shape(self) -> int: # type: ignore[override]
211
- """Not supported — the composite cache has no single maximum shape.
212
-
213
- The sliding-window cache is bounded by sliding_window; the MoSRAH cache is
214
- unbounded. No truthful scalar maximum represents the composite.
215
- """
216
- raise NotImplementedError(
217
- "ShramLayerCache has no single maximum cache shape. "
218
- "Query sliding_window_cache or mosrah_cache directly."
219
- )
220
-
221
- def get_mask_sizes( # type: ignore[override]
222
- self,
223
- cache_position: torch.Tensor,
224
- ) -> tuple[int, int]:
225
- """Not supported — ShramLayerCache does not participate in HF mask construction.
226
-
227
- The two sub-caches have different mask semantics and their respective attention
228
- paths handle masking directly.
229
- """
230
- raise NotImplementedError(
231
- "ShramLayerCache does not support get_mask_sizes(). "
232
- "Query sliding_window_cache or mosrah_cache directly."
233
- )
 
1
+ """SHRAM per-layer cache — composite owner for one SHRAM decoder layer.
2
+
3
+ A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the
4
+ local sliding-window path and the MoSRAH sparse path. Each path has its own cache with
5
+ different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies
6
+ the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention
7
+ path can interact with it without indirection.
8
+
9
+ ShramLayerCache does not define a composite update() interface. The two paths have materially
10
+ different update semantics — the local side uses chunk-local key/value/mask concatenation
11
+ while the MoSRAH side uses expert-choice scatter with an active mask — and merging these
12
+ behind a single update() would hide those differences behind a misleading abstraction. Instead,
13
+ each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the
14
+ ownership, coordination, and reset/reorder boundary for one decoder layer.
15
+
16
+ Sequence length at this boundary is reported by delegating to the local sliding-window
17
+ sub-cache, which tracks the cumulative count of token positions processed. This is the
18
+ quantity HuggingFace generation reads through get_seq_length().
19
+ """
20
+
21
+ import torch
22
+ from transformers.cache_utils import CacheLayerMixin
23
+
24
+ from .__cache__mosrah_cache import MoSRAHCache
25
+ from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
26
+
27
+
28
+ class ShramLayerCache(CacheLayerMixin):
29
+ """Cache subsystem for one SHRAM decoder layer.
30
+
31
+ Owns and coordinates two sub-caches:
32
+ - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path.
33
+ - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path.
34
+
35
+ Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are
36
+ exposed directly for their downstream attention paths — no composite update() interface is
37
+ provided, because the two paths have materially different update semantics.
38
+
39
+ Sequence length is reported by delegating to the local sliding-window sub-cache, which
40
+ tracks the cumulative count of token positions processed across all update() calls.
41
+
42
+ Args:
43
+ sliding_window: Number of tokens retained by the local sliding-window cache.
44
+ num_local_heads: Number of local attention heads.
45
+ local_head_dim: Per-head embedding width for the local path.
46
+ num_mosrah_heads: Total number of MoSRAH expert heads (L).
47
+ mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
48
+ batch_size: Number of sequences in the batch.
49
+ device: Device on which to allocate cache tensors.
50
+ initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
51
+ when any slot overflows. Defaults to 64 to avoid repeated reallocation during
52
+ prompt processing.
53
+ """
54
+
55
+ is_compileable = False
56
+ is_sliding = False
57
+
58
+ def __init__(
59
+ self,
60
+ sliding_window: int,
61
+ num_local_heads: int,
62
+ local_head_dim: int,
63
+ num_mosrah_heads: int,
64
+ mosrah_head_dim: int,
65
+ batch_size: int,
66
+ device: torch.device,
67
+ initial_buffer_size: int = 64,
68
+ ) -> None:
69
+ super().__init__()
70
+ self.sliding_window_cache = LocalSlidingWindowLayerCache(
71
+ sliding_window=sliding_window,
72
+ num_heads=num_local_heads,
73
+ head_dim=local_head_dim,
74
+ batch_size=batch_size,
75
+ device=device,
76
+ )
77
+ self.mosrah_cache = MoSRAHCache(
78
+ num_mosrah_heads=num_mosrah_heads,
79
+ head_dim=mosrah_head_dim,
80
+ batch_size=batch_size,
81
+ device=device,
82
+ initial_buffer_size=initial_buffer_size,
83
+ )
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Properties
87
+ # ---------------------------------------------------------------------------
88
+
89
+ @property
90
+ def is_initialized(self) -> bool:
91
+ """True iff both sub-caches have allocated their storage.
92
+
93
+ Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction,
94
+ so this is True immediately after ShramLayerCache.__init__ returns.
95
+ """
96
+ return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized
97
+
98
+ @is_initialized.setter
99
+ def is_initialized(self, value: bool) -> None:
100
+ # CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance
101
+ # attribute. Since property is a data descriptor it takes precedence, but Python
102
+ # still routes the assignment through __set__. Absorb it silently — state is
103
+ # derived from sub-caches, not stored here.
104
+ pass
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # CacheLayerMixin — composite-meaningful methods
108
+ # ---------------------------------------------------------------------------
109
+
110
+ def get_seq_length(self) -> int: # type: ignore[override]
111
+ """Return the cumulative sequence length from the local sliding-window path.
112
+
113
+ The local path is authoritative for sequence progress: it sees every token
114
+ presented to this layer and accumulates a truthful total. Delegates to
115
+ sliding_window_cache.get_seq_length().
116
+ """
117
+ return self.sliding_window_cache.get_seq_length()
118
+
119
+ def reset(self) -> None:
120
+ """Clear both sub-caches.
121
+
122
+ Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window
123
+ state and MoSRAH sparse state remain consistent.
124
+ """
125
+ self.sliding_window_cache.reset()
126
+ self.mosrah_cache.reset()
127
+
128
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
129
+ """Reorder the batch dimension of both sub-caches for beam search.
130
+
131
+ Delegates to each sub-cache. Both are reordered atomically so the sliding-window
132
+ and MoSRAH state correspond to the same beam hypotheses after reordering.
133
+
134
+ Args:
135
+ beam_idx: Permutation indices of shape (batch,) produced by beam search.
136
+ """
137
+ self.sliding_window_cache.reorder_cache(beam_idx)
138
+ self.mosrah_cache.reorder_cache(beam_idx)
139
+
140
+ def batch_repeat_interleave(self, repeats: int) -> None:
141
+ """Expand the batch dimension of both sub-caches for beam search initialisation.
142
+
143
+ Delegates atomically to each sub-cache. Both must be expanded together so the
144
+ sliding-window and MoSRAH state correspond to the same beam candidates.
145
+
146
+ Args:
147
+ repeats: Number of times to repeat each batch entry.
148
+ """
149
+ self.sliding_window_cache.batch_repeat_interleave(repeats)
150
+ self.mosrah_cache.batch_repeat_interleave(repeats)
151
+
152
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
153
+ """Select a subset of batch entries in both sub-caches for contrastive search.
154
+
155
+ Delegates atomically to each sub-cache. Both must be trimmed together so the
156
+ sliding-window and MoSRAH state remain consistent.
157
+
158
+ Args:
159
+ indices: 1-D integer tensor of batch indices to retain.
160
+ """
161
+ self.sliding_window_cache.batch_select_indices(indices)
162
+ self.mosrah_cache.batch_select_indices(indices)
163
+
164
+ def offload(self) -> None:
165
+ """Offload both sub-caches to CPU.
166
+
167
+ Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache
168
+ does not own self.keys/self.values directly; all cached data lives in the sub-caches.
169
+ """
170
+ self.sliding_window_cache.offload()
171
+ self.mosrah_cache.offload()
172
+
173
+ def prefetch(self) -> None:
174
+ """Move both sub-caches back to their model device ahead of time.
175
+
176
+ Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache
177
+ does not own self.keys/self.values directly; all cached data lives in the sub-caches.
178
+ """
179
+ self.sliding_window_cache.prefetch()
180
+ self.mosrah_cache.prefetch()
181
+
182
+ def lazy_initialization( # type: ignore[override]
183
+ self, key_states: torch.Tensor, value_states: torch.Tensor
184
+ ) -> None:
185
+ """No-op — both sub-caches handle their own initialization."""
186
+ pass
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # CacheLayerMixin — unsupported abstract methods
190
+ # ---------------------------------------------------------------------------
191
+
192
+ def update( # type: ignore[override]
193
+ self,
194
+ key_states: torch.Tensor,
195
+ value_states: torch.Tensor,
196
+ cache_kwargs: dict | None = None,
197
+ ) -> tuple[torch.Tensor, torch.Tensor]:
198
+ """Not supported — ShramLayerCache has no composite update interface.
199
+
200
+ The two sub-caches have materially different update semantics: the sliding-window
201
+ side uses standard key/value concatenation while the MoSRAH side uses expert-choice
202
+ scatter with an active mask. Callers must update each sub-cache directly via
203
+ sliding_window_cache.update() or mosrah_cache.update().
204
+ """
205
+ raise NotImplementedError(
206
+ "ShramLayerCache has no composite update interface. "
207
+ "Update sliding_window_cache or mosrah_cache directly."
208
+ )
209
+
210
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
211
+ """Not supported — the composite cache has no single maximum shape.
212
+
213
+ The sliding-window cache is bounded by sliding_window; the MoSRAH cache is
214
+ unbounded. No truthful scalar maximum represents the composite.
215
+ """
216
+ raise NotImplementedError(
217
+ "ShramLayerCache has no single maximum cache shape. "
218
+ "Query sliding_window_cache or mosrah_cache directly."
219
+ )
220
+
221
+ def get_mask_sizes( # type: ignore[override]
222
+ self,
223
+ cache_position: torch.Tensor,
224
+ ) -> tuple[int, int]:
225
+ """Not supported — ShramLayerCache does not participate in HF mask construction.
226
+
227
+ The two sub-caches have different mask semantics and their respective attention
228
+ paths handle masking directly.
229
+ """
230
+ raise NotImplementedError(
231
+ "ShramLayerCache does not support get_mask_sizes(). "
232
+ "Query sliding_window_cache or mosrah_cache directly."
233
+ )
__cache__sliding_window_cache.py CHANGED
@@ -1,289 +1,289 @@
1
- # src/shram/model/cache/sliding_window_cache.py
2
-
3
- """Local sliding-window cache for the SHRAM local attention path.
4
-
5
- This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by
6
- `ShramLayerCache` and consumed by `SlidingWindowAttention`.
7
-
8
- Its job is narrow:
9
-
10
- - accept the current chunk's local key/value tensors and active mask
11
- - return the current-step local frame consumed by local attention
12
- - separately retain the next-step sliding-window cache state
13
-
14
- It does not decide local causal visibility. That is owned by
15
- `SlidingWindowAttention`, which consumes the returned key/value/mask frame and
16
- constructs the effective local attention mask from it.
17
- """
18
-
19
- import torch
20
- from transformers.cache_utils import CacheLayerMixin
21
-
22
-
23
- class LocalSlidingWindowLayerCache(CacheLayerMixin):
24
- """Fixed-width local cache for one SHRAM decoder layer.
25
-
26
- The cache keeps a retained local sliding-window buffer and an aligned active
27
- mask. On update, it returns the current-step local frame formed by
28
- concatenating retained cache state with the new chunk, then remembers only
29
- the last `sliding_window` positions for the next step.
30
-
31
- Dead positions are allowed to remain in both the returned frame and the
32
- retained cache. Correctness is carried by the aligned active mask.
33
-
34
- Args:
35
- sliding_window: Width of the retained local sliding-window buffer.
36
- num_heads: Number of local attention heads.
37
- head_dim: Per-head embedding width for the local path.
38
- batch_size: Number of sequences in the batch.
39
- device: Device on which to allocate cache storage.
40
- """
41
-
42
- is_compileable = False
43
- is_sliding = True
44
-
45
- def __init__(
46
- self,
47
- sliding_window: int,
48
- num_heads: int,
49
- head_dim: int,
50
- batch_size: int,
51
- device: torch.device,
52
- ) -> None:
53
- super().__init__()
54
-
55
- if sliding_window < 1:
56
- raise ValueError(
57
- f"sliding_window must be >= 1, got {sliding_window}."
58
- )
59
- if num_heads < 1:
60
- raise ValueError(f"num_heads must be >= 1, got {num_heads}.")
61
- if head_dim < 1:
62
- raise ValueError(f"head_dim must be >= 1, got {head_dim}.")
63
- if batch_size < 1:
64
- raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
65
-
66
- self.sliding_window = sliding_window
67
- self.num_heads = num_heads
68
- self.head_dim = head_dim
69
- self.batch_size = batch_size
70
- self.device = device
71
-
72
- # Retained next-step local cache state. Storage is fixed-width from the
73
- # start; semantic validity is carried by `active_mask`.
74
- self.keys = torch.zeros(
75
- batch_size,
76
- num_heads,
77
- sliding_window,
78
- head_dim,
79
- device=device,
80
- )
81
- self.values = torch.zeros(
82
- batch_size,
83
- num_heads,
84
- sliding_window,
85
- head_dim,
86
- device=device,
87
- )
88
- self.active_mask = torch.zeros(
89
- batch_size,
90
- sliding_window,
91
- dtype=torch.bool,
92
- device=device,
93
- )
94
-
95
- self.is_initialized = True
96
-
97
- # Cumulative count of all token positions presented through update() for
98
- # this cache instance. This is the quantity HuggingFace generation reads
99
- # through get_seq_length() to track how far along the sequence we are.
100
- self._total_processed: int = 0
101
-
102
- def update( # type: ignore[override]
103
- self,
104
- key_states: torch.Tensor,
105
- value_states: torch.Tensor,
106
- active_mask: torch.Tensor,
107
- cache_kwargs: dict | None = None,
108
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
- """Return the current-step local frame and retain the next-step window.
110
-
111
- Args:
112
- key_states: Shape `(B, H, T_new, D)` local key vectors for the
113
- current chunk.
114
- value_states: Shape `(B, H, T_new, D)` local value vectors for the
115
- current chunk.
116
- active_mask: Shape `(B, T_new)` bool. `True` means the
117
- corresponding token position in the current chunk is active.
118
- cache_kwargs: Present only to satisfy the `CacheLayerMixin`
119
- interface. Unused by this cache.
120
-
121
- Returns:
122
- Tuple of:
123
- - visible_keys: `(B, H, sliding_window + T_new, D)`
124
- - visible_values: `(B, H, sliding_window + T_new, D)`
125
- - visible_active_mask: `(B, sliding_window + T_new)`
126
-
127
- These are the tensors the local attention path should consume
128
- directly for the current step.
129
- """
130
- self._ensure_state_compatibility(
131
- key_states=key_states,
132
- value_states=value_states,
133
- )
134
-
135
- # The current-step local frame is just retained cache state followed by
136
- # the current chunk in chronological order.
137
- composite_keys, composite_values, composite_mask = self._make_composite_frame(
138
- key_states=key_states,
139
- value_states=value_states,
140
- active_mask=active_mask,
141
- )
142
-
143
- # The cache remembers only the last raw sliding-window positions of that
144
- # composite frame for the next step. Dead positions are allowed to
145
- # survive; downstream local attention will ignore them using the mask.
146
- self._retain_next_window(
147
- composite_keys=composite_keys,
148
- composite_values=composite_values,
149
- composite_mask=composite_mask,
150
- )
151
-
152
- self._total_processed += key_states.shape[2]
153
-
154
- return composite_keys, composite_values, composite_mask
155
-
156
- def _ensure_state_compatibility(
157
- self,
158
- key_states: torch.Tensor,
159
- value_states: torch.Tensor,
160
- ) -> None:
161
- """Keep retained cache buffers compatible with the incoming update tensors.
162
-
163
- The cache is allocated eagerly for simplicity. If later updates arrive on
164
- a different device or in a different floating dtype, move the retained
165
- state to match while preserving its contents.
166
- """
167
- if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device:
168
- self.keys = self.keys.to(
169
- device=key_states.device,
170
- dtype=key_states.dtype,
171
- )
172
-
173
- if (
174
- self.values.dtype != value_states.dtype
175
- or self.values.device != value_states.device
176
- ):
177
- self.values = self.values.to(
178
- device=value_states.device,
179
- dtype=value_states.dtype,
180
- )
181
-
182
- if self.active_mask.device != key_states.device:
183
- self.active_mask = self.active_mask.to(
184
- key_states.device,
185
- non_blocking=True,
186
- )
187
-
188
- def _make_composite_frame(
189
- self,
190
- key_states: torch.Tensor,
191
- value_states: torch.Tensor,
192
- active_mask: torch.Tensor,
193
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
- """Build the current-step local frame in chronological order."""
195
- return (
196
- torch.cat([self.keys, key_states], dim=-2),
197
- torch.cat([self.values, value_states], dim=-2),
198
- torch.cat([self.active_mask, active_mask], dim=-1),
199
- )
200
-
201
- def _retain_next_window(
202
- self,
203
- composite_keys: torch.Tensor,
204
- composite_values: torch.Tensor,
205
- composite_mask: torch.Tensor,
206
- ) -> None:
207
- """Remember the next-step retained local state.
208
-
209
- This is a raw positional trim to the last `sliding_window` positions, not
210
- a semantic live-token trim.
211
- """
212
- self.keys = composite_keys[:, :, -self.sliding_window :, :]
213
- self.values = composite_values[:, :, -self.sliding_window :, :]
214
- self.active_mask = composite_mask[:, -self.sliding_window :]
215
-
216
- def get_seq_length(self) -> int:
217
- """Return the cumulative number of token positions processed by this cache.
218
-
219
- This is the total count of token positions presented across all update()
220
- calls since construction or the last reset(). It is the quantity HuggingFace
221
- generation reads to track sequence progress and is not the same as active-token
222
- count or current window occupancy.
223
- """
224
- return self._total_processed
225
-
226
- def get_max_cache_shape(self) -> int:
227
- return self.sliding_window
228
-
229
- def get_mask_sizes( # type: ignore[override]
230
- self,
231
- cache_position: torch.Tensor,
232
- ) -> tuple[int, int]:
233
- raise NotImplementedError(
234
- "LocalSlidingWindowLayerCache does not support get_mask_sizes()."
235
- )
236
-
237
- def reset(self) -> None:
238
- """Restore fresh-cache behavior."""
239
- self.keys.zero_()
240
- self.values.zero_()
241
- self.active_mask.zero_()
242
- self._total_processed = 0
243
-
244
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
245
- """Reorder the batch dimension for beam search."""
246
- self.keys = self.keys[beam_idx]
247
- self.values = self.values[beam_idx]
248
- self.active_mask = self.active_mask[beam_idx]
249
-
250
- def batch_repeat_interleave(self, repeats: int) -> None:
251
- """Expand the batch dimension for beam-search initialisation."""
252
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
253
- self.values = self.values.repeat_interleave(repeats, dim=0)
254
- self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
255
- self.batch_size = self.batch_size * repeats
256
-
257
- def batch_select_indices(self, indices: torch.Tensor) -> None:
258
- """Select a subset of batch entries for contrastive search."""
259
- self.keys = self.keys[indices]
260
- self.values = self.values[indices]
261
- self.active_mask = self.active_mask[indices]
262
- self.batch_size = int(indices.shape[0])
263
-
264
- def offload(self) -> None:
265
- """Offload cache tensors to CPU."""
266
- super().offload()
267
- self.active_mask = self.active_mask.to("cpu", non_blocking=True)
268
-
269
- def prefetch(self) -> None:
270
- """Move cache tensors back to the model device ahead of time."""
271
- super().prefetch()
272
- if self.active_mask.device != self.keys.device:
273
- self.active_mask = self.active_mask.to(
274
- self.keys.device,
275
- non_blocking=True,
276
- )
277
-
278
- def crop(self, max_length: int) -> None:
279
- raise NotImplementedError(
280
- "LocalSlidingWindowLayerCache does not support crop()."
281
- )
282
-
283
- def lazy_initialization(
284
- self,
285
- key_states: torch.Tensor,
286
- value_states: torch.Tensor,
287
- ) -> None:
288
- """No-op — this cache allocates its fixed buffers at construction time."""
289
  return
 
1
+ # src/shram/model/cache/sliding_window_cache.py
2
+
3
+ """Local sliding-window cache for the SHRAM local attention path.
4
+
5
+ This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by
6
+ `ShramLayerCache` and consumed by `SlidingWindowAttention`.
7
+
8
+ Its job is narrow:
9
+
10
+ - accept the current chunk's local key/value tensors and active mask
11
+ - return the current-step local frame consumed by local attention
12
+ - separately retain the next-step sliding-window cache state
13
+
14
+ It does not decide local causal visibility. That is owned by
15
+ `SlidingWindowAttention`, which consumes the returned key/value/mask frame and
16
+ constructs the effective local attention mask from it.
17
+ """
18
+
19
+ import torch
20
+ from transformers.cache_utils import CacheLayerMixin
21
+
22
+
23
+ class LocalSlidingWindowLayerCache(CacheLayerMixin):
24
+ """Fixed-width local cache for one SHRAM decoder layer.
25
+
26
+ The cache keeps a retained local sliding-window buffer and an aligned active
27
+ mask. On update, it returns the current-step local frame formed by
28
+ concatenating retained cache state with the new chunk, then remembers only
29
+ the last `sliding_window` positions for the next step.
30
+
31
+ Dead positions are allowed to remain in both the returned frame and the
32
+ retained cache. Correctness is carried by the aligned active mask.
33
+
34
+ Args:
35
+ sliding_window: Width of the retained local sliding-window buffer.
36
+ num_heads: Number of local attention heads.
37
+ head_dim: Per-head embedding width for the local path.
38
+ batch_size: Number of sequences in the batch.
39
+ device: Device on which to allocate cache storage.
40
+ """
41
+
42
+ is_compileable = False
43
+ is_sliding = True
44
+
45
+ def __init__(
46
+ self,
47
+ sliding_window: int,
48
+ num_heads: int,
49
+ head_dim: int,
50
+ batch_size: int,
51
+ device: torch.device,
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ if sliding_window < 1:
56
+ raise ValueError(
57
+ f"sliding_window must be >= 1, got {sliding_window}."
58
+ )
59
+ if num_heads < 1:
60
+ raise ValueError(f"num_heads must be >= 1, got {num_heads}.")
61
+ if head_dim < 1:
62
+ raise ValueError(f"head_dim must be >= 1, got {head_dim}.")
63
+ if batch_size < 1:
64
+ raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
65
+
66
+ self.sliding_window = sliding_window
67
+ self.num_heads = num_heads
68
+ self.head_dim = head_dim
69
+ self.batch_size = batch_size
70
+ self.device = device
71
+
72
+ # Retained next-step local cache state. Storage is fixed-width from the
73
+ # start; semantic validity is carried by `active_mask`.
74
+ self.keys = torch.zeros(
75
+ batch_size,
76
+ num_heads,
77
+ sliding_window,
78
+ head_dim,
79
+ device=device,
80
+ )
81
+ self.values = torch.zeros(
82
+ batch_size,
83
+ num_heads,
84
+ sliding_window,
85
+ head_dim,
86
+ device=device,
87
+ )
88
+ self.active_mask = torch.zeros(
89
+ batch_size,
90
+ sliding_window,
91
+ dtype=torch.bool,
92
+ device=device,
93
+ )
94
+
95
+ self.is_initialized = True
96
+
97
+ # Cumulative count of all token positions presented through update() for
98
+ # this cache instance. This is the quantity HuggingFace generation reads
99
+ # through get_seq_length() to track how far along the sequence we are.
100
+ self._total_processed: int = 0
101
+
102
+ def update( # type: ignore[override]
103
+ self,
104
+ key_states: torch.Tensor,
105
+ value_states: torch.Tensor,
106
+ active_mask: torch.Tensor,
107
+ cache_kwargs: dict | None = None,
108
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
+ """Return the current-step local frame and retain the next-step window.
110
+
111
+ Args:
112
+ key_states: Shape `(B, H, T_new, D)` local key vectors for the
113
+ current chunk.
114
+ value_states: Shape `(B, H, T_new, D)` local value vectors for the
115
+ current chunk.
116
+ active_mask: Shape `(B, T_new)` bool. `True` means the
117
+ corresponding token position in the current chunk is active.
118
+ cache_kwargs: Present only to satisfy the `CacheLayerMixin`
119
+ interface. Unused by this cache.
120
+
121
+ Returns:
122
+ Tuple of:
123
+ - visible_keys: `(B, H, sliding_window + T_new, D)`
124
+ - visible_values: `(B, H, sliding_window + T_new, D)`
125
+ - visible_active_mask: `(B, sliding_window + T_new)`
126
+
127
+ These are the tensors the local attention path should consume
128
+ directly for the current step.
129
+ """
130
+ self._ensure_state_compatibility(
131
+ key_states=key_states,
132
+ value_states=value_states,
133
+ )
134
+
135
+ # The current-step local frame is just retained cache state followed by
136
+ # the current chunk in chronological order.
137
+ composite_keys, composite_values, composite_mask = self._make_composite_frame(
138
+ key_states=key_states,
139
+ value_states=value_states,
140
+ active_mask=active_mask,
141
+ )
142
+
143
+ # The cache remembers only the last raw sliding-window positions of that
144
+ # composite frame for the next step. Dead positions are allowed to
145
+ # survive; downstream local attention will ignore them using the mask.
146
+ self._retain_next_window(
147
+ composite_keys=composite_keys,
148
+ composite_values=composite_values,
149
+ composite_mask=composite_mask,
150
+ )
151
+
152
+ self._total_processed += key_states.shape[2]
153
+
154
+ return composite_keys, composite_values, composite_mask
155
+
156
+ def _ensure_state_compatibility(
157
+ self,
158
+ key_states: torch.Tensor,
159
+ value_states: torch.Tensor,
160
+ ) -> None:
161
+ """Keep retained cache buffers compatible with the incoming update tensors.
162
+
163
+ The cache is allocated eagerly for simplicity. If later updates arrive on
164
+ a different device or in a different floating dtype, move the retained
165
+ state to match while preserving its contents.
166
+ """
167
+ if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device:
168
+ self.keys = self.keys.to(
169
+ device=key_states.device,
170
+ dtype=key_states.dtype,
171
+ )
172
+
173
+ if (
174
+ self.values.dtype != value_states.dtype
175
+ or self.values.device != value_states.device
176
+ ):
177
+ self.values = self.values.to(
178
+ device=value_states.device,
179
+ dtype=value_states.dtype,
180
+ )
181
+
182
+ if self.active_mask.device != key_states.device:
183
+ self.active_mask = self.active_mask.to(
184
+ key_states.device,
185
+ non_blocking=True,
186
+ )
187
+
188
+ def _make_composite_frame(
189
+ self,
190
+ key_states: torch.Tensor,
191
+ value_states: torch.Tensor,
192
+ active_mask: torch.Tensor,
193
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
+ """Build the current-step local frame in chronological order."""
195
+ return (
196
+ torch.cat([self.keys, key_states], dim=-2),
197
+ torch.cat([self.values, value_states], dim=-2),
198
+ torch.cat([self.active_mask, active_mask], dim=-1),
199
+ )
200
+
201
+ def _retain_next_window(
202
+ self,
203
+ composite_keys: torch.Tensor,
204
+ composite_values: torch.Tensor,
205
+ composite_mask: torch.Tensor,
206
+ ) -> None:
207
+ """Remember the next-step retained local state.
208
+
209
+ This is a raw positional trim to the last `sliding_window` positions, not
210
+ a semantic live-token trim.
211
+ """
212
+ self.keys = composite_keys[:, :, -self.sliding_window :, :]
213
+ self.values = composite_values[:, :, -self.sliding_window :, :]
214
+ self.active_mask = composite_mask[:, -self.sliding_window :]
215
+
216
+ def get_seq_length(self) -> int:
217
+ """Return the cumulative number of token positions processed by this cache.
218
+
219
+ This is the total count of token positions presented across all update()
220
+ calls since construction or the last reset(). It is the quantity HuggingFace
221
+ generation reads to track sequence progress and is not the same as active-token
222
+ count or current window occupancy.
223
+ """
224
+ return self._total_processed
225
+
226
+ def get_max_cache_shape(self) -> int:
227
+ return self.sliding_window
228
+
229
+ def get_mask_sizes( # type: ignore[override]
230
+ self,
231
+ cache_position: torch.Tensor,
232
+ ) -> tuple[int, int]:
233
+ raise NotImplementedError(
234
+ "LocalSlidingWindowLayerCache does not support get_mask_sizes()."
235
+ )
236
+
237
+ def reset(self) -> None:
238
+ """Restore fresh-cache behavior."""
239
+ self.keys.zero_()
240
+ self.values.zero_()
241
+ self.active_mask.zero_()
242
+ self._total_processed = 0
243
+
244
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
245
+ """Reorder the batch dimension for beam search."""
246
+ self.keys = self.keys[beam_idx]
247
+ self.values = self.values[beam_idx]
248
+ self.active_mask = self.active_mask[beam_idx]
249
+
250
+ def batch_repeat_interleave(self, repeats: int) -> None:
251
+ """Expand the batch dimension for beam-search initialisation."""
252
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
253
+ self.values = self.values.repeat_interleave(repeats, dim=0)
254
+ self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
255
+ self.batch_size = self.batch_size * repeats
256
+
257
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
258
+ """Select a subset of batch entries for contrastive search."""
259
+ self.keys = self.keys[indices]
260
+ self.values = self.values[indices]
261
+ self.active_mask = self.active_mask[indices]
262
+ self.batch_size = int(indices.shape[0])
263
+
264
+ def offload(self) -> None:
265
+ """Offload cache tensors to CPU."""
266
+ super().offload()
267
+ self.active_mask = self.active_mask.to("cpu", non_blocking=True)
268
+
269
+ def prefetch(self) -> None:
270
+ """Move cache tensors back to the model device ahead of time."""
271
+ super().prefetch()
272
+ if self.active_mask.device != self.keys.device:
273
+ self.active_mask = self.active_mask.to(
274
+ self.keys.device,
275
+ non_blocking=True,
276
+ )
277
+
278
+ def crop(self, max_length: int) -> None:
279
+ raise NotImplementedError(
280
+ "LocalSlidingWindowLayerCache does not support crop()."
281
+ )
282
+
283
+ def lazy_initialization(
284
+ self,
285
+ key_states: torch.Tensor,
286
+ value_states: torch.Tensor,
287
+ ) -> None:
288
+ """No-op — this cache allocates its fixed buffers at construction time."""
289
  return
__cache__slow_mosrah_cache.py CHANGED
@@ -1,321 +1,321 @@
1
- """Unvectorized reference implementation of the MoSRAH sparse KV cache.
2
-
3
- This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
4
- interface and storage layout as MoSRAHCache but uses an explicit Python loop over
5
- (b, l, t) triples in update(). The loop is obviously correct by inspection: each active
6
- position's key and value are written to the next available slot for that (batch, head)
7
- pair, in the order positions appear along the T dimension, which directly enforces
8
- causal ordering without any index arithmetic to verify.
9
-
10
- SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
11
- trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
12
- Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
13
- vectorized implementation is validated by asserting exact agreement with this one on all
14
- test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
15
- (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
16
- an oracle.
17
- """
18
-
19
- import torch
20
- from transformers.cache_utils import CacheLayerMixin
21
-
22
-
23
- class SlowMoSRAHCache(CacheLayerMixin):
24
- """Unvectorized reference implementation of the MoSRAH KV cache.
25
-
26
- Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
27
- mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
28
- with the same constructor signature and the same CacheLayerMixin protocol methods.
29
- The sole difference is update(), which uses an explicit Python loop over (b, l, t)
30
- triples rather than vectorized index arithmetic.
31
-
32
- This class is not used in the model path. It exists so that MoSRAHCache.update()
33
- can be validated by asserting exact agreement with this implementation on all test
34
- inputs. See module docstring for the trust chain this enables.
35
-
36
- Args:
37
- num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
38
- second dimension of all storage tensors.
39
- head_dim: Bottlenecked head embedding width (u). Determines the fourth
40
- dimension of all storage tensors.
41
- batch_size: Number of sequences in the batch. Determines the first dimension
42
- of all storage tensors.
43
- device: Device on which to allocate all tensors. Should match the model device.
44
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
45
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
46
- during prompt processing.
47
- """
48
-
49
- is_compileable = False
50
- is_sliding = False
51
-
52
- def __init__(
53
- self,
54
- num_mosrah_heads: int,
55
- head_dim: int,
56
- batch_size: int,
57
- device: torch.device,
58
- initial_buffer_size: int = 64,
59
- ) -> None:
60
- super().__init__()
61
- self.num_mosrah_heads = num_mosrah_heads
62
- self.head_dim = head_dim
63
- self.batch_size = batch_size
64
- self.device = device
65
-
66
- # Allocate primary storage into the mixin-standard self.keys / self.values so
67
- # that inherited methods (offload, prefetch) operate on real tensors. _counts
68
- # tracks valid occupancy per (batch, head) slot.
69
- self.keys: torch.Tensor = torch.zeros(
70
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
71
- )
72
- self.values: torch.Tensor = torch.zeros(
73
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
74
- )
75
- self._counts: torch.Tensor = torch.zeros(
76
- batch_size, num_mosrah_heads, dtype=torch.long, device=device
77
- )
78
-
79
- # Storage is fully allocated at construction — the cache is initialized.
80
- self.is_initialized = True
81
-
82
- # ---------------------------------------------------------------------------
83
- # Properties
84
- # ---------------------------------------------------------------------------
85
-
86
- @property
87
- def buffer_capacity(self) -> int:
88
- """Current number of slots allocated per (batch, head) pair.
89
-
90
- Derived directly from self.keys rather than tracked separately, so it is
91
- always consistent with the actual buffer after expansion.
92
- """
93
- return self.keys.shape[2]
94
-
95
- # ---------------------------------------------------------------------------
96
- # Primary API
97
- # ---------------------------------------------------------------------------
98
-
99
- def update( # type: ignore[override]
100
- self,
101
- key_states: torch.Tensor,
102
- value_states: torch.Tensor,
103
- active_mask: torch.Tensor,
104
- cache_kwargs: dict | None = None,
105
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
106
- """Scatter active key/value states using an explicit loop; return full cache state.
107
-
108
- Iterates over every (b, l, t) triple. For each position where active_mask is
109
- True, the key and value are written to the next available slot for that
110
- (batch, head) pair and the count is incremented. Causal ordering is guaranteed
111
- because the t dimension is traversed from 0 to T-1 and counts are updated
112
- immediately after each write.
113
-
114
- Buffer expansion (doubling buffer_capacity) is triggered before any writes if
115
- the incoming tokens would cause any slot to overflow the current capacity.
116
-
117
- Args:
118
- key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
119
- value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
120
- active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
121
- cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
122
-
123
- Returns:
124
- Tuple of (keys, values, active_mask):
125
- keys: (B, L, T, u) float — full key buffer including junk slots.
126
- values: (B, L, T, u) float — full value buffer including junk slots.
127
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
128
- """
129
- B, L, T = active_mask.shape
130
-
131
- # Expansion check uses the total active tokens per slot, same as the
132
- # vectorized implementation, so both expand under identical conditions.
133
- incoming_delta = active_mask.long().sum(dim=2) # (B, L)
134
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
135
- self._expand()
136
-
137
- # Write each active position into the next available slot for its (batch, head)
138
- # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
139
- for b in range(B):
140
- for l in range(L):
141
- for t in range(T):
142
- if active_mask[b, l, t]:
143
- pos = self._counts[b, l].item()
144
- self.keys[b, l, pos, :] = key_states[b, l, t, :]
145
- self.values[b, l, pos, :] = value_states[b, l, t, :]
146
- self._counts[b, l] += 1
147
-
148
- return self.keys, self.values, self._make_active_mask()
149
-
150
- def get_heads_lengths(self) -> torch.Tensor:
151
- """Return the per-(batch, head) token count for this layer.
152
-
153
- This is the authoritative occupancy tensor consumed by BEA for attention
154
- masking and by position computation (Unit 10.A) for semantic-sequence
155
- position computation.
156
-
157
- Returns:
158
- Integer tensor of shape (B, L) where entry [b, h] is the number of valid
159
- tokens stored in the (b, h) slot. Zero for slots with no writes yet.
160
- """
161
- return self._counts
162
-
163
- # ---------------------------------------------------------------------------
164
- # CacheLayerMixin — overridden coordination methods
165
- # ---------------------------------------------------------------------------
166
-
167
- def reset(self) -> None:
168
- """Clear all cached key and value tensors.
169
-
170
- Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
171
- and is_initialized remains True — only the contents are cleared.
172
- """
173
- self.keys.zero_()
174
- self.values.zero_()
175
- self._counts.zero_()
176
-
177
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
178
- """Reorder the batch dimension of all cached tensors for beam search.
179
-
180
- Applied atomically across self.keys, self.values, and _counts. Beam search
181
- must reorder all three together or the occupancy counts and buffer contents
182
- will correspond to different beam hypotheses.
183
-
184
- Overrides the parent because the parent's implementation calls get_seq_length(),
185
- which is not supported for this cache.
186
-
187
- Args:
188
- beam_idx: Permutation indices of shape (batch,) produced by the beam
189
- search algorithm.
190
- """
191
- self.keys = self.keys[beam_idx]
192
- self.values = self.values[beam_idx]
193
- self._counts = self._counts[beam_idx]
194
-
195
- def batch_repeat_interleave(self, repeats: int) -> None:
196
- """Expand the batch dimension by repeating each entry repeats times.
197
-
198
- Used at beam search initialisation to expand the cache from batch size B to
199
- B * repeats, matching the expanded beam candidate batch. Applied atomically
200
- across keys, values, and _counts; batch_size is updated to reflect the new size.
201
-
202
- Args:
203
- repeats: Number of times to repeat each batch entry.
204
- """
205
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
206
- self.values = self.values.repeat_interleave(repeats, dim=0)
207
- self._counts = self._counts.repeat_interleave(repeats, dim=0)
208
- self.batch_size = self.batch_size * repeats
209
-
210
- def batch_select_indices(self, indices: torch.Tensor) -> None:
211
- """Select a subset of batch entries by index.
212
-
213
- Used in contrastive search to retain only the selected candidate entries.
214
- Applied atomically across keys, values, and _counts; batch_size is updated
215
- to reflect the number of retained entries.
216
-
217
- Args:
218
- indices: 1-D integer tensor of batch indices to retain.
219
- """
220
- self.keys = self.keys[indices]
221
- self.values = self.values[indices]
222
- self._counts = self._counts[indices]
223
- self.batch_size = indices.shape[0]
224
-
225
- def offload(self) -> None:
226
- """Offload all cached tensors to CPU.
227
-
228
- Extends the parent to also offload _counts, which the parent does not know
229
- about. All three tensors are moved atomically so device state remains consistent.
230
- """
231
- super().offload()
232
- self._counts = self._counts.to("cpu", non_blocking=True)
233
-
234
- def prefetch(self) -> None:
235
- """Move all cached tensors back to the model device ahead of time.
236
-
237
- Extends the parent to also prefetch _counts, which the parent does not know
238
- about. _counts is synced to self.keys.device after the parent moves keys and
239
- values, so all three remain consistent.
240
- """
241
- super().prefetch()
242
- if self._counts.device != self.keys.device:
243
- self._counts = self._counts.to(self.keys.device, non_blocking=True)
244
-
245
- def lazy_initialization( # type: ignore[override]
246
- self, key_states: torch.Tensor, value_states: torch.Tensor
247
- ) -> None:
248
- """No-op — storage is fully allocated at construction time."""
249
- pass
250
-
251
- # ---------------------------------------------------------------------------
252
- # CacheLayerMixin — unsupported abstract methods
253
- # ---------------------------------------------------------------------------
254
-
255
- def get_seq_length(self) -> int: # type: ignore[override]
256
- """Not supported — no single sequence length represents this cache's state.
257
-
258
- MoSRAH heads accumulate independently; (batch, head) slots have different
259
- lengths depending on routing history. There is no meaningful scalar summary.
260
- Use get_heads_lengths() for per-head occupancy.
261
- """
262
- raise NotImplementedError(
263
- "SlowMoSRAHCache has no single sequence length. "
264
- "Use get_heads_lengths() for per-head occupancy."
265
- )
266
-
267
- def get_max_cache_shape(self) -> int: # type: ignore[override]
268
- """Not supported — SlowMoSRAHCache is dynamic and unbounded."""
269
- raise NotImplementedError(
270
- "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
271
- )
272
-
273
- def get_mask_sizes( # type: ignore[override]
274
- self,
275
- cache_position: torch.Tensor,
276
- ) -> tuple[int, int]:
277
- """Not supported — SlowMoSRAHCache does not participate in HF mask construction."""
278
- raise NotImplementedError(
279
- "SlowMoSRAHCache does not support get_mask_sizes()."
280
- )
281
-
282
- # ---------------------------------------------------------------------------
283
- # Internal helpers
284
- # ---------------------------------------------------------------------------
285
-
286
- def _make_active_mask(self) -> torch.Tensor:
287
- """Construct the (B, L, T) active mask from current counts.
288
-
289
- Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
290
- has been written. Positions at or beyond the count are junk and must be
291
- excluded by downstream attention.
292
- """
293
- cap = self.buffer_capacity
294
- return (
295
- torch.arange(cap, device=self.keys.device)
296
- .expand(self.batch_size, self.num_mosrah_heads, cap)
297
- < self._counts.unsqueeze(-1)
298
- )
299
-
300
- def _expand(self) -> None:
301
- """Double the buffer capacity, preserving existing data.
302
-
303
- Called by update() when an incoming batch of tokens would cause any
304
- (batch, head) slot to exceed the current buffer capacity. All existing
305
- key and value data is copied into the low half of the new buffer; the
306
- high half is zero-initialised and will be filled by subsequent writes.
307
- After reassignment, buffer_capacity reflects the new size automatically.
308
- """
309
- old_cap = self.buffer_capacity
310
- new_cap = old_cap * 2
311
- dev = self.keys.device
312
- new_keys = torch.zeros(
313
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
314
- )
315
- new_values = torch.zeros(
316
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
317
- )
318
- new_keys[:, :, :old_cap, :] = self.keys
319
- new_values[:, :, :old_cap, :] = self.values
320
- self.keys = new_keys
321
- self.values = new_values
 
1
+ """Unvectorized reference implementation of the MoSRAH sparse KV cache.
2
+
3
+ This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
4
+ interface and storage layout as MoSRAHCache but uses an explicit Python loop over
5
+ (b, l, t) triples in update(). The loop is obviously correct by inspection: each active
6
+ position's key and value are written to the next available slot for that (batch, head)
7
+ pair, in the order positions appear along the T dimension, which directly enforces
8
+ causal ordering without any index arithmetic to verify.
9
+
10
+ SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
11
+ trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
12
+ Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
13
+ vectorized implementation is validated by asserting exact agreement with this one on all
14
+ test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
15
+ (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
16
+ an oracle.
17
+ """
18
+
19
+ import torch
20
+ from transformers.cache_utils import CacheLayerMixin
21
+
22
+
23
+ class SlowMoSRAHCache(CacheLayerMixin):
24
+ """Unvectorized reference implementation of the MoSRAH KV cache.
25
+
26
+ Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
27
+ mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
28
+ with the same constructor signature and the same CacheLayerMixin protocol methods.
29
+ The sole difference is update(), which uses an explicit Python loop over (b, l, t)
30
+ triples rather than vectorized index arithmetic.
31
+
32
+ This class is not used in the model path. It exists so that MoSRAHCache.update()
33
+ can be validated by asserting exact agreement with this implementation on all test
34
+ inputs. See module docstring for the trust chain this enables.
35
+
36
+ Args:
37
+ num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
38
+ second dimension of all storage tensors.
39
+ head_dim: Bottlenecked head embedding width (u). Determines the fourth
40
+ dimension of all storage tensors.
41
+ batch_size: Number of sequences in the batch. Determines the first dimension
42
+ of all storage tensors.
43
+ device: Device on which to allocate all tensors. Should match the model device.
44
+ initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
45
+ when any slot overflows. Defaults to 64 to avoid repeated reallocation
46
+ during prompt processing.
47
+ """
48
+
49
+ is_compileable = False
50
+ is_sliding = False
51
+
52
+ def __init__(
53
+ self,
54
+ num_mosrah_heads: int,
55
+ head_dim: int,
56
+ batch_size: int,
57
+ device: torch.device,
58
+ initial_buffer_size: int = 64,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.num_mosrah_heads = num_mosrah_heads
62
+ self.head_dim = head_dim
63
+ self.batch_size = batch_size
64
+ self.device = device
65
+
66
+ # Allocate primary storage into the mixin-standard self.keys / self.values so
67
+ # that inherited methods (offload, prefetch) operate on real tensors. _counts
68
+ # tracks valid occupancy per (batch, head) slot.
69
+ self.keys: torch.Tensor = torch.zeros(
70
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
71
+ )
72
+ self.values: torch.Tensor = torch.zeros(
73
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
74
+ )
75
+ self._counts: torch.Tensor = torch.zeros(
76
+ batch_size, num_mosrah_heads, dtype=torch.long, device=device
77
+ )
78
+
79
+ # Storage is fully allocated at construction — the cache is initialized.
80
+ self.is_initialized = True
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Properties
84
+ # ---------------------------------------------------------------------------
85
+
86
+ @property
87
+ def buffer_capacity(self) -> int:
88
+ """Current number of slots allocated per (batch, head) pair.
89
+
90
+ Derived directly from self.keys rather than tracked separately, so it is
91
+ always consistent with the actual buffer after expansion.
92
+ """
93
+ return self.keys.shape[2]
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Primary API
97
+ # ---------------------------------------------------------------------------
98
+
99
+ def update( # type: ignore[override]
100
+ self,
101
+ key_states: torch.Tensor,
102
+ value_states: torch.Tensor,
103
+ active_mask: torch.Tensor,
104
+ cache_kwargs: dict | None = None,
105
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
106
+ """Scatter active key/value states using an explicit loop; return full cache state.
107
+
108
+ Iterates over every (b, l, t) triple. For each position where active_mask is
109
+ True, the key and value are written to the next available slot for that
110
+ (batch, head) pair and the count is incremented. Causal ordering is guaranteed
111
+ because the t dimension is traversed from 0 to T-1 and counts are updated
112
+ immediately after each write.
113
+
114
+ Buffer expansion (doubling buffer_capacity) is triggered before any writes if
115
+ the incoming tokens would cause any slot to overflow the current capacity.
116
+
117
+ Args:
118
+ key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
119
+ value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
120
+ active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
121
+ cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
122
+
123
+ Returns:
124
+ Tuple of (keys, values, active_mask):
125
+ keys: (B, L, T, u) float — full key buffer including junk slots.
126
+ values: (B, L, T, u) float — full value buffer including junk slots.
127
+ active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
128
+ """
129
+ B, L, T = active_mask.shape
130
+
131
+ # Expansion check uses the total active tokens per slot, same as the
132
+ # vectorized implementation, so both expand under identical conditions.
133
+ incoming_delta = active_mask.long().sum(dim=2) # (B, L)
134
+ if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
135
+ self._expand()
136
+
137
+ # Write each active position into the next available slot for its (batch, head)
138
+ # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
139
+ for b in range(B):
140
+ for l in range(L):
141
+ for t in range(T):
142
+ if active_mask[b, l, t]:
143
+ pos = self._counts[b, l].item()
144
+ self.keys[b, l, pos, :] = key_states[b, l, t, :]
145
+ self.values[b, l, pos, :] = value_states[b, l, t, :]
146
+ self._counts[b, l] += 1
147
+
148
+ return self.keys, self.values, self._make_active_mask()
149
+
150
+ def get_heads_lengths(self) -> torch.Tensor:
151
+ """Return the per-(batch, head) token count for this layer.
152
+
153
+ This is the authoritative occupancy tensor consumed by BEA for attention
154
+ masking and by position computation (Unit 10.A) for semantic-sequence
155
+ position computation.
156
+
157
+ Returns:
158
+ Integer tensor of shape (B, L) where entry [b, h] is the number of valid
159
+ tokens stored in the (b, h) slot. Zero for slots with no writes yet.
160
+ """
161
+ return self._counts
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # CacheLayerMixin — overridden coordination methods
165
+ # ---------------------------------------------------------------------------
166
+
167
+ def reset(self) -> None:
168
+ """Clear all cached key and value tensors.
169
+
170
+ Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
171
+ and is_initialized remains True — only the contents are cleared.
172
+ """
173
+ self.keys.zero_()
174
+ self.values.zero_()
175
+ self._counts.zero_()
176
+
177
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
178
+ """Reorder the batch dimension of all cached tensors for beam search.
179
+
180
+ Applied atomically across self.keys, self.values, and _counts. Beam search
181
+ must reorder all three together or the occupancy counts and buffer contents
182
+ will correspond to different beam hypotheses.
183
+
184
+ Overrides the parent because the parent's implementation calls get_seq_length(),
185
+ which is not supported for this cache.
186
+
187
+ Args:
188
+ beam_idx: Permutation indices of shape (batch,) produced by the beam
189
+ search algorithm.
190
+ """
191
+ self.keys = self.keys[beam_idx]
192
+ self.values = self.values[beam_idx]
193
+ self._counts = self._counts[beam_idx]
194
+
195
+ def batch_repeat_interleave(self, repeats: int) -> None:
196
+ """Expand the batch dimension by repeating each entry repeats times.
197
+
198
+ Used at beam search initialisation to expand the cache from batch size B to
199
+ B * repeats, matching the expanded beam candidate batch. Applied atomically
200
+ across keys, values, and _counts; batch_size is updated to reflect the new size.
201
+
202
+ Args:
203
+ repeats: Number of times to repeat each batch entry.
204
+ """
205
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
206
+ self.values = self.values.repeat_interleave(repeats, dim=0)
207
+ self._counts = self._counts.repeat_interleave(repeats, dim=0)
208
+ self.batch_size = self.batch_size * repeats
209
+
210
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
211
+ """Select a subset of batch entries by index.
212
+
213
+ Used in contrastive search to retain only the selected candidate entries.
214
+ Applied atomically across keys, values, and _counts; batch_size is updated
215
+ to reflect the number of retained entries.
216
+
217
+ Args:
218
+ indices: 1-D integer tensor of batch indices to retain.
219
+ """
220
+ self.keys = self.keys[indices]
221
+ self.values = self.values[indices]
222
+ self._counts = self._counts[indices]
223
+ self.batch_size = indices.shape[0]
224
+
225
+ def offload(self) -> None:
226
+ """Offload all cached tensors to CPU.
227
+
228
+ Extends the parent to also offload _counts, which the parent does not know
229
+ about. All three tensors are moved atomically so device state remains consistent.
230
+ """
231
+ super().offload()
232
+ self._counts = self._counts.to("cpu", non_blocking=True)
233
+
234
+ def prefetch(self) -> None:
235
+ """Move all cached tensors back to the model device ahead of time.
236
+
237
+ Extends the parent to also prefetch _counts, which the parent does not know
238
+ about. _counts is synced to self.keys.device after the parent moves keys and
239
+ values, so all three remain consistent.
240
+ """
241
+ super().prefetch()
242
+ if self._counts.device != self.keys.device:
243
+ self._counts = self._counts.to(self.keys.device, non_blocking=True)
244
+
245
+ def lazy_initialization( # type: ignore[override]
246
+ self, key_states: torch.Tensor, value_states: torch.Tensor
247
+ ) -> None:
248
+ """No-op — storage is fully allocated at construction time."""
249
+ pass
250
+
251
+ # ---------------------------------------------------------------------------
252
+ # CacheLayerMixin — unsupported abstract methods
253
+ # ---------------------------------------------------------------------------
254
+
255
+ def get_seq_length(self) -> int: # type: ignore[override]
256
+ """Not supported — no single sequence length represents this cache's state.
257
+
258
+ MoSRAH heads accumulate independently; (batch, head) slots have different
259
+ lengths depending on routing history. There is no meaningful scalar summary.
260
+ Use get_heads_lengths() for per-head occupancy.
261
+ """
262
+ raise NotImplementedError(
263
+ "SlowMoSRAHCache has no single sequence length. "
264
+ "Use get_heads_lengths() for per-head occupancy."
265
+ )
266
+
267
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
268
+ """Not supported — SlowMoSRAHCache is dynamic and unbounded."""
269
+ raise NotImplementedError(
270
+ "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
271
+ )
272
+
273
+ def get_mask_sizes( # type: ignore[override]
274
+ self,
275
+ cache_position: torch.Tensor,
276
+ ) -> tuple[int, int]:
277
+ """Not supported — SlowMoSRAHCache does not participate in HF mask construction."""
278
+ raise NotImplementedError(
279
+ "SlowMoSRAHCache does not support get_mask_sizes()."
280
+ )
281
+
282
+ # ---------------------------------------------------------------------------
283
+ # Internal helpers
284
+ # ---------------------------------------------------------------------------
285
+
286
+ def _make_active_mask(self) -> torch.Tensor:
287
+ """Construct the (B, L, T) active mask from current counts.
288
+
289
+ Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
290
+ has been written. Positions at or beyond the count are junk and must be
291
+ excluded by downstream attention.
292
+ """
293
+ cap = self.buffer_capacity
294
+ return (
295
+ torch.arange(cap, device=self.keys.device)
296
+ .expand(self.batch_size, self.num_mosrah_heads, cap)
297
+ < self._counts.unsqueeze(-1)
298
+ )
299
+
300
+ def _expand(self) -> None:
301
+ """Double the buffer capacity, preserving existing data.
302
+
303
+ Called by update() when an incoming batch of tokens would cause any
304
+ (batch, head) slot to exceed the current buffer capacity. All existing
305
+ key and value data is copied into the low half of the new buffer; the
306
+ high half is zero-initialised and will be filled by subsequent writes.
307
+ After reassignment, buffer_capacity reflects the new size automatically.
308
+ """
309
+ old_cap = self.buffer_capacity
310
+ new_cap = old_cap * 2
311
+ dev = self.keys.device
312
+ new_keys = torch.zeros(
313
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
314
+ )
315
+ new_values = torch.zeros(
316
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
317
+ )
318
+ new_keys[:, :, :old_cap, :] = self.keys
319
+ new_values[:, :, :old_cap, :] = self.values
320
+ self.keys = new_keys
321
+ self.values = new_values
__init__.py CHANGED
@@ -1,21 +1,21 @@
1
- from .configuration import ShramConfig
2
- from .decoder_layer import DecoderLayer
3
- from .huggingface import ShramForCausalLM
4
- from .__attention__load_balance_loss import LoadBalanceLoss
5
- from .mlp import SwiGLUMLP
6
- from .model import ShramModel
7
- from .rope import RotaryEmbedding
8
- from .__attention__router import MoSRAHRouter
9
- from .__cache__mosrah_cache import MoSRAHCache
10
-
11
- __all__ = [
12
- "DecoderLayer",
13
- "LoadBalanceLoss",
14
- "MoSRAHCache",
15
- "MoSRAHRouter",
16
- "ShramConfig",
17
- "ShramForCausalLM",
18
- "ShramModel",
19
- "RotaryEmbedding",
20
- "SwiGLUMLP",
21
- ]
 
1
+ from .configuration import ShramConfig
2
+ from .decoder_layer import DecoderLayer
3
+ from .huggingface import ShramForCausalLM
4
+ from .__attention__load_balance_loss import LoadBalanceLoss
5
+ from .mlp import SwiGLUMLP
6
+ from .model import ShramModel
7
+ from .rope import RotaryEmbedding
8
+ from .__attention__router import MoSRAHRouter
9
+ from .__cache__mosrah_cache import MoSRAHCache
10
+
11
+ __all__ = [
12
+ "DecoderLayer",
13
+ "LoadBalanceLoss",
14
+ "MoSRAHCache",
15
+ "MoSRAHRouter",
16
+ "ShramConfig",
17
+ "ShramForCausalLM",
18
+ "ShramModel",
19
+ "RotaryEmbedding",
20
+ "SwiGLUMLP",
21
+ ]
config.json CHANGED
@@ -1,28 +1,28 @@
1
- {
2
- "alpha": 1.0,
3
- "attention_dropout": 0.0,
4
- "auto_map": {
5
- "AutoConfig": "configuration.ShramConfig",
6
- "AutoModelForCausalLM": "huggingface.ShramForCausalLM"
7
- },
8
- "beta": 32.0,
9
- "head_dim": 16,
10
- "hidden_size": 512,
11
- "inference_sequence_length": 1024,
12
- "intermediate_size": 1366,
13
- "local_rope_theta": 10000.0,
14
- "model_type": "shram",
15
- "mosrah_rope_theta": 10000.0,
16
- "num_hidden_layers": 12,
17
- "num_mosrah_heads": 16,
18
- "num_selected_heads": 16,
19
- "num_sliding_window_heads": 16,
20
- "rms_norm_eps": 1e-05,
21
- "rope_mode": "main_sequence",
22
- "tie_word_embeddings": false,
23
- "training_sequence_length": 1024,
24
- "transformers_version": "5.3.0",
25
- "use_cache": true,
26
- "vocab_size": 50277,
27
- "window_size": 128
28
- }
 
1
+ {
2
+ "alpha": 1.0,
3
+ "attention_dropout": 0.0,
4
+ "auto_map": {
5
+ "AutoConfig": "configuration.ShramConfig",
6
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM"
7
+ },
8
+ "beta": 32.0,
9
+ "head_dim": 16,
10
+ "hidden_size": 512,
11
+ "inference_sequence_length": 1024,
12
+ "intermediate_size": 1366,
13
+ "local_rope_theta": 10000.0,
14
+ "model_type": "shram",
15
+ "mosrah_rope_theta": 10000.0,
16
+ "num_hidden_layers": 12,
17
+ "num_mosrah_heads": 16,
18
+ "num_selected_heads": 16,
19
+ "num_sliding_window_heads": 16,
20
+ "rms_norm_eps": 1e-05,
21
+ "rope_mode": "main_sequence",
22
+ "tie_word_embeddings": false,
23
+ "training_sequence_length": 1024,
24
+ "transformers_version": "5.7.0",
25
+ "use_cache": true,
26
+ "vocab_size": 50277,
27
+ "window_size": 128
28
+ }
configuration.py CHANGED
@@ -1,199 +1,178 @@
1
- """Configuration for the SHRAM transformer.
2
-
3
- All architectural parameters that vary across model scales or are meaningful research
4
- variables are expressed here. Architectural constants (no bias in linear layers,
5
- SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
- documented at the point of use — they are not config parameters because they do not
7
- vary and changing them produces a different architecture, not a different scale.
8
-
9
- RoPE configuration is owned entirely by this config. Each attention path reads its
10
- parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
- HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
- """
13
-
14
- from transformers import PretrainedConfig
15
-
16
-
17
- class ShramConfig(PretrainedConfig):
18
- """Configuration class for the SHRAM decoder-only transformer.
19
-
20
- SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
21
- attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
22
- local sliding-window causal attention path and h_s is the MoSRAH sparse routed
23
- path. All other components follow the Llama 3 baseline.
24
-
25
- This config is the single source of truth for every architectural dimension of the
26
- model. Nothing in the architecture may use a literal number that belongs here.
27
-
28
- Two independent RoPE configurations exist — one per attention path:
29
-
30
- - h_l always uses standard RoPE with ``local_rope_theta``.
31
- - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
32
- ``inference_sequence_length``, ``alpha``, and ``beta``. When
33
- ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
34
- ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
35
- and the correct setting for experiments that do not require context extension.
36
-
37
- Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
38
-
39
- config = AutoConfig.from_pretrained(
40
- "your-namespace/advanced-transformers-lib",
41
- trust_remote_code=True,
42
- num_hidden_layers=12,
43
- )
44
- model = AutoModelForCausalLM.from_config(config)
45
-
46
- Args:
47
- vocab_size: Vocabulary size. Controls the embedding table and output logits
48
- dimension. Must match the tokenizer.
49
- embedding_width: Model width ``d``. The dimension of the residual stream.
50
- mlp_width: FFN hidden dimension.
51
- num_decoder_layers: Number of transformer blocks stacked in sequence.
52
- num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
53
- num_mosrah_heads: Total MoSRAH expert heads available ``L``.
54
- num_selected_heads: MoSRAH heads each token selects ``K``.
55
- head_dim: Per-head dimension, shared by both attention paths. Must be even
56
- (RoPE rotates dimensions in pairs). Paper uses 16.
57
- window_size: Sliding window size for h_l. Paper uses 128.
58
- rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
59
- original sequence positions; ``"semantic_sequence"`` supplies local slot
60
- indices. Both are required; experimentally correct mode is undetermined
61
- (paper §4). Default ``"main_sequence"``.
62
- rms_norm_eps: Epsilon for RMSNorm layers.
63
- local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
64
- Paper uses b=10000.
65
- mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
66
- b=10000.
67
- training_sequence_length: Context length ``C_train`` the model was or will be
68
- trained at. Used to compute the YaRN scale factor for BEA.
69
- inference_sequence_length: Context length ``C_target`` the model must support
70
- at inference. When equal to ``training_sequence_length``, scale ``s=1``
71
- and YaRN reduces to standard RoPE.
72
- alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
73
- ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
74
- beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
75
- ``r(d) > beta`` are left unscaled. Paper value: 32.0.
76
- attention_dropout: Dropout probability on attention weights. Default 0.0.
77
- use_cache: Whether to return past_key_values for KV caching.
78
- output_hidden_states: Whether to return hidden states after each layer.
79
- tie_word_embeddings: Whether input embedding and LM head share weights.
80
- """
81
-
82
- model_type = "shram"
83
-
84
- auto_map = {
85
- "AutoConfig": "configuration.ShramConfig",
86
- "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
87
- }
88
-
89
- def __init__(
90
- self,
91
- vocab_size: int = 50277,
92
- embedding_width: int = 512,
93
- mlp_width: int = 1366,
94
- num_decoder_layers: int = 12,
95
- num_sliding_window_heads: int = 16,
96
- num_mosrah_heads: int = 16,
97
- num_selected_heads: int = 16,
98
- head_dim: int = 16,
99
- window_size: int = 128,
100
- rope_mode: str = "main_sequence",
101
- rms_norm_eps: float = 1e-5,
102
- local_rope_theta: float = 10000.0,
103
- mosrah_rope_theta: float = 10000.0,
104
- training_sequence_length: int = 1024,
105
- alpha: float = 1.0,
106
- beta: float = 32.0,
107
- attention_dropout: float = 0.0,
108
- use_cache: bool = True,
109
- output_hidden_states: bool = False,
110
- tie_word_embeddings: bool = False,
111
- **kwargs,
112
- ):
113
- if head_dim % 2 != 0:
114
- raise ValueError(
115
- f"head_dim must be even (RoPE rotates dimensions in pairs). "
116
- f"Got head_dim={head_dim}."
117
- )
118
-
119
- if rope_mode not in {"main_sequence", "semantic_sequence"}:
120
- raise ValueError(
121
- f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
122
- f"got '{rope_mode}'."
123
- )
124
-
125
- if training_sequence_length <= 0:
126
- raise ValueError(
127
- f"training_sequence_length must be positive, "
128
- f"got {training_sequence_length}."
129
- )
130
-
131
- # inference_sequence_length is not a constructor parameter. It defaults to
132
- # training_sequence_length (scale=1.0, standard RoPE). If a saved config
133
- # carries the field through kwargs (after set_inference_context() was called
134
- # before saving), restore it here with validation.
135
- saved_inference_length = kwargs.pop("inference_sequence_length", training_sequence_length)
136
- if saved_inference_length <= 0:
137
- raise ValueError(
138
- f"inference_sequence_length must be positive, "
139
- f"got {saved_inference_length}."
140
- )
141
-
142
- self.vocab_size = vocab_size
143
- self.hidden_size = embedding_width
144
- self.intermediate_size = mlp_width
145
- self.num_hidden_layers = num_decoder_layers
146
- self.num_sliding_window_heads = num_sliding_window_heads
147
- self.num_mosrah_heads = num_mosrah_heads
148
- self.num_selected_heads = num_selected_heads
149
- self.head_dim = head_dim
150
- self.window_size = window_size
151
- self.rope_mode = rope_mode
152
- self.rms_norm_eps = rms_norm_eps
153
- self.local_rope_theta = local_rope_theta
154
- self.mosrah_rope_theta = mosrah_rope_theta
155
- self.training_sequence_length = training_sequence_length
156
- self.inference_sequence_length = saved_inference_length
157
- self.alpha = alpha
158
- self.beta = beta
159
- self.attention_dropout = attention_dropout
160
- self.use_cache = use_cache
161
-
162
- super().__init__(
163
- tie_word_embeddings=tie_word_embeddings,
164
- output_hidden_states=output_hidden_states,
165
- **kwargs,
166
- )
167
-
168
- # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
169
- # serialises it into config.json.
170
- self.auto_map = type(self).auto_map
171
-
172
- @property
173
- def scale(self) -> float:
174
- """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
175
-
176
- When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
177
- adjustments cancel and A_rope = 1. This is the default state.
178
- """
179
- return self.inference_sequence_length / self.training_sequence_length
180
-
181
- def set_inference_context(self, inference_sequence_length: int) -> None:
182
- """Set the inference context length for YaRN context extension.
183
-
184
- This is the only supported way to set inference_sequence_length. At construction
185
- the inference context defaults to training_sequence_length (scale=1.0, standard
186
- RoPE). Call this method to configure a longer inference context, which causes
187
- YaRN to interpolate frequencies and extend the effective context window.
188
-
189
- Args:
190
- inference_sequence_length: Target inference context length. Must be positive.
191
- Values equal to training_sequence_length produce scale=1.0 (standard RoPE).
192
- Values greater than training_sequence_length enable YaRN extrapolation.
193
- """
194
- if inference_sequence_length <= 0:
195
- raise ValueError(
196
- f"inference_sequence_length must be positive, "
197
- f"got {inference_sequence_length}."
198
- )
199
- self.inference_sequence_length = inference_sequence_length
 
1
+ """Configuration for the SHRAM transformer.
2
+
3
+ All architectural parameters that vary across model scales or are meaningful research
4
+ variables are expressed here. Architectural constants (no bias in linear layers,
5
+ SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
+ documented at the point of use — they are not config parameters because they do not
7
+ vary and changing them produces a different architecture, not a different scale.
8
+
9
+ RoPE configuration is owned entirely by this config. Each attention path reads its
10
+ parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
+ HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
+ """
13
+
14
+ from transformers import PretrainedConfig
15
+
16
+
17
+ class ShramConfig(PretrainedConfig):
18
+ """Configuration class for the SHRAM decoder-only transformer.
19
+
20
+ SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
21
+ attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
22
+ local sliding-window causal attention path and h_s is the MoSRAH sparse routed
23
+ path. All other components follow the Llama 3 baseline.
24
+
25
+ This config is the single source of truth for every architectural dimension of the
26
+ model. Nothing in the architecture may use a literal number that belongs here.
27
+
28
+ Two independent RoPE configurations exist — one per attention path:
29
+
30
+ - h_l always uses standard RoPE with ``local_rope_theta``.
31
+ - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
32
+ ``inference_sequence_length``, ``alpha``, and ``beta``. When
33
+ ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
34
+ ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
35
+ and the correct setting for experiments that do not require context extension.
36
+
37
+ Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
38
+
39
+ config = AutoConfig.from_pretrained(
40
+ "your-namespace/advanced-transformers-lib",
41
+ trust_remote_code=True,
42
+ num_hidden_layers=12,
43
+ )
44
+ model = AutoModelForCausalLM.from_config(config)
45
+
46
+ Args:
47
+ vocab_size: Vocabulary size. Controls the embedding table and output logits
48
+ dimension. Must match the tokenizer.
49
+ embedding_width: Model width ``d``. The dimension of the residual stream.
50
+ mlp_width: FFN hidden dimension.
51
+ num_decoder_layers: Number of transformer blocks stacked in sequence.
52
+ num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
53
+ num_mosrah_heads: Total MoSRAH expert heads available ``L``.
54
+ num_selected_heads: MoSRAH heads each token selects ``K``.
55
+ head_dim: Per-head dimension, shared by both attention paths. Must be even
56
+ (RoPE rotates dimensions in pairs). Paper uses 16.
57
+ window_size: Sliding window size for h_l. Paper uses 128.
58
+ rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
59
+ original sequence positions; ``"semantic_sequence"`` supplies local slot
60
+ indices. Both are required; experimentally correct mode is undetermined
61
+ (paper §4). Default ``"main_sequence"``.
62
+ rms_norm_eps: Epsilon for RMSNorm layers.
63
+ local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
64
+ Paper uses b=10000.
65
+ mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
66
+ b=10000.
67
+ training_sequence_length: Context length ``C_train`` the model was or will be
68
+ trained at. Used to compute the YaRN scale factor for BEA.
69
+ inference_sequence_length: Context length ``C_target`` the model must support
70
+ at inference. Optional; defaults to ``training_sequence_length`` so that
71
+ ``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
72
+ alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
73
+ ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
74
+ beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
75
+ ``r(d) > beta`` are left unscaled. Paper value: 32.0.
76
+ attention_dropout: Dropout probability on attention weights. Default 0.0.
77
+ use_cache: Whether to return past_key_values for KV caching.
78
+ output_hidden_states: Whether to return hidden states after each layer.
79
+ tie_word_embeddings: Whether input embedding and LM head share weights.
80
+ """
81
+
82
+ model_type = "shram"
83
+
84
+ auto_map = {
85
+ "AutoConfig": "configuration.ShramConfig",
86
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ vocab_size: int = 50277,
92
+ embedding_width: int = 512,
93
+ mlp_width: int = 1366,
94
+ num_decoder_layers: int = 12,
95
+ num_sliding_window_heads: int = 16,
96
+ num_mosrah_heads: int = 16,
97
+ num_selected_heads: int = 16,
98
+ head_dim: int = 16,
99
+ window_size: int = 128,
100
+ rope_mode: str = "main_sequence",
101
+ rms_norm_eps: float = 1e-5,
102
+ local_rope_theta: float = 10000.0,
103
+ mosrah_rope_theta: float = 10000.0,
104
+ training_sequence_length: int = 1024,
105
+ inference_sequence_length: int | None = None,
106
+ alpha: float = 1.0,
107
+ beta: float = 32.0,
108
+ attention_dropout: float = 0.0,
109
+ use_cache: bool = True,
110
+ output_hidden_states: bool = False,
111
+ tie_word_embeddings: bool = False,
112
+ **kwargs,
113
+ ):
114
+ if head_dim % 2 != 0:
115
+ raise ValueError(
116
+ f"head_dim must be even (RoPE rotates dimensions in pairs). "
117
+ f"Got head_dim={head_dim}."
118
+ )
119
+
120
+ if rope_mode not in {"main_sequence", "semantic_sequence"}:
121
+ raise ValueError(
122
+ f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
123
+ f"got '{rope_mode}'."
124
+ )
125
+
126
+ if training_sequence_length <= 0:
127
+ raise ValueError(
128
+ f"training_sequence_length must be positive, "
129
+ f"got {training_sequence_length}."
130
+ )
131
+
132
+ if inference_sequence_length is None:
133
+ inference_sequence_length = training_sequence_length
134
+ if inference_sequence_length <= 0:
135
+ raise ValueError(
136
+ f"inference_sequence_length must be positive, "
137
+ f"got {inference_sequence_length}."
138
+ )
139
+
140
+ self.vocab_size = vocab_size
141
+ self.hidden_size = embedding_width
142
+ self.intermediate_size = mlp_width
143
+ self.num_hidden_layers = num_decoder_layers
144
+ self.num_sliding_window_heads = num_sliding_window_heads
145
+ self.num_mosrah_heads = num_mosrah_heads
146
+ self.num_selected_heads = num_selected_heads
147
+ self.head_dim = head_dim
148
+ self.window_size = window_size
149
+ self.rope_mode = rope_mode
150
+ self.rms_norm_eps = rms_norm_eps
151
+ self.local_rope_theta = local_rope_theta
152
+ self.mosrah_rope_theta = mosrah_rope_theta
153
+ self.training_sequence_length = training_sequence_length
154
+ self.inference_sequence_length = inference_sequence_length
155
+ self.alpha = alpha
156
+ self.beta = beta
157
+ self.attention_dropout = attention_dropout
158
+ self.use_cache = use_cache
159
+
160
+ super().__init__(
161
+ tie_word_embeddings=tie_word_embeddings,
162
+ output_hidden_states=output_hidden_states,
163
+ **kwargs,
164
+ )
165
+
166
+ # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
167
+ # serialises it into config.json.
168
+ self.auto_map = type(self).auto_map
169
+
170
+ @property
171
+ def scale(self) -> float:
172
+ """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
173
+
174
+ When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
175
+ adjustments cancel and A_rope = 1. This is the default state.
176
+ """
177
+ return self.inference_sequence_length / self.training_sequence_length
178
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
decoder_layer.py CHANGED
@@ -1,87 +1,87 @@
1
- """Decoder layer — a single transformer block.
2
-
3
- Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
4
- residual connections around both sublayers:
5
-
6
- normed_attn = RMSNorm(x)
7
- attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
8
- h = x + attn_out
9
-
10
- normed_mlp = RMSNorm(h)
11
- mlp_out = SwiGLUMLP(normed_mlp)
12
- out = h + mlp_out
13
-
14
- Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
15
- through unnormalised residuals at depth, and each sublayer receives a stable,
16
- normalised view of the signal.
17
-
18
- Two independent RMSNorm instances are used — one before attention, one before
19
- MLP. They learn different scalings because they precede layers with different
20
- dynamic ranges. Sharing them would be wrong.
21
-
22
- torch.nn.RMSNorm is used directly (available from PyTorch 2.4+). It omits mean
23
- subtraction, is faster than LayerNorm, and proved more stable at scale.
24
- """
25
-
26
- import torch
27
- import torch.nn as nn
28
-
29
- from .__attention__shram import SHRAMHybridLayer
30
- from .__cache__shram_layer_cache import ShramLayerCache
31
- from .configuration import ShramConfig
32
- from .mlp import SwiGLUMLP
33
-
34
-
35
- class DecoderLayer(nn.Module):
36
- """A single pre-norm SHRAM decoder block.
37
-
38
- Composes SHRAMHybridLayer and SwiGLUMLP with residual connections and
39
- independent RMSNorm instances on each sublayer input.
40
-
41
- Args:
42
- config: SHRAM config. Must expose ``hidden_size`` and ``rms_norm_eps``
43
- in addition to the fields required by SHRAMHybridLayer and
44
- SwiGLUMLP.
45
- """
46
-
47
- def __init__(self, config: ShramConfig) -> None:
48
- super().__init__()
49
- self.attn_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
50
- self.mlp_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51
- self.attention = SHRAMHybridLayer(config)
52
- self.mlp = SwiGLUMLP(config)
53
-
54
- def forward(
55
- self,
56
- x: torch.Tensor,
57
- position_ids: torch.Tensor,
58
- active_mask: torch.Tensor,
59
- cache: ShramLayerCache | None = None,
60
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
61
- """Apply one decoder block to the input.
62
-
63
- Args:
64
- x: Input of shape (batch, seq_len, hidden_size).
65
- position_ids: Authoritative positions of shape (batch, seq_len).
66
- active_mask: Current-chunk active mask of shape (batch, seq_len),
67
- where True means the token is semantically live. Forwarded
68
- unchanged to the hybrid attention layer.
69
- cache: Optional per-layer SHRAM cache passed through to the hybrid
70
- attention layer unchanged.
71
-
72
- Returns:
73
- output: Tensor of shape (batch, seq_len, hidden_size).
74
- load_balance_loss: Scalar sparse-path load-balance loss propagated
75
- from SHRAMHybridLayer.
76
- max_vio: Detached scalar routing-imbalance summary. Passed through
77
- unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
78
- """
79
- attn_out, load_balance_loss, max_vio = self.attention(
80
- hidden_states=self.attn_norm(x),
81
- position_ids=position_ids,
82
- active_mask=active_mask,
83
- cache=cache,
84
- )
85
- hidden_states = x + attn_out
86
- output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
87
  return output, load_balance_loss, max_vio
 
1
+ """Decoder layer — a single transformer block.
2
+
3
+ Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
4
+ residual connections around both sublayers:
5
+
6
+ normed_attn = RMSNorm(x)
7
+ attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
8
+ h = x + attn_out
9
+
10
+ normed_mlp = RMSNorm(h)
11
+ mlp_out = SwiGLUMLP(normed_mlp)
12
+ out = h + mlp_out
13
+
14
+ Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
15
+ through unnormalised residuals at depth, and each sublayer receives a stable,
16
+ normalised view of the signal.
17
+
18
+ Two independent RMSNorm instances are used — one before attention, one before
19
+ MLP. They learn different scalings because they precede layers with different
20
+ dynamic ranges. Sharing them would be wrong.
21
+
22
+ torch.nn.RMSNorm is used directly (available from PyTorch 2.4+). It omits mean
23
+ subtraction, is faster than LayerNorm, and proved more stable at scale.
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ from .__attention__shram import SHRAMHybridLayer
30
+ from .__cache__shram_layer_cache import ShramLayerCache
31
+ from .configuration import ShramConfig
32
+ from .mlp import SwiGLUMLP
33
+
34
+
35
+ class DecoderLayer(nn.Module):
36
+ """A single pre-norm SHRAM decoder block.
37
+
38
+ Composes SHRAMHybridLayer and SwiGLUMLP with residual connections and
39
+ independent RMSNorm instances on each sublayer input.
40
+
41
+ Args:
42
+ config: SHRAM config. Must expose ``hidden_size`` and ``rms_norm_eps``
43
+ in addition to the fields required by SHRAMHybridLayer and
44
+ SwiGLUMLP.
45
+ """
46
+
47
+ def __init__(self, config: ShramConfig) -> None:
48
+ super().__init__()
49
+ self.attn_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
50
+ self.mlp_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51
+ self.attention = SHRAMHybridLayer(config)
52
+ self.mlp = SwiGLUMLP(config)
53
+
54
+ def forward(
55
+ self,
56
+ x: torch.Tensor,
57
+ position_ids: torch.Tensor,
58
+ active_mask: torch.Tensor,
59
+ cache: ShramLayerCache | None = None,
60
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
61
+ """Apply one decoder block to the input.
62
+
63
+ Args:
64
+ x: Input of shape (batch, seq_len, hidden_size).
65
+ position_ids: Authoritative positions of shape (batch, seq_len).
66
+ active_mask: Current-chunk active mask of shape (batch, seq_len),
67
+ where True means the token is semantically live. Forwarded
68
+ unchanged to the hybrid attention layer.
69
+ cache: Optional per-layer SHRAM cache passed through to the hybrid
70
+ attention layer unchanged.
71
+
72
+ Returns:
73
+ output: Tensor of shape (batch, seq_len, hidden_size).
74
+ load_balance_loss: Scalar sparse-path load-balance loss propagated
75
+ from SHRAMHybridLayer.
76
+ max_vio: Detached scalar routing-imbalance summary. Passed through
77
+ unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
78
+ """
79
+ attn_out, load_balance_loss, max_vio = self.attention(
80
+ hidden_states=self.attn_norm(x),
81
+ position_ids=position_ids,
82
+ active_mask=active_mask,
83
+ cache=cache,
84
+ )
85
+ hidden_states = x + attn_out
86
+ output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
87
  return output, load_balance_loss, max_vio
huggingface.py CHANGED
@@ -1,518 +1,518 @@
1
- """HuggingFace causal-LM wrapper for SHRAM.
2
-
3
- ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM.
4
- It owns token embedding lookup, LM-head projection, wrapper-level next-token
5
- cross-entropy loss, config-controlled tied embeddings, and generation/cache
6
- orchestration at the wrapper boundary.
7
-
8
- The backbone remains a pure transformer stack. ShramModel accepts pre-embedded
9
- hidden states together with current position IDs, a current active mask, and an
10
- optional ShramCache. It has no knowledge of token IDs, vocabulary projection,
11
- or causal-LM loss.
12
-
13
- HuggingFace generation reaches this wrapper with two different tensor
14
- conventions:
15
-
16
- - ``position_ids`` is a current-step tensor. GenerationMixin updates the total
17
- sequence state between steps, then slices position-bearing tensors back down
18
- before calling ``forward()``.
19
- - ``attention_mask`` is a full 2D mask over the total sequence so far. This
20
- wrapper slices its recent chunk to produce the current semantic liveness mask
21
- expected by the backbone.
22
-
23
- Generation-created caches are handled in ``_prepare_cache_for_generation``.
24
- That hook ensures HuggingFace generation uses ShramCache rather than a generic
25
- dynamic cache. The direct ``forward()`` path does not silently create caches;
26
- when ``use_cache=True`` it expects a truthful ShramCache to have been supplied.
27
- """
28
-
29
- from dataclasses import dataclass
30
- from typing import Any
31
-
32
- import torch
33
- import torch.nn as nn
34
- from transformers import GenerationMixin, PreTrainedModel
35
- from transformers.cache_utils import Cache
36
- from transformers.generation.configuration_utils import GenerationMode
37
- from transformers.modeling_outputs import CausalLMOutputWithPast
38
-
39
- from .__cache__shram_cache import ShramCache
40
- from .configuration import ShramConfig
41
- from .model import ShramModel
42
-
43
-
44
- @dataclass
45
- class ShramCausalLMOutput(CausalLMOutputWithPast):
46
- """SHRAM causal-LM wrapper output.
47
-
48
- This subclasses HuggingFace's standard ``CausalLMOutputWithPast``.
49
- Dataclass inheritance is sufficient here: all standard causal-LM fields and
50
- ModelOutput behavior are inherited from the parent, and this subclass adds
51
- only the SHRAM-specific wrapper outputs.
52
- """
53
-
54
- ce_loss: torch.FloatTensor | None = None
55
- load_balance_loss: torch.FloatTensor | None = None
56
- max_vio: torch.FloatTensor | None = None
57
-
58
-
59
- class ShramForCausalLM(PreTrainedModel, GenerationMixin):
60
- """HuggingFace-facing causal language model wrapper for SHRAM.
61
-
62
- Owns token embeddings, LM-head projection, wrapper-level shifted CE loss,
63
- tied embedding configuration, and generation/cache boundary behavior.
64
- Delegates all transformer computation to ``ShramModel``.
65
-
66
- Args:
67
- config: SHRAM model configuration.
68
- """
69
-
70
- config_class = ShramConfig
71
- base_model_prefix = "model"
72
- _no_split_modules = ["DecoderLayer"]
73
- supports_gradient_checkpointing = True
74
-
75
- def __init__(self, config: ShramConfig) -> None:
76
- super().__init__(config)
77
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
78
- self.model = ShramModel(config)
79
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
80
- self._configure_tied_embeddings()
81
- self.post_init()
82
-
83
- def _configure_tied_embeddings(self) -> None:
84
- """Apply config-controlled tied embedding behavior on this instance."""
85
- if self.config.tie_word_embeddings:
86
- self.lm_head.weight = self.embed_tokens.weight
87
- self._tied_weights_keys = {
88
- "lm_head.weight": "embed_tokens.weight",
89
- }
90
- else:
91
- self._tied_weights_keys = {}
92
-
93
- def get_input_embeddings(self) -> nn.Embedding:
94
- """Return the token embedding matrix."""
95
- return self.embed_tokens
96
-
97
- def set_input_embeddings(self, value: nn.Embedding) -> None:
98
- """Replace the token embedding matrix."""
99
- self.embed_tokens = value
100
- self._configure_tied_embeddings()
101
-
102
- def get_output_embeddings(self) -> nn.Linear:
103
- """Return the LM head."""
104
- return self.lm_head
105
-
106
- def set_output_embeddings(self, value: nn.Linear) -> None:
107
- """Replace the LM head."""
108
- self.lm_head = value
109
- self._configure_tied_embeddings()
110
-
111
- def _build_shram_cache(
112
- self,
113
- batch_size: int,
114
- device: torch.device,
115
- ) -> ShramCache:
116
- """Construct a fresh top-level SHRAM cache."""
117
- return ShramCache(
118
- num_hidden_layers=self.config.num_hidden_layers,
119
- sliding_window=self.config.window_size,
120
- num_local_heads=self.config.num_sliding_window_heads,
121
- local_head_dim=self.config.head_dim,
122
- num_mosrah_heads=self.config.num_mosrah_heads,
123
- mosrah_head_dim=self.config.head_dim,
124
- batch_size=batch_size,
125
- device=device,
126
- )
127
-
128
- def _validate_generation_cache_request(
129
- self,
130
- generation_config: Any,
131
- model_kwargs: dict[str, Any],
132
- generation_mode: GenerationMode,
133
- ) -> None:
134
- """Validate SHRAM's generation-side cache policy."""
135
- if generation_mode in {
136
- GenerationMode.ASSISTED_GENERATION,
137
- GenerationMode.CONTRASTIVE_SEARCH,
138
- }:
139
- raise NotImplementedError(
140
- "ShramForCausalLM does not currently support assisted generation "
141
- "or contrastive search because ShramCache does not support crop()."
142
- )
143
-
144
- user_defined_cache = model_kwargs.get("past_key_values")
145
- if user_defined_cache is not None:
146
- if generation_config.cache_implementation is not None:
147
- raise ValueError(
148
- "Passing both `cache_implementation` and `past_key_values` "
149
- "is unsupported. Please use only one."
150
- )
151
- if isinstance(user_defined_cache, tuple):
152
- raise ValueError(
153
- "Passing a tuple of `past_key_values` is not supported. "
154
- "Please use a `ShramCache` instance."
155
- )
156
- if not isinstance(user_defined_cache, ShramCache):
157
- raise TypeError(
158
- "ShramForCausalLM requires `past_key_values` to be a "
159
- "`ShramCache` instance."
160
- )
161
-
162
- if (
163
- user_defined_cache is None
164
- and generation_config.use_cache
165
- and generation_config.cache_implementation is not None
166
- ):
167
- raise ValueError(
168
- "ShramForCausalLM does not support `cache_implementation`. "
169
- "Generation-created caches must be `ShramCache` objects."
170
- )
171
-
172
- def _prepare_cache_for_generation(
173
- self,
174
- generation_config: Any,
175
- model_kwargs: dict[str, Any],
176
- generation_mode: GenerationMode,
177
- batch_size: int,
178
- max_cache_length: int,
179
- ) -> None:
180
- """Ensure HuggingFace generation uses ShramCache.
181
-
182
- This is the SHRAM-specific generation hook. The rest of the default
183
- generation plumbing is kept intact as much as possible.
184
-
185
- Args:
186
- generation_config: Active generation configuration.
187
- model_kwargs: Generation kwargs, updated in place.
188
- generation_mode: HuggingFace generation mode.
189
- batch_size: Effective generation batch size.
190
- max_cache_length: Requested cache length. Accepted but unused here.
191
- """
192
- self._validate_generation_cache_request(
193
- generation_config=generation_config,
194
- model_kwargs=model_kwargs,
195
- generation_mode=generation_mode,
196
- )
197
-
198
- if model_kwargs.get("past_key_values") is not None:
199
- return
200
-
201
- if not generation_config.use_cache:
202
- return
203
-
204
- num_repeats = max(
205
- generation_config.num_beams or 1,
206
- generation_config.num_return_sequences or 1,
207
- )
208
- model_kwargs["past_key_values"] = self._build_shram_cache(
209
- batch_size=batch_size*num_repeats,
210
- device=self.embed_tokens.weight.device,
211
- )
212
-
213
- def _reorder_cache(
214
- self,
215
- past_key_values: Cache,
216
- beam_idx: torch.Tensor,
217
- ) -> Cache:
218
- """Reorder the cache in place for beam search."""
219
- past_key_values.reorder_cache(beam_idx)
220
- return past_key_values
221
-
222
- def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
223
- """Validate token IDs at the wrapper boundary."""
224
- if input_ids.ndim != 2:
225
- raise ValueError("input_ids must have shape (batch, seq_len).")
226
- if input_ids.shape[1] == 0:
227
- raise ValueError("input_ids sequence length must be nonzero.")
228
- if input_ids.dtype != torch.long:
229
- raise TypeError("input_ids must be an long int tensor.")
230
-
231
- def _validate_attention_mask(
232
- self,
233
- input_ids: torch.Tensor,
234
- attention_mask: torch.Tensor | None,
235
- ) -> None:
236
- """Validate the full-sequence attention mask."""
237
- if attention_mask is None:
238
- return
239
- if attention_mask.ndim != 2:
240
- raise ValueError("attention_mask must have shape (batch, total_seq_len).")
241
- if attention_mask.shape[0] != input_ids.shape[0]:
242
- raise ValueError("attention_mask batch dimension must match input_ids.")
243
- if attention_mask.shape[1] < input_ids.shape[1]:
244
- raise ValueError(
245
- "attention_mask must be at least as long as the current input_ids chunk."
246
- )
247
-
248
- def _validate_position_ids(
249
- self,
250
- input_ids: torch.Tensor,
251
- position_ids: torch.Tensor | None,
252
- ) -> None:
253
- """Validate current-step position IDs."""
254
- if position_ids is None:
255
- return
256
- if position_ids.ndim != 2:
257
- raise ValueError("position_ids must have shape (batch, seq_len).")
258
- if position_ids.shape != input_ids.shape:
259
- raise ValueError(
260
- "position_ids must match the current input_ids shape exactly."
261
- )
262
- if input_ids.dtype != torch.long:
263
- raise TypeError("position_ids must be an long tensor.")
264
-
265
- def _validate_labels(
266
- self,
267
- input_ids: torch.Tensor,
268
- labels: torch.Tensor | None,
269
- ) -> None:
270
- """Validate label shape at the wrapper boundary."""
271
- if labels is None:
272
- return
273
- if labels.ndim != 2:
274
- raise ValueError("labels must have shape (batch, seq_len).")
275
- if labels.shape != input_ids.shape:
276
- raise ValueError("labels must have the same shape as input_ids.")
277
- if input_ids.dtype != torch.long:
278
- raise TypeError("labels must be a long tensor.")
279
-
280
- def _validate_cache_inputs(
281
- self,
282
- use_cache: bool,
283
- past_key_values: Cache | None,
284
- ) -> None:
285
- """Validate cache policy for direct wrapper calls."""
286
- if use_cache:
287
- if past_key_values is None:
288
- raise ValueError(
289
- "use_cache=True requires an explicit ShramCache. During "
290
- "generate(), HuggingFace should supply this through "
291
- "_prepare_cache_for_generation()."
292
- )
293
- if not isinstance(past_key_values, ShramCache):
294
- raise TypeError(
295
- "past_key_values must be a ShramCache when use_cache=True."
296
- )
297
- return
298
-
299
- if past_key_values is not None:
300
- raise ValueError("past_key_values was provided while use_cache=False.")
301
-
302
- def _validate_position_sources(
303
- self,
304
- use_cache: bool,
305
- attention_mask: torch.Tensor | None,
306
- position_ids: torch.Tensor | None,
307
- ) -> None:
308
- """Validate that cached forward has a truthful source of positions."""
309
- if use_cache and attention_mask is None and position_ids is None:
310
- raise ValueError(
311
- "Cached forward requires either position_ids or attention_mask."
312
- )
313
-
314
- def _validate_hf_boundary(
315
- self,
316
- output_attentions: bool | None,
317
- return_dict: bool | None,
318
- inputs_embeds: torch.Tensor | None,
319
- cache_position: torch.Tensor | None,
320
- extra_kwargs: dict[str, Any],
321
- ) -> None:
322
- """Validate unsupported HuggingFace-facing wrapper inputs."""
323
- if output_attentions:
324
- raise NotImplementedError(
325
- "ShramForCausalLM does not expose output_attentions."
326
- )
327
- if return_dict is False:
328
- raise ValueError(
329
- "return_dict=False is not supported. "
330
- "ShramForCausalLM always returns ShramCausalLMOutput."
331
- )
332
- if inputs_embeds is not None:
333
- raise ValueError(
334
- "inputs_embeds is not supported at the SHRAM wrapper boundary. "
335
- "Pass input_ids instead."
336
- )
337
- if extra_kwargs:
338
- unsupported = ", ".join(sorted(extra_kwargs))
339
- raise TypeError(
340
- f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
341
- )
342
-
343
- def _standardize_full_attention_mask(
344
- self,
345
- input_ids: torch.Tensor,
346
- attention_mask: torch.Tensor | None,
347
- ) -> torch.BoolTensor:
348
- """Return a concrete full-sequence boolean attention mask."""
349
- if attention_mask is None:
350
- return torch.ones_like(input_ids, dtype=torch.bool)
351
- return attention_mask.to(dtype=torch.bool)
352
-
353
- def _resolve_current_position_ids(
354
- self,
355
- input_ids: torch.Tensor,
356
- position_ids: torch.Tensor | None,
357
- full_attention_mask: torch.BoolTensor,
358
- ) -> torch.LongTensor:
359
- """Resolve concrete current-step position IDs for the backbone."""
360
- if position_ids is not None:
361
- return position_ids.to(dtype=torch.long)
362
-
363
- full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
364
- full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
365
- current_length = input_ids.shape[1]
366
- return full_position_ids[:, -current_length:]
367
-
368
- def forward(
369
- self,
370
- input_ids: torch.Tensor,
371
- attention_mask: torch.Tensor | None = None,
372
- position_ids: torch.Tensor | None = None,
373
- past_key_values: Cache | None = None,
374
- use_cache: bool | None = None,
375
- output_hidden_states: bool | None = None,
376
- labels: torch.Tensor | None = None,
377
- return_dict: bool | None = None,
378
- ce_weight: float = 1.0,
379
- load_balance_weight: float = 0.01,
380
- **kwargs: Any,
381
- ) -> ShramCausalLMOutput:
382
- """Run the SHRAM causal language model wrapper.
383
-
384
- Args:
385
- input_ids: Current token IDs of shape ``(batch, seq_len)``.
386
- attention_mask: Optional full 2D mask of shape
387
- ``(batch, total_seq_len)``. The wrapper slices its recent chunk
388
- to produce the current semantic liveness mask expected by the
389
- backbone.
390
- position_ids: Optional current-step position IDs of shape
391
- ``(batch, seq_len)``. In ordinary HuggingFace generation this is
392
- already the current-step tensor when it reaches ``forward()``.
393
- past_key_values: Optional SHRAM cache. Required when
394
- ``use_cache=True``.
395
- use_cache: Whether to use and return a cache. Defaults to
396
- ``config.use_cache``.
397
- output_hidden_states: Whether to return backbone hidden states.
398
- Defaults to ``config.output_hidden_states``.
399
- labels: Optional target token IDs of shape ``(batch, seq_len)``.
400
- return_dict: Must be ``True`` or ``None``.
401
- ce_weight: Weight applied to the cross-entropy loss when combining with
402
- the load-balance loss. Default 1.0.
403
- load_balance_weight: Weight applied to the load-balance auxiliary loss.
404
- Default 0.01, matching the paper's recommendation.
405
- **kwargs: Unsupported HuggingFace kwargs fail explicitly.
406
-
407
- Returns:
408
- ``ShramCausalLMOutput`` with:
409
- - ``logits`` of shape ``(batch, seq_len, vocab_size)``,
410
- - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss``
411
- when labels are provided (``None`` otherwise),
412
- - ``ce_loss`` — raw unweighted cross-entropy loss for logging,
413
- - ``past_key_values`` as the active ``ShramCache`` or ``None``,
414
- - ``hidden_states`` when requested,
415
- - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
416
- - detached ``max_vio`` from the backbone.
417
- """
418
- use_cache = use_cache if use_cache is not None else self.config.use_cache
419
- output_hidden_states = (
420
- output_hidden_states
421
- if output_hidden_states is not None
422
- else self.config.output_hidden_states
423
- )
424
-
425
- inputs_embeds = kwargs.pop("inputs_embeds", None)
426
- output_attentions = kwargs.pop("output_attentions", None)
427
- cache_position = kwargs.pop("cache_position", None)
428
-
429
- # ------------------------------------------------------------------
430
- # Validation zone.
431
- #
432
- # The wrapper boundary is where HuggingFace-facing inputs are judged
433
- # for truthfulness before any internal work begins. These checks are
434
- # intentionally front-loaded so the core logic below can assume one
435
- # coherent interpretation of the call rather than defensively checking
436
- # shapes, cache policy, or unsupported HF knobs at the point of use.
437
- # This keeps the main sequence readable while ensuring invalid states
438
- # fail before they can silently contaminate backbone execution.
439
- # ------------------------------------------------------------------
440
- self._validate_input_ids(input_ids)
441
- self._validate_attention_mask(input_ids, attention_mask)
442
- self._validate_position_ids(input_ids, position_ids)
443
- self._validate_labels(input_ids, labels)
444
- self._validate_cache_inputs(use_cache, past_key_values)
445
- self._validate_position_sources(use_cache, attention_mask, position_ids)
446
- self._validate_hf_boundary(
447
- output_attentions=output_attentions,
448
- return_dict=return_dict,
449
- inputs_embeds=inputs_embeds,
450
- cache_position=cache_position,
451
- extra_kwargs=kwargs,
452
- )
453
-
454
- # ------------------------------------------------------------------
455
- # Standardization zone.
456
- #
457
- # HuggingFace and SHRAM use different boundary conventions: generation
458
- # carries a full-sequence 2D attention mask, while the SHRAM backbone
459
- # wants a current-step active mask and concrete current position IDs.
460
- # This zone collapses those wrapper-facing conventions into one valid
461
- # backbone-facing state. After this point the core no longer reasons
462
- # about optional or ambiguous input forms; it works only with concrete
463
- # tensors whose semantics are already fixed.
464
- # ------------------------------------------------------------------
465
- full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask(
466
- input_ids=input_ids,
467
- attention_mask=attention_mask,
468
- )
469
- current_length: int = input_ids.shape[1]
470
- current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
471
- current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
472
- input_ids=input_ids,
473
- position_ids=position_ids,
474
- full_attention_mask=full_attention_mask,
475
- )
476
- shram_cache: ShramCache | None = past_key_values if use_cache else None
477
-
478
- # ------------------------------------------------------------------
479
- # Core wrapper responsibilities.
480
- #
481
- # The wrapper's primary job is kept visible here: convert token IDs to
482
- # embeddings, delegate transformer computation to ShramModel, project
483
- # hidden states back to vocabulary logits, optionally compute the
484
- # wrapper-level shifted next-token loss, and return the HuggingFace-
485
- # facing output object. The backbone remains responsible only for
486
- # transformer semantics; token/vocabulary/loss concerns stay here.
487
- # ------------------------------------------------------------------
488
- token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids)
489
- backbone_outputs = self.model(
490
- inputs_embeds=token_embeddings,
491
- position_ids=current_position_ids,
492
- active_mask=current_active_mask,
493
- cache=shram_cache,
494
- output_hidden_states=output_hidden_states,
495
- )
496
-
497
- logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"])
498
-
499
- ce_loss: torch.FloatTensor | None = None
500
- loss: torch.FloatTensor | None = None
501
- if labels is not None:
502
- shift_logits = logits[:, :-1, :].contiguous()
503
- shift_labels = labels[:, 1:].contiguous()
504
- ce_loss = nn.functional.cross_entropy(
505
- shift_logits.view(-1, self.config.vocab_size),
506
- shift_labels.view(-1),
507
- )
508
- loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"]
509
-
510
- return ShramCausalLMOutput(
511
- loss=loss,
512
- ce_loss=ce_loss,
513
- logits=logits,
514
- past_key_values=backbone_outputs["past_key_values"],
515
- hidden_states=backbone_outputs["hidden_states"],
516
- load_balance_loss=backbone_outputs["load_balance_loss"],
517
- max_vio=backbone_outputs["max_vio"],
518
  )
 
1
+ """HuggingFace causal-LM wrapper for SHRAM.
2
+
3
+ ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM.
4
+ It owns token embedding lookup, LM-head projection, wrapper-level next-token
5
+ cross-entropy loss, config-controlled tied embeddings, and generation/cache
6
+ orchestration at the wrapper boundary.
7
+
8
+ The backbone remains a pure transformer stack. ShramModel accepts pre-embedded
9
+ hidden states together with current position IDs, a current active mask, and an
10
+ optional ShramCache. It has no knowledge of token IDs, vocabulary projection,
11
+ or causal-LM loss.
12
+
13
+ HuggingFace generation reaches this wrapper with two different tensor
14
+ conventions:
15
+
16
+ - ``position_ids`` is a current-step tensor. GenerationMixin updates the total
17
+ sequence state between steps, then slices position-bearing tensors back down
18
+ before calling ``forward()``.
19
+ - ``attention_mask`` is a full 2D mask over the total sequence so far. This
20
+ wrapper slices its recent chunk to produce the current semantic liveness mask
21
+ expected by the backbone.
22
+
23
+ Generation-created caches are handled in ``_prepare_cache_for_generation``.
24
+ That hook ensures HuggingFace generation uses ShramCache rather than a generic
25
+ dynamic cache. The direct ``forward()`` path does not silently create caches;
26
+ when ``use_cache=True`` it expects a truthful ShramCache to have been supplied.
27
+ """
28
+
29
+ from dataclasses import dataclass
30
+ from typing import Any
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ from transformers import GenerationMixin, PreTrainedModel
35
+ from transformers.cache_utils import Cache
36
+ from transformers.generation.configuration_utils import GenerationMode
37
+ from transformers.modeling_outputs import CausalLMOutputWithPast
38
+
39
+ from .__cache__shram_cache import ShramCache
40
+ from .configuration import ShramConfig
41
+ from .model import ShramModel
42
+
43
+
44
+ @dataclass
45
+ class ShramCausalLMOutput(CausalLMOutputWithPast):
46
+ """SHRAM causal-LM wrapper output.
47
+
48
+ This subclasses HuggingFace's standard ``CausalLMOutputWithPast``.
49
+ Dataclass inheritance is sufficient here: all standard causal-LM fields and
50
+ ModelOutput behavior are inherited from the parent, and this subclass adds
51
+ only the SHRAM-specific wrapper outputs.
52
+ """
53
+
54
+ ce_loss: torch.FloatTensor | None = None
55
+ load_balance_loss: torch.FloatTensor | None = None
56
+ max_vio: torch.FloatTensor | None = None
57
+
58
+
59
+ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
60
+ """HuggingFace-facing causal language model wrapper for SHRAM.
61
+
62
+ Owns token embeddings, LM-head projection, wrapper-level shifted CE loss,
63
+ tied embedding configuration, and generation/cache boundary behavior.
64
+ Delegates all transformer computation to ``ShramModel``.
65
+
66
+ Args:
67
+ config: SHRAM model configuration.
68
+ """
69
+
70
+ config_class = ShramConfig
71
+ base_model_prefix = "model"
72
+ _no_split_modules = ["DecoderLayer"]
73
+ supports_gradient_checkpointing = True
74
+
75
+ def __init__(self, config: ShramConfig) -> None:
76
+ super().__init__(config)
77
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
78
+ self.model = ShramModel(config)
79
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
80
+ self._configure_tied_embeddings()
81
+ self.post_init()
82
+
83
+ def _configure_tied_embeddings(self) -> None:
84
+ """Apply config-controlled tied embedding behavior on this instance."""
85
+ if self.config.tie_word_embeddings:
86
+ self.lm_head.weight = self.embed_tokens.weight
87
+ self._tied_weights_keys = {
88
+ "lm_head.weight": "embed_tokens.weight",
89
+ }
90
+ else:
91
+ self._tied_weights_keys = {}
92
+
93
+ def get_input_embeddings(self) -> nn.Embedding:
94
+ """Return the token embedding matrix."""
95
+ return self.embed_tokens
96
+
97
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
98
+ """Replace the token embedding matrix."""
99
+ self.embed_tokens = value
100
+ self._configure_tied_embeddings()
101
+
102
+ def get_output_embeddings(self) -> nn.Linear:
103
+ """Return the LM head."""
104
+ return self.lm_head
105
+
106
+ def set_output_embeddings(self, value: nn.Linear) -> None:
107
+ """Replace the LM head."""
108
+ self.lm_head = value
109
+ self._configure_tied_embeddings()
110
+
111
+ def _build_shram_cache(
112
+ self,
113
+ batch_size: int,
114
+ device: torch.device,
115
+ ) -> ShramCache:
116
+ """Construct a fresh top-level SHRAM cache."""
117
+ return ShramCache(
118
+ num_hidden_layers=self.config.num_hidden_layers,
119
+ sliding_window=self.config.window_size,
120
+ num_local_heads=self.config.num_sliding_window_heads,
121
+ local_head_dim=self.config.head_dim,
122
+ num_mosrah_heads=self.config.num_mosrah_heads,
123
+ mosrah_head_dim=self.config.head_dim,
124
+ batch_size=batch_size,
125
+ device=device,
126
+ )
127
+
128
+ def _validate_generation_cache_request(
129
+ self,
130
+ generation_config: Any,
131
+ model_kwargs: dict[str, Any],
132
+ generation_mode: GenerationMode,
133
+ ) -> None:
134
+ """Validate SHRAM's generation-side cache policy."""
135
+ if generation_mode in {
136
+ GenerationMode.ASSISTED_GENERATION,
137
+ GenerationMode.CONTRASTIVE_SEARCH,
138
+ }:
139
+ raise NotImplementedError(
140
+ "ShramForCausalLM does not currently support assisted generation "
141
+ "or contrastive search because ShramCache does not support crop()."
142
+ )
143
+
144
+ user_defined_cache = model_kwargs.get("past_key_values")
145
+ if user_defined_cache is not None:
146
+ if generation_config.cache_implementation is not None:
147
+ raise ValueError(
148
+ "Passing both `cache_implementation` and `past_key_values` "
149
+ "is unsupported. Please use only one."
150
+ )
151
+ if isinstance(user_defined_cache, tuple):
152
+ raise ValueError(
153
+ "Passing a tuple of `past_key_values` is not supported. "
154
+ "Please use a `ShramCache` instance."
155
+ )
156
+ if not isinstance(user_defined_cache, ShramCache):
157
+ raise TypeError(
158
+ "ShramForCausalLM requires `past_key_values` to be a "
159
+ "`ShramCache` instance."
160
+ )
161
+
162
+ if (
163
+ user_defined_cache is None
164
+ and generation_config.use_cache
165
+ and generation_config.cache_implementation is not None
166
+ ):
167
+ raise ValueError(
168
+ "ShramForCausalLM does not support `cache_implementation`. "
169
+ "Generation-created caches must be `ShramCache` objects."
170
+ )
171
+
172
+ def _prepare_cache_for_generation(
173
+ self,
174
+ generation_config: Any,
175
+ model_kwargs: dict[str, Any],
176
+ generation_mode: GenerationMode,
177
+ batch_size: int,
178
+ max_cache_length: int,
179
+ ) -> None:
180
+ """Ensure HuggingFace generation uses ShramCache.
181
+
182
+ This is the SHRAM-specific generation hook. The rest of the default
183
+ generation plumbing is kept intact as much as possible.
184
+
185
+ Args:
186
+ generation_config: Active generation configuration.
187
+ model_kwargs: Generation kwargs, updated in place.
188
+ generation_mode: HuggingFace generation mode.
189
+ batch_size: Effective generation batch size.
190
+ max_cache_length: Requested cache length. Accepted but unused here.
191
+ """
192
+ self._validate_generation_cache_request(
193
+ generation_config=generation_config,
194
+ model_kwargs=model_kwargs,
195
+ generation_mode=generation_mode,
196
+ )
197
+
198
+ if model_kwargs.get("past_key_values") is not None:
199
+ return
200
+
201
+ if not generation_config.use_cache:
202
+ return
203
+
204
+ num_repeats = max(
205
+ generation_config.num_beams or 1,
206
+ generation_config.num_return_sequences or 1,
207
+ )
208
+ model_kwargs["past_key_values"] = self._build_shram_cache(
209
+ batch_size=batch_size*num_repeats,
210
+ device=self.embed_tokens.weight.device,
211
+ )
212
+
213
+ def _reorder_cache(
214
+ self,
215
+ past_key_values: Cache,
216
+ beam_idx: torch.Tensor,
217
+ ) -> Cache:
218
+ """Reorder the cache in place for beam search."""
219
+ past_key_values.reorder_cache(beam_idx)
220
+ return past_key_values
221
+
222
+ def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
223
+ """Validate token IDs at the wrapper boundary."""
224
+ if input_ids.ndim != 2:
225
+ raise ValueError("input_ids must have shape (batch, seq_len).")
226
+ if input_ids.shape[1] == 0:
227
+ raise ValueError("input_ids sequence length must be nonzero.")
228
+ if input_ids.dtype != torch.long:
229
+ raise TypeError("input_ids must be an long int tensor.")
230
+
231
+ def _validate_attention_mask(
232
+ self,
233
+ input_ids: torch.Tensor,
234
+ attention_mask: torch.Tensor | None,
235
+ ) -> None:
236
+ """Validate the full-sequence attention mask."""
237
+ if attention_mask is None:
238
+ return
239
+ if attention_mask.ndim != 2:
240
+ raise ValueError("attention_mask must have shape (batch, total_seq_len).")
241
+ if attention_mask.shape[0] != input_ids.shape[0]:
242
+ raise ValueError("attention_mask batch dimension must match input_ids.")
243
+ if attention_mask.shape[1] < input_ids.shape[1]:
244
+ raise ValueError(
245
+ "attention_mask must be at least as long as the current input_ids chunk."
246
+ )
247
+
248
+ def _validate_position_ids(
249
+ self,
250
+ input_ids: torch.Tensor,
251
+ position_ids: torch.Tensor | None,
252
+ ) -> None:
253
+ """Validate current-step position IDs."""
254
+ if position_ids is None:
255
+ return
256
+ if position_ids.ndim != 2:
257
+ raise ValueError("position_ids must have shape (batch, seq_len).")
258
+ if position_ids.shape != input_ids.shape:
259
+ raise ValueError(
260
+ "position_ids must match the current input_ids shape exactly."
261
+ )
262
+ if input_ids.dtype != torch.long:
263
+ raise TypeError("position_ids must be an long tensor.")
264
+
265
+ def _validate_labels(
266
+ self,
267
+ input_ids: torch.Tensor,
268
+ labels: torch.Tensor | None,
269
+ ) -> None:
270
+ """Validate label shape at the wrapper boundary."""
271
+ if labels is None:
272
+ return
273
+ if labels.ndim != 2:
274
+ raise ValueError("labels must have shape (batch, seq_len).")
275
+ if labels.shape != input_ids.shape:
276
+ raise ValueError("labels must have the same shape as input_ids.")
277
+ if input_ids.dtype != torch.long:
278
+ raise TypeError("labels must be a long tensor.")
279
+
280
+ def _validate_cache_inputs(
281
+ self,
282
+ use_cache: bool,
283
+ past_key_values: Cache | None,
284
+ ) -> None:
285
+ """Validate cache policy for direct wrapper calls."""
286
+ if use_cache:
287
+ if past_key_values is None:
288
+ raise ValueError(
289
+ "use_cache=True requires an explicit ShramCache. During "
290
+ "generate(), HuggingFace should supply this through "
291
+ "_prepare_cache_for_generation()."
292
+ )
293
+ if not isinstance(past_key_values, ShramCache):
294
+ raise TypeError(
295
+ "past_key_values must be a ShramCache when use_cache=True."
296
+ )
297
+ return
298
+
299
+ if past_key_values is not None:
300
+ raise ValueError("past_key_values was provided while use_cache=False.")
301
+
302
+ def _validate_position_sources(
303
+ self,
304
+ use_cache: bool,
305
+ attention_mask: torch.Tensor | None,
306
+ position_ids: torch.Tensor | None,
307
+ ) -> None:
308
+ """Validate that cached forward has a truthful source of positions."""
309
+ if use_cache and attention_mask is None and position_ids is None:
310
+ raise ValueError(
311
+ "Cached forward requires either position_ids or attention_mask."
312
+ )
313
+
314
+ def _validate_hf_boundary(
315
+ self,
316
+ output_attentions: bool | None,
317
+ return_dict: bool | None,
318
+ inputs_embeds: torch.Tensor | None,
319
+ cache_position: torch.Tensor | None,
320
+ extra_kwargs: dict[str, Any],
321
+ ) -> None:
322
+ """Validate unsupported HuggingFace-facing wrapper inputs."""
323
+ if output_attentions:
324
+ raise NotImplementedError(
325
+ "ShramForCausalLM does not expose output_attentions."
326
+ )
327
+ if return_dict is False:
328
+ raise ValueError(
329
+ "return_dict=False is not supported. "
330
+ "ShramForCausalLM always returns ShramCausalLMOutput."
331
+ )
332
+ if inputs_embeds is not None:
333
+ raise ValueError(
334
+ "inputs_embeds is not supported at the SHRAM wrapper boundary. "
335
+ "Pass input_ids instead."
336
+ )
337
+ if extra_kwargs:
338
+ unsupported = ", ".join(sorted(extra_kwargs))
339
+ raise TypeError(
340
+ f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
341
+ )
342
+
343
+ def _standardize_full_attention_mask(
344
+ self,
345
+ input_ids: torch.Tensor,
346
+ attention_mask: torch.Tensor | None,
347
+ ) -> torch.BoolTensor:
348
+ """Return a concrete full-sequence boolean attention mask."""
349
+ if attention_mask is None:
350
+ return torch.ones_like(input_ids, dtype=torch.bool)
351
+ return attention_mask.to(dtype=torch.bool)
352
+
353
+ def _resolve_current_position_ids(
354
+ self,
355
+ input_ids: torch.Tensor,
356
+ position_ids: torch.Tensor | None,
357
+ full_attention_mask: torch.BoolTensor,
358
+ ) -> torch.LongTensor:
359
+ """Resolve concrete current-step position IDs for the backbone."""
360
+ if position_ids is not None:
361
+ return position_ids.to(dtype=torch.long)
362
+
363
+ full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
364
+ full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
365
+ current_length = input_ids.shape[1]
366
+ return full_position_ids[:, -current_length:]
367
+
368
+ def forward(
369
+ self,
370
+ input_ids: torch.Tensor,
371
+ attention_mask: torch.Tensor | None = None,
372
+ position_ids: torch.Tensor | None = None,
373
+ past_key_values: Cache | None = None,
374
+ use_cache: bool | None = None,
375
+ output_hidden_states: bool | None = None,
376
+ labels: torch.Tensor | None = None,
377
+ return_dict: bool | None = None,
378
+ ce_weight: float = 1.0,
379
+ load_balance_weight: float = 0.01,
380
+ **kwargs: Any,
381
+ ) -> ShramCausalLMOutput:
382
+ """Run the SHRAM causal language model wrapper.
383
+
384
+ Args:
385
+ input_ids: Current token IDs of shape ``(batch, seq_len)``.
386
+ attention_mask: Optional full 2D mask of shape
387
+ ``(batch, total_seq_len)``. The wrapper slices its recent chunk
388
+ to produce the current semantic liveness mask expected by the
389
+ backbone.
390
+ position_ids: Optional current-step position IDs of shape
391
+ ``(batch, seq_len)``. In ordinary HuggingFace generation this is
392
+ already the current-step tensor when it reaches ``forward()``.
393
+ past_key_values: Optional SHRAM cache. Required when
394
+ ``use_cache=True``.
395
+ use_cache: Whether to use and return a cache. Defaults to
396
+ ``config.use_cache``.
397
+ output_hidden_states: Whether to return backbone hidden states.
398
+ Defaults to ``config.output_hidden_states``.
399
+ labels: Optional target token IDs of shape ``(batch, seq_len)``.
400
+ return_dict: Must be ``True`` or ``None``.
401
+ ce_weight: Weight applied to the cross-entropy loss when combining with
402
+ the load-balance loss. Default 1.0.
403
+ load_balance_weight: Weight applied to the load-balance auxiliary loss.
404
+ Default 0.01, matching the paper's recommendation.
405
+ **kwargs: Unsupported HuggingFace kwargs fail explicitly.
406
+
407
+ Returns:
408
+ ``ShramCausalLMOutput`` with:
409
+ - ``logits`` of shape ``(batch, seq_len, vocab_size)``,
410
+ - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss``
411
+ when labels are provided (``None`` otherwise),
412
+ - ``ce_loss`` — raw unweighted cross-entropy loss for logging,
413
+ - ``past_key_values`` as the active ``ShramCache`` or ``None``,
414
+ - ``hidden_states`` when requested,
415
+ - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
416
+ - detached ``max_vio`` from the backbone.
417
+ """
418
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
419
+ output_hidden_states = (
420
+ output_hidden_states
421
+ if output_hidden_states is not None
422
+ else self.config.output_hidden_states
423
+ )
424
+
425
+ inputs_embeds = kwargs.pop("inputs_embeds", None)
426
+ output_attentions = kwargs.pop("output_attentions", None)
427
+ cache_position = kwargs.pop("cache_position", None)
428
+
429
+ # ------------------------------------------------------------------
430
+ # Validation zone.
431
+ #
432
+ # The wrapper boundary is where HuggingFace-facing inputs are judged
433
+ # for truthfulness before any internal work begins. These checks are
434
+ # intentionally front-loaded so the core logic below can assume one
435
+ # coherent interpretation of the call rather than defensively checking
436
+ # shapes, cache policy, or unsupported HF knobs at the point of use.
437
+ # This keeps the main sequence readable while ensuring invalid states
438
+ # fail before they can silently contaminate backbone execution.
439
+ # ------------------------------------------------------------------
440
+ self._validate_input_ids(input_ids)
441
+ self._validate_attention_mask(input_ids, attention_mask)
442
+ self._validate_position_ids(input_ids, position_ids)
443
+ self._validate_labels(input_ids, labels)
444
+ self._validate_cache_inputs(use_cache, past_key_values)
445
+ self._validate_position_sources(use_cache, attention_mask, position_ids)
446
+ self._validate_hf_boundary(
447
+ output_attentions=output_attentions,
448
+ return_dict=return_dict,
449
+ inputs_embeds=inputs_embeds,
450
+ cache_position=cache_position,
451
+ extra_kwargs=kwargs,
452
+ )
453
+
454
+ # ------------------------------------------------------------------
455
+ # Standardization zone.
456
+ #
457
+ # HuggingFace and SHRAM use different boundary conventions: generation
458
+ # carries a full-sequence 2D attention mask, while the SHRAM backbone
459
+ # wants a current-step active mask and concrete current position IDs.
460
+ # This zone collapses those wrapper-facing conventions into one valid
461
+ # backbone-facing state. After this point the core no longer reasons
462
+ # about optional or ambiguous input forms; it works only with concrete
463
+ # tensors whose semantics are already fixed.
464
+ # ------------------------------------------------------------------
465
+ full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask(
466
+ input_ids=input_ids,
467
+ attention_mask=attention_mask,
468
+ )
469
+ current_length: int = input_ids.shape[1]
470
+ current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
471
+ current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
472
+ input_ids=input_ids,
473
+ position_ids=position_ids,
474
+ full_attention_mask=full_attention_mask,
475
+ )
476
+ shram_cache: ShramCache | None = past_key_values if use_cache else None
477
+
478
+ # ------------------------------------------------------------------
479
+ # Core wrapper responsibilities.
480
+ #
481
+ # The wrapper's primary job is kept visible here: convert token IDs to
482
+ # embeddings, delegate transformer computation to ShramModel, project
483
+ # hidden states back to vocabulary logits, optionally compute the
484
+ # wrapper-level shifted next-token loss, and return the HuggingFace-
485
+ # facing output object. The backbone remains responsible only for
486
+ # transformer semantics; token/vocabulary/loss concerns stay here.
487
+ # ------------------------------------------------------------------
488
+ token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids)
489
+ backbone_outputs = self.model(
490
+ inputs_embeds=token_embeddings,
491
+ position_ids=current_position_ids,
492
+ active_mask=current_active_mask,
493
+ cache=shram_cache,
494
+ output_hidden_states=output_hidden_states,
495
+ )
496
+
497
+ logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"])
498
+
499
+ ce_loss: torch.FloatTensor | None = None
500
+ loss: torch.FloatTensor | None = None
501
+ if labels is not None:
502
+ shift_logits = logits[:, :-1, :].contiguous()
503
+ shift_labels = labels[:, 1:].contiguous()
504
+ ce_loss = nn.functional.cross_entropy(
505
+ shift_logits.view(-1, self.config.vocab_size),
506
+ shift_labels.view(-1),
507
+ )
508
+ loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"]
509
+
510
+ return ShramCausalLMOutput(
511
+ loss=loss,
512
+ ce_loss=ce_loss,
513
+ logits=logits,
514
+ past_key_values=backbone_outputs["past_key_values"],
515
+ hidden_states=backbone_outputs["hidden_states"],
516
+ load_balance_loss=backbone_outputs["load_balance_loss"],
517
+ max_vio=backbone_outputs["max_vio"],
518
  )
mlp.py CHANGED
@@ -1,52 +1,52 @@
1
- """SwiGLU feed-forward sublayer.
2
-
3
- SwiGLU is a gated linear unit variant that multiplies a SiLU-gated projection
4
- element-wise against a separate up-projection:
5
-
6
- output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
7
-
8
- The gating mechanism gives the network more expressive control over which features
9
- to propagate than a plain two-matrix FFN. It requires three weight matrices instead
10
- of two, which is why intermediate_size in Llama 3 is set lower than the 4× multiplier
11
- typical of two-matrix FFNs — the total parameter count remains comparable.
12
-
13
- SiLU is used as the gate activation because Llama 3 committed to SwiGLU specifically
14
- — a fixed architectural choice.
15
- """
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- from transformers import PretrainedConfig
21
-
22
-
23
- class SwiGLUMLP(nn.Module):
24
- """SwiGLU feed-forward sublayer.
25
-
26
- Implements the three-matrix SwiGLU FFN used in Llama 3:
27
-
28
- output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
29
-
30
- No bias on any projection. SiLU as the gate activation is an architectural
31
- constant — it is what defines SwiGLU specifically.
32
-
33
- Args:
34
- config: Model config. Must expose ``hidden_size`` and ``intermediate_size``.
35
- """
36
-
37
- def __init__(self, config: PretrainedConfig) -> None:
38
- super().__init__()
39
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
40
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
41
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
42
-
43
- def forward(self, x: torch.Tensor) -> torch.Tensor:
44
- """Apply the SwiGLU feed-forward transformation.
45
-
46
- Args:
47
- x: Input tensor of shape (batch, seq_len, hidden_size).
48
-
49
- Returns:
50
- Output tensor of shape (batch, seq_len, hidden_size).
51
- """
52
- return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
 
1
+ """SwiGLU feed-forward sublayer.
2
+
3
+ SwiGLU is a gated linear unit variant that multiplies a SiLU-gated projection
4
+ element-wise against a separate up-projection:
5
+
6
+ output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
7
+
8
+ The gating mechanism gives the network more expressive control over which features
9
+ to propagate than a plain two-matrix FFN. It requires three weight matrices instead
10
+ of two, which is why intermediate_size in Llama 3 is set lower than the 4× multiplier
11
+ typical of two-matrix FFNs — the total parameter count remains comparable.
12
+
13
+ SiLU is used as the gate activation because Llama 3 committed to SwiGLU specifically
14
+ — a fixed architectural choice.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from transformers import PretrainedConfig
21
+
22
+
23
+ class SwiGLUMLP(nn.Module):
24
+ """SwiGLU feed-forward sublayer.
25
+
26
+ Implements the three-matrix SwiGLU FFN used in Llama 3:
27
+
28
+ output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
29
+
30
+ No bias on any projection. SiLU as the gate activation is an architectural
31
+ constant — it is what defines SwiGLU specifically.
32
+
33
+ Args:
34
+ config: Model config. Must expose ``hidden_size`` and ``intermediate_size``.
35
+ """
36
+
37
+ def __init__(self, config: PretrainedConfig) -> None:
38
+ super().__init__()
39
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
40
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
41
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ """Apply the SwiGLU feed-forward transformation.
45
+
46
+ Args:
47
+ x: Input tensor of shape (batch, seq_len, hidden_size).
48
+
49
+ Returns:
50
+ Output tensor of shape (batch, seq_len, hidden_size).
51
+ """
52
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
model.py CHANGED
@@ -1,132 +1,132 @@
1
- """Transformer backbone for Shram.
2
-
3
- ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed
4
- by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual
5
- representations. It has no knowledge of tokens, vocabulary, generation, or the
6
- HuggingFace causal-LM wrapper contract.
7
-
8
- Keeping the embedding out of the backbone is the correct convention and makes
9
- the backbone genuinely modality-agnostic. The token interface — embedding lookup,
10
- LM head, weight tying, and generation-facing naming conventions — belongs on the
11
- task wrapper (ShramForCausalLM), which is the only class that knows this
12
- backbone is being used for language modelling.
13
-
14
- The final RMSNorm is necessary because the decoder stack uses pre-norm throughout:
15
- each sublayer normalises its own input, leaving the residual stream itself
16
- unnormalised. After many layers of accumulated residuals, that stream arrives at
17
- the top with uncontrolled magnitude. The final norm brings it to a well-scaled
18
- state before any projection. Without it, the LM head would receive signals of
19
- arbitrary scale.
20
-
21
- Caching is caller-managed. If a ShramCache is provided, ShramModel threads the
22
- corresponding per-layer ShramLayerCache into each DecoderLayer and returns the
23
- same top-level ShramCache object in the output dict. If None is provided, no
24
- caching occurs.
25
-
26
- Returns a plain dict with keys:
27
- - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size)
28
- - "past_key_values": the ShramCache object passed in, or None
29
- - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
30
- - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
31
- - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
32
- """
33
-
34
- import torch
35
- import torch.nn as nn
36
-
37
- from .__cache__shram_cache import ShramCache
38
- from .configuration import ShramConfig
39
- from .decoder_layer import DecoderLayer
40
-
41
-
42
- class ShramModel(nn.Module):
43
- """Pure transformer backbone: decoder stack and final normalisation.
44
-
45
- Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size)
46
- and returns contextual representations of the same shape. No token embedding,
47
- vocabulary projection, or causal-LM lifecycle concerns.
48
-
49
- RoPE is applied inside each attention layer. Positional information is
50
- encoded in the relationship between Q and K, not added to the residual
51
- stream, so the backbone is agnostic to how positions are represented.
52
-
53
- Args:
54
- config: Model configuration. Must be a ``ShramConfig`` instance.
55
- """
56
-
57
- def __init__(self, config: ShramConfig) -> None:
58
- super().__init__()
59
- self.config = config
60
- self.layers = nn.ModuleList(
61
- [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
62
- )
63
- self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
-
65
- def forward(
66
- self,
67
- inputs_embeds: torch.Tensor,
68
- position_ids: torch.Tensor,
69
- active_mask: torch.Tensor,
70
- cache: ShramCache | None = None,
71
- output_hidden_states: bool = False,
72
- ) -> dict:
73
- """Run the transformer stack over a batch of pre-embedded sequences.
74
-
75
- Args:
76
- inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size).
77
- position_ids: Absolute positions of shape (batch, seq_len). Required.
78
- Must be provided explicitly by the caller — this module does not
79
- infer positions from cache state.
80
- active_mask: Current-chunk active mask of shape (batch, seq_len),
81
- where True means the token is semantically live. Forwarded
82
- unchanged to every decoder layer.
83
- cache: Optional top-level ShramCache. When provided, each DecoderLayer
84
- receives its own layer-local cache via ``cache.layers[layer_idx]``.
85
- The top-level cache object is updated in place and returned unchanged.
86
- output_hidden_states: When True, the output dict includes a tuple of
87
- per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out),
88
- collected before the final norm.
89
-
90
- Returns:
91
- Plain dict with keys:
92
- - ``"last_hidden_state"``: normed backbone output,
93
- shape (batch, seq_len, hidden_size).
94
- - ``"past_key_values"``: the cache object passed in, or None.
95
- - ``"hidden_states"``: tuple of per-layer activations (including
96
- inputs_embeds as position 0) if ``output_hidden_states`` is True,
97
- else None. Collected before the final norm so each entry reflects the
98
- unnormalised residual stream at that depth.
99
- - ``"load_balance_loss"``: scalar sum of per-layer SHRAM
100
- load-balance losses.
101
- - ``"max_vio"``: detached scalar maximum routing-imbalance across
102
- all decoder layers. Zero means perfectly balanced routing across
103
- every layer; higher values identify the worst-case head imbalance.
104
- """
105
- hidden_states = inputs_embeds
106
- all_hidden_states = (hidden_states,) if output_hidden_states else None
107
- total_load_balance_loss = inputs_embeds.new_zeros(())
108
- max_vio = inputs_embeds.new_zeros(())
109
-
110
- for layer_idx, layer in enumerate(self.layers):
111
- layer_cache = None if cache is None else cache.layers[layer_idx]
112
- hidden_states, layer_load_balance_loss, layer_max_vio = layer(
113
- hidden_states,
114
- position_ids,
115
- active_mask,
116
- cache=layer_cache,
117
- )
118
- total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss
119
- max_vio = torch.maximum(max_vio, layer_max_vio)
120
-
121
- if output_hidden_states:
122
- all_hidden_states = all_hidden_states + (hidden_states,)
123
-
124
- hidden_states = self.norm(hidden_states)
125
-
126
- return {
127
- "last_hidden_state": hidden_states,
128
- "past_key_values": cache,
129
- "hidden_states": all_hidden_states,
130
- "load_balance_loss": total_load_balance_loss,
131
- "max_vio": max_vio,
132
  }
 
1
+ """Transformer backbone for Shram.
2
+
3
+ ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed
4
+ by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual
5
+ representations. It has no knowledge of tokens, vocabulary, generation, or the
6
+ HuggingFace causal-LM wrapper contract.
7
+
8
+ Keeping the embedding out of the backbone is the correct convention and makes
9
+ the backbone genuinely modality-agnostic. The token interface — embedding lookup,
10
+ LM head, weight tying, and generation-facing naming conventions — belongs on the
11
+ task wrapper (ShramForCausalLM), which is the only class that knows this
12
+ backbone is being used for language modelling.
13
+
14
+ The final RMSNorm is necessary because the decoder stack uses pre-norm throughout:
15
+ each sublayer normalises its own input, leaving the residual stream itself
16
+ unnormalised. After many layers of accumulated residuals, that stream arrives at
17
+ the top with uncontrolled magnitude. The final norm brings it to a well-scaled
18
+ state before any projection. Without it, the LM head would receive signals of
19
+ arbitrary scale.
20
+
21
+ Caching is caller-managed. If a ShramCache is provided, ShramModel threads the
22
+ corresponding per-layer ShramLayerCache into each DecoderLayer and returns the
23
+ same top-level ShramCache object in the output dict. If None is provided, no
24
+ caching occurs.
25
+
26
+ Returns a plain dict with keys:
27
+ - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size)
28
+ - "past_key_values": the ShramCache object passed in, or None
29
+ - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
30
+ - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
31
+ - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
32
+ """
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+
37
+ from .__cache__shram_cache import ShramCache
38
+ from .configuration import ShramConfig
39
+ from .decoder_layer import DecoderLayer
40
+
41
+
42
+ class ShramModel(nn.Module):
43
+ """Pure transformer backbone: decoder stack and final normalisation.
44
+
45
+ Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size)
46
+ and returns contextual representations of the same shape. No token embedding,
47
+ vocabulary projection, or causal-LM lifecycle concerns.
48
+
49
+ RoPE is applied inside each attention layer. Positional information is
50
+ encoded in the relationship between Q and K, not added to the residual
51
+ stream, so the backbone is agnostic to how positions are represented.
52
+
53
+ Args:
54
+ config: Model configuration. Must be a ``ShramConfig`` instance.
55
+ """
56
+
57
+ def __init__(self, config: ShramConfig) -> None:
58
+ super().__init__()
59
+ self.config = config
60
+ self.layers = nn.ModuleList(
61
+ [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
62
+ )
63
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
+
65
+ def forward(
66
+ self,
67
+ inputs_embeds: torch.Tensor,
68
+ position_ids: torch.Tensor,
69
+ active_mask: torch.Tensor,
70
+ cache: ShramCache | None = None,
71
+ output_hidden_states: bool = False,
72
+ ) -> dict:
73
+ """Run the transformer stack over a batch of pre-embedded sequences.
74
+
75
+ Args:
76
+ inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size).
77
+ position_ids: Absolute positions of shape (batch, seq_len). Required.
78
+ Must be provided explicitly by the caller — this module does not
79
+ infer positions from cache state.
80
+ active_mask: Current-chunk active mask of shape (batch, seq_len),
81
+ where True means the token is semantically live. Forwarded
82
+ unchanged to every decoder layer.
83
+ cache: Optional top-level ShramCache. When provided, each DecoderLayer
84
+ receives its own layer-local cache via ``cache.layers[layer_idx]``.
85
+ The top-level cache object is updated in place and returned unchanged.
86
+ output_hidden_states: When True, the output dict includes a tuple of
87
+ per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out),
88
+ collected before the final norm.
89
+
90
+ Returns:
91
+ Plain dict with keys:
92
+ - ``"last_hidden_state"``: normed backbone output,
93
+ shape (batch, seq_len, hidden_size).
94
+ - ``"past_key_values"``: the cache object passed in, or None.
95
+ - ``"hidden_states"``: tuple of per-layer activations (including
96
+ inputs_embeds as position 0) if ``output_hidden_states`` is True,
97
+ else None. Collected before the final norm so each entry reflects the
98
+ unnormalised residual stream at that depth.
99
+ - ``"load_balance_loss"``: scalar sum of per-layer SHRAM
100
+ load-balance losses.
101
+ - ``"max_vio"``: detached scalar maximum routing-imbalance across
102
+ all decoder layers. Zero means perfectly balanced routing across
103
+ every layer; higher values identify the worst-case head imbalance.
104
+ """
105
+ hidden_states = inputs_embeds
106
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
107
+ total_load_balance_loss = inputs_embeds.new_zeros(())
108
+ max_vio = inputs_embeds.new_zeros(())
109
+
110
+ for layer_idx, layer in enumerate(self.layers):
111
+ layer_cache = None if cache is None else cache.layers[layer_idx]
112
+ hidden_states, layer_load_balance_loss, layer_max_vio = layer(
113
+ hidden_states,
114
+ position_ids,
115
+ active_mask,
116
+ cache=layer_cache,
117
+ )
118
+ total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss
119
+ max_vio = torch.maximum(max_vio, layer_max_vio)
120
+
121
+ if output_hidden_states:
122
+ all_hidden_states = all_hidden_states + (hidden_states,)
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+
126
+ return {
127
+ "last_hidden_state": hidden_states,
128
+ "past_key_values": cache,
129
+ "hidden_states": all_hidden_states,
130
+ "load_balance_loss": total_load_balance_loss,
131
+ "max_vio": max_vio,
132
  }
rope.py CHANGED
@@ -1,291 +1,291 @@
1
- """Rotary Position Embeddings (RoPE).
2
-
3
- RoPE encodes position in the *relationship* between query and key vectors. When the
4
- attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
5
- a score that depends only on the relative distance — not on absolute positions.
6
-
7
- Two modes are supported:
8
-
9
- default Standard RoPE with base frequency b. Each dimension pair d is assigned
10
- frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
11
- scaling A_rope = 1.
12
-
13
- yarn YaRN frequency interpolation for long-context extrapolation (Peng et al.,
14
- "YaRN: Efficient Context Window Extension of Large Language Models", 2023,
15
- §A.2). Three frequency regimes:
16
- - Low-frequency dimensions (r < α): fully interpolated by scale s.
17
- These dimensions have long wavelengths relative to the training window
18
- and must be compressed to avoid out-of-distribution positions.
19
- - High-frequency dimensions (r > β): left unchanged. Short-wavelength
20
- dimensions already encode relative position accurately at any scale.
21
- - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
22
- Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
23
- standard RoPE.
24
-
25
- Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
26
- parameters — no shared instance, no config reading. See Unit 5.A design decisions.
27
-
28
- Cache sharing: all instances with identical parameters share one cos/sin table via a
29
- class-level registry. The first instance that needs a particular (parameters, seq_len,
30
- device, dtype) combination builds the table; all subsequent instances reference it
31
- directly. This avoids redundant builds across the num_hidden_layers instances that
32
- share the same parametrisation.
33
- """
34
-
35
- import math
36
-
37
- import torch
38
- import torch.nn as nn
39
-
40
-
41
- # ---------------------------------------------------------------------------
42
- # Rotation helper
43
- # ---------------------------------------------------------------------------
44
-
45
- def _rotate_half(x: torch.Tensor) -> torch.Tensor:
46
- """Apply the 90° rotation used in the RoPE update formula.
47
-
48
- Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
49
- Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
50
- on each consecutive pair of dimensions, matching the block-diagonal operator
51
- R^u_{Θ,p} in the paper.
52
- """
53
- d = x.shape[-1] // 2
54
- x1, x2 = x[..., :d], x[..., d:]
55
- return torch.cat([-x2, x1], dim=-1)
56
-
57
-
58
- # ---------------------------------------------------------------------------
59
- # RotaryEmbedding
60
- # ---------------------------------------------------------------------------
61
-
62
- class RotaryEmbedding(nn.Module):
63
- """Rotary Position Embeddings with explicit mode and parameter control.
64
-
65
- Each caller constructs its own instance with the exact parameters it needs.
66
- h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
67
- config object is read inside this module.
68
-
69
- The cos/sin cache is built lazily on the first forward call and extended
70
- automatically when a longer sequence is encountered. Instances with identical
71
- parameters share one cache via the class-level ``_cache`` registry,
72
- avoiding redundant computation across decoder layers.
73
-
74
- Args:
75
- mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
76
- head_dim: Per-head embedding dimension ``u``. Must be even.
77
- theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
78
- initial_seq_length: ``C_train`` — context length the model was trained at.
79
- Required for ``mode="yarn"``.
80
- dilation: Scale factor ``s = C_target / C_train`` — how much the context
81
- window is extended beyond training length. Required for ``mode="yarn"``.
82
- When ``dilation=1.0``, YaRN reduces to standard RoPE.
83
- alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
84
- interpolated. Required for ``mode="yarn"``.
85
- beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
86
- unchanged. Required for ``mode="yarn"``.
87
- device: Optional device for initial buffer placement.
88
-
89
- Raises:
90
- NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
91
- ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``,
92
- ``dilation``, ``alpha``, ``beta`` are absent.
93
- """
94
-
95
- # Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table).
96
- # Shared across all RotaryEmbedding instances in the process. Keys include device
97
- # and dtype so that tables built on different devices or in different precisions
98
- # are stored independently.
99
- _cache: dict = {}
100
-
101
- def __init__(
102
- self,
103
- mode: str,
104
- head_dim: int,
105
- theta: float,
106
- initial_seq_length: int | None = None,
107
- dilation: float | None = None,
108
- alpha: float | None = None,
109
- beta: float | None = None,
110
- device: torch.device | None = None,
111
- ) -> None:
112
- super().__init__()
113
-
114
- self._validate_mode(mode)
115
- self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta)
116
- self.mode = mode
117
-
118
- # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
119
- # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
120
- # so rotation_freqs has head_dim/2 entries.
121
- d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
122
- base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u}
123
-
124
- if mode == "default":
125
- rotation_freqs = base_freqs
126
- self.attention_scaling: float = 1.0
127
-
128
- else: # yarn
129
- s = dilation
130
-
131
- # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
132
- # function to classify each dimension into one of three regimes.
133
- normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi)
134
-
135
- # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
136
- # linear blend between α and β.
137
- blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)
138
-
139
- # θ_d' = (1 − γ) · θ_d / s + γ · θ_d
140
- rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs
141
-
142
- # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
143
- self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
144
-
145
- # freq_key uniquely identifies the parameter set that produced rotation_freqs.
146
- # Used as the primary component of the cache registry key.
147
- if mode == "default":
148
- self._freq_key: tuple = ("default", head_dim, float(theta))
149
- else:
150
- self._freq_key = (
151
- "yarn", head_dim, float(theta),
152
- int(initial_seq_length), float(dilation),
153
- float(alpha), float(beta),
154
- )
155
-
156
- # rotation_freqs is a non-persistent buffer so it moves with the model across
157
- # devices via .to() / .cuda() without appearing in saved checkpoints.
158
- # It is stored per-instance rather than in the shared cache because it is
159
- # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
160
- # it is used to build. The meaningful sharing win is on those tables.
161
- self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)
162
-
163
- # Cache tensors are plain instance attributes (not registered buffers) so that
164
- # sharing across identically-parametrised instances survives .to() calls.
165
- # Registered buffers are copied on device move; plain attributes are aliased,
166
- # preserving the shared-tensor identity that the cache design depends on.
167
- self._cos_cached: torch.Tensor | None = None
168
- self._sin_cached: torch.Tensor | None = None
169
-
170
- # ---------------------------------------------------------------------------
171
- # Validation helpers
172
- # ---------------------------------------------------------------------------
173
-
174
- @staticmethod
175
- def _validate_mode(mode: str) -> None:
176
- """Raise NotImplementedError if mode is not a supported value."""
177
- if mode not in {"default", "yarn"}:
178
- raise NotImplementedError(
179
- f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
180
- )
181
-
182
- @staticmethod
183
- def _validate_yarn_params(
184
- mode: str,
185
- initial_seq_length: int | None,
186
- dilation: float | None,
187
- alpha: float | None,
188
- beta: float | None,
189
- ) -> None:
190
- """Raise ValueError if mode='yarn' and any required parameter is absent."""
191
- if mode != "yarn":
192
- return
193
- missing = [
194
- name for name, val in [
195
- ("initial_seq_length", initial_seq_length),
196
- ("dilation", dilation),
197
- ("alpha", alpha),
198
- ("beta", beta),
199
- ]
200
- if val is None
201
- ]
202
- if missing:
203
- raise ValueError(f"mode='yarn' requires {missing}.")
204
-
205
- # ---------------------------------------------------------------------------
206
- # Cache management
207
- # ---------------------------------------------------------------------------
208
-
209
- def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
210
- """Build the cos/sin table to cover positions [0, seq_len).
211
-
212
- Checks the class-level registry first. If a table already exists for this
213
- exact (parameters, seq_len, device, dtype) combination it is reused directly;
214
- otherwise it is computed and stored. The instance attributes are pointed at
215
- the registry entry so that all layers sharing the same parametrisation
216
- reference the same tensor.
217
- """
218
- cache_key = (self._freq_key, seq_len, str(device), str(dtype))
219
-
220
- if cache_key not in RotaryEmbedding._cache:
221
- positions = torch.arange(seq_len, device=device, dtype=torch.float32)
222
- # outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim)
223
- freqs = torch.outer(
224
- positions,
225
- self.rotation_freqs.to(device=device, dtype=torch.float32),
226
- )
227
- angle_embedding = torch.cat((freqs, freqs), dim=-1)
228
- RotaryEmbedding._cache[cache_key] = (
229
- angle_embedding.cos().to(dtype),
230
- angle_embedding.sin().to(dtype),
231
- )
232
-
233
- self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]
234
-
235
- def forward(
236
- self,
237
- q: torch.Tensor,
238
- k: torch.Tensor,
239
- position_ids: torch.Tensor,
240
- ) -> tuple[torch.Tensor, torch.Tensor, float]:
241
- """Apply rotary embeddings to query and key tensors.
242
-
243
- The cos/sin cache is extended lazily when position_ids reference positions
244
- beyond its current length, or when the device or dtype has changed.
245
-
246
- ``position_ids`` may be any integer tensor shape. Its values are valid
247
- position indices into the cos/sin cache:
248
-
249
- - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
250
- - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
251
-
252
- When q/k have head dimensions absent from position_ids, broadcast dimensions
253
- are inserted automatically at dim 1.
254
-
255
- Args:
256
- q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
257
- k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
258
- position_ids: Integer positions of shape (batch, *pos_dims).
259
-
260
- Returns:
261
- Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
262
- 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
263
- apply to attention logits before softmax.
264
- """
265
- seq_len = int(position_ids.max().item()) + 1
266
-
267
- # The cache is valid when it exists, covers all positions referenced by
268
- # position_ids, and matches q's dtype and device. Each condition is named
269
- # separately so the rebuild trigger is readable rather than a compound predicate.
270
- cache_missing = self._cos_cached is None
271
- cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
272
- wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
273
- wrong_device = not cache_missing and self._cos_cached.device != q.device
274
-
275
- if cache_missing or cache_too_short or wrong_dtype or wrong_device:
276
- self._extend_cache(seq_len, device=q.device, dtype=q.dtype)
277
-
278
- cos = self._cos_cached[position_ids]
279
- sin = self._sin_cached[position_ids]
280
-
281
- # Insert broadcast dimensions for any head axes present in q/k but absent
282
- # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
283
- # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
284
- while cos.ndim < q.ndim:
285
- cos = cos.unsqueeze(1)
286
- sin = sin.unsqueeze(1)
287
-
288
- q_rotated = q * cos + _rotate_half(q) * sin
289
- k_rotated = k * cos + _rotate_half(k) * sin
290
-
291
- return q_rotated, k_rotated, self.attention_scaling
 
1
+ """Rotary Position Embeddings (RoPE).
2
+
3
+ RoPE encodes position in the *relationship* between query and key vectors. When the
4
+ attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
5
+ a score that depends only on the relative distance — not on absolute positions.
6
+
7
+ Two modes are supported:
8
+
9
+ default Standard RoPE with base frequency b. Each dimension pair d is assigned
10
+ frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
11
+ scaling A_rope = 1.
12
+
13
+ yarn YaRN frequency interpolation for long-context extrapolation (Peng et al.,
14
+ "YaRN: Efficient Context Window Extension of Large Language Models", 2023,
15
+ §A.2). Three frequency regimes:
16
+ - Low-frequency dimensions (r < α): fully interpolated by scale s.
17
+ These dimensions have long wavelengths relative to the training window
18
+ and must be compressed to avoid out-of-distribution positions.
19
+ - High-frequency dimensions (r > β): left unchanged. Short-wavelength
20
+ dimensions already encode relative position accurately at any scale.
21
+ - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
22
+ Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
23
+ standard RoPE.
24
+
25
+ Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
26
+ parameters — no shared instance, no config reading. See Unit 5.A design decisions.
27
+
28
+ Cache sharing: all instances with identical parameters share one cos/sin table via a
29
+ class-level registry. The first instance that needs a particular (parameters, seq_len,
30
+ device, dtype) combination builds the table; all subsequent instances reference it
31
+ directly. This avoids redundant builds across the num_hidden_layers instances that
32
+ share the same parametrisation.
33
+ """
34
+
35
+ import math
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Rotation helper
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
46
+ """Apply the 90° rotation used in the RoPE update formula.
47
+
48
+ Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
49
+ Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
50
+ on each consecutive pair of dimensions, matching the block-diagonal operator
51
+ R^u_{Θ,p} in the paper.
52
+ """
53
+ d = x.shape[-1] // 2
54
+ x1, x2 = x[..., :d], x[..., d:]
55
+ return torch.cat([-x2, x1], dim=-1)
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # RotaryEmbedding
60
+ # ---------------------------------------------------------------------------
61
+
62
+ class RotaryEmbedding(nn.Module):
63
+ """Rotary Position Embeddings with explicit mode and parameter control.
64
+
65
+ Each caller constructs its own instance with the exact parameters it needs.
66
+ h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
67
+ config object is read inside this module.
68
+
69
+ The cos/sin cache is built lazily on the first forward call and extended
70
+ automatically when a longer sequence is encountered. Instances with identical
71
+ parameters share one cache via the class-level ``_cache`` registry,
72
+ avoiding redundant computation across decoder layers.
73
+
74
+ Args:
75
+ mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
76
+ head_dim: Per-head embedding dimension ``u``. Must be even.
77
+ theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
78
+ initial_seq_length: ``C_train`` — context length the model was trained at.
79
+ Required for ``mode="yarn"``.
80
+ dilation: Scale factor ``s = C_target / C_train`` — how much the context
81
+ window is extended beyond training length. Required for ``mode="yarn"``.
82
+ When ``dilation=1.0``, YaRN reduces to standard RoPE.
83
+ alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
84
+ interpolated. Required for ``mode="yarn"``.
85
+ beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
86
+ unchanged. Required for ``mode="yarn"``.
87
+ device: Optional device for initial buffer placement.
88
+
89
+ Raises:
90
+ NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
91
+ ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``,
92
+ ``dilation``, ``alpha``, ``beta`` are absent.
93
+ """
94
+
95
+ # Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table).
96
+ # Shared across all RotaryEmbedding instances in the process. Keys include device
97
+ # and dtype so that tables built on different devices or in different precisions
98
+ # are stored independently.
99
+ _cache: dict = {}
100
+
101
+ def __init__(
102
+ self,
103
+ mode: str,
104
+ head_dim: int,
105
+ theta: float,
106
+ initial_seq_length: int | None = None,
107
+ dilation: float | None = None,
108
+ alpha: float | None = None,
109
+ beta: float | None = None,
110
+ device: torch.device | None = None,
111
+ ) -> None:
112
+ super().__init__()
113
+
114
+ self._validate_mode(mode)
115
+ self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta)
116
+ self.mode = mode
117
+
118
+ # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
119
+ # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
120
+ # so rotation_freqs has head_dim/2 entries.
121
+ d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
122
+ base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u}
123
+
124
+ if mode == "default":
125
+ rotation_freqs = base_freqs
126
+ self.attention_scaling: float = 1.0
127
+
128
+ else: # yarn
129
+ s = dilation
130
+
131
+ # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
132
+ # function to classify each dimension into one of three regimes.
133
+ normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi)
134
+
135
+ # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
136
+ # linear blend between α and β.
137
+ blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)
138
+
139
+ # θ_d' = (1 − γ) · θ_d / s + γ · θ_d
140
+ rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs
141
+
142
+ # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
143
+ self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
144
+
145
+ # freq_key uniquely identifies the parameter set that produced rotation_freqs.
146
+ # Used as the primary component of the cache registry key.
147
+ if mode == "default":
148
+ self._freq_key: tuple = ("default", head_dim, float(theta))
149
+ else:
150
+ self._freq_key = (
151
+ "yarn", head_dim, float(theta),
152
+ int(initial_seq_length), float(dilation),
153
+ float(alpha), float(beta),
154
+ )
155
+
156
+ # rotation_freqs is a non-persistent buffer so it moves with the model across
157
+ # devices via .to() / .cuda() without appearing in saved checkpoints.
158
+ # It is stored per-instance rather than in the shared cache because it is
159
+ # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
160
+ # it is used to build. The meaningful sharing win is on those tables.
161
+ self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)
162
+
163
+ # Cache tensors are plain instance attributes (not registered buffers) so that
164
+ # sharing across identically-parametrised instances survives .to() calls.
165
+ # Registered buffers are copied on device move; plain attributes are aliased,
166
+ # preserving the shared-tensor identity that the cache design depends on.
167
+ self._cos_cached: torch.Tensor | None = None
168
+ self._sin_cached: torch.Tensor | None = None
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Validation helpers
172
+ # ---------------------------------------------------------------------------
173
+
174
+ @staticmethod
175
+ def _validate_mode(mode: str) -> None:
176
+ """Raise NotImplementedError if mode is not a supported value."""
177
+ if mode not in {"default", "yarn"}:
178
+ raise NotImplementedError(
179
+ f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
180
+ )
181
+
182
+ @staticmethod
183
+ def _validate_yarn_params(
184
+ mode: str,
185
+ initial_seq_length: int | None,
186
+ dilation: float | None,
187
+ alpha: float | None,
188
+ beta: float | None,
189
+ ) -> None:
190
+ """Raise ValueError if mode='yarn' and any required parameter is absent."""
191
+ if mode != "yarn":
192
+ return
193
+ missing = [
194
+ name for name, val in [
195
+ ("initial_seq_length", initial_seq_length),
196
+ ("dilation", dilation),
197
+ ("alpha", alpha),
198
+ ("beta", beta),
199
+ ]
200
+ if val is None
201
+ ]
202
+ if missing:
203
+ raise ValueError(f"mode='yarn' requires {missing}.")
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # Cache management
207
+ # ---------------------------------------------------------------------------
208
+
209
+ def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
210
+ """Build the cos/sin table to cover positions [0, seq_len).
211
+
212
+ Checks the class-level registry first. If a table already exists for this
213
+ exact (parameters, seq_len, device, dtype) combination it is reused directly;
214
+ otherwise it is computed and stored. The instance attributes are pointed at
215
+ the registry entry so that all layers sharing the same parametrisation
216
+ reference the same tensor.
217
+ """
218
+ cache_key = (self._freq_key, seq_len, str(device), str(dtype))
219
+
220
+ if cache_key not in RotaryEmbedding._cache:
221
+ positions = torch.arange(seq_len, device=device, dtype=torch.float32)
222
+ # outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim)
223
+ freqs = torch.outer(
224
+ positions,
225
+ self.rotation_freqs.to(device=device, dtype=torch.float32),
226
+ )
227
+ angle_embedding = torch.cat((freqs, freqs), dim=-1)
228
+ RotaryEmbedding._cache[cache_key] = (
229
+ angle_embedding.cos().to(dtype),
230
+ angle_embedding.sin().to(dtype),
231
+ )
232
+
233
+ self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]
234
+
235
+ def forward(
236
+ self,
237
+ q: torch.Tensor,
238
+ k: torch.Tensor,
239
+ position_ids: torch.Tensor,
240
+ ) -> tuple[torch.Tensor, torch.Tensor, float]:
241
+ """Apply rotary embeddings to query and key tensors.
242
+
243
+ The cos/sin cache is extended lazily when position_ids reference positions
244
+ beyond its current length, or when the device or dtype has changed.
245
+
246
+ ``position_ids`` may be any integer tensor shape. Its values are valid
247
+ position indices into the cos/sin cache:
248
+
249
+ - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
250
+ - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
251
+
252
+ When q/k have head dimensions absent from position_ids, broadcast dimensions
253
+ are inserted automatically at dim 1.
254
+
255
+ Args:
256
+ q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
257
+ k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
258
+ position_ids: Integer positions of shape (batch, *pos_dims).
259
+
260
+ Returns:
261
+ Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
262
+ 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
263
+ apply to attention logits before softmax.
264
+ """
265
+ seq_len = int(position_ids.max().item()) + 1
266
+
267
+ # The cache is valid when it exists, covers all positions referenced by
268
+ # position_ids, and matches q's dtype and device. Each condition is named
269
+ # separately so the rebuild trigger is readable rather than a compound predicate.
270
+ cache_missing = self._cos_cached is None
271
+ cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
272
+ wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
273
+ wrong_device = not cache_missing and self._cos_cached.device != q.device
274
+
275
+ if cache_missing or cache_too_short or wrong_dtype or wrong_device:
276
+ self._extend_cache(seq_len, device=q.device, dtype=q.dtype)
277
+
278
+ cos = self._cos_cached[position_ids]
279
+ sin = self._sin_cached[position_ids]
280
+
281
+ # Insert broadcast dimensions for any head axes present in q/k but absent
282
+ # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
283
+ # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
284
+ while cos.ndim < q.ndim:
285
+ cos = cos.unsqueeze(1)
286
+ sin = sin.unsqueeze(1)
287
+
288
+ q_rotated = q * cos + _rotate_half(q) * sin
289
+ k_rotated = k * cos + _rotate_half(k) * sin
290
+
291
+ return q_rotated, k_rotated, self.attention_scaling
tokenizer_config.json CHANGED
@@ -1,13 +1,14 @@
1
- {
2
- "add_prefix_space": false,
3
- "backend": "tokenizers",
4
- "bos_token": "<|endoftext|>",
5
- "eos_token": "<|endoftext|>",
6
- "errors": "replace",
7
- "is_local": false,
8
- "model_max_length": 1000000000000000019884624838656,
9
- "pad_token": "<|padding|>",
10
- "tokenizer_class": "GPTNeoXTokenizerFast",
11
- "trim_offsets": true,
12
- "unk_token": "<|endoftext|>"
 
13
  }
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "local_files_only": false,
9
+ "model_max_length": 1000000000000000019884624838656,
10
+ "pad_token": "<|padding|>",
11
+ "tokenizer_class": "GPTNeoXTokenizerFast",
12
+ "trim_offsets": true,
13
+ "unk_token": "<|endoftext|>"
14
  }