Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use smithblack-0/SHRAM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM
- SGLang
How to use smithblack-0/SHRAM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM
Update architecture and tokenizer
Browse files- README.md +5 -3
- __attention__bottlenecked_ensemble_attention.py +15 -4
- __attention__expert_packing.py +148 -142
- __attention__mosrah.py +14 -14
- __attention__positions_converter.py +14 -8
- __attention__router.py +23 -15
- __attention__shram.py +0 -13
- __attention__sliding_window_attention.py +2 -1
- __cache__mosrah_cache.py +69 -61
- __cache__shram_cache.py +14 -28
- __cache__shram_layer_cache.py +23 -35
- __cache__sliding_window_cache.py +1 -1
- __cache__slow_mosrah_cache.py +20 -39
- config.json +6 -4
- configuration.py +76 -5
- decoder_layer.py +2 -2
- huggingface.py +85 -8
- mlp.py +3 -3
- model.py +2 -2
- rope.py +55 -48
README.md
CHANGED
|
@@ -79,13 +79,15 @@ contains no weights. All values are overridable via kwargs.
|
|
| 79 |
| `attention_dropout` | 0.0 |
|
| 80 |
| `beta` | 32.0 |
|
| 81 |
| `dtype` | None |
|
|
|
|
| 82 |
| `head_dim` | 16 |
|
| 83 |
-
| `hidden_size` | 512 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
-
| `
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
|
|
|
|
|
|
| 87 |
| `mosrah_rope_theta` | 10000.0 |
|
| 88 |
-
| `
|
| 89 |
| `num_mosrah_heads` | 16 |
|
| 90 |
| `num_selected_heads` | 16 |
|
| 91 |
| `num_sliding_window_heads` | 16 |
|
|
|
|
| 79 |
| `attention_dropout` | 0.0 |
|
| 80 |
| `beta` | 32.0 |
|
| 81 |
| `dtype` | None |
|
| 82 |
+
| `embedding_width` | 512 |
|
| 83 |
| `head_dim` | 16 |
|
|
|
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
+
| `load_balance_p` | 2.0 |
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
| 87 |
+
| `mlp_width` | 1366 |
|
| 88 |
+
| `mosrah_overallocation_factor` | 2.0 |
|
| 89 |
| `mosrah_rope_theta` | 10000.0 |
|
| 90 |
+
| `num_decoder_layers` | 12 |
|
| 91 |
| `num_mosrah_heads` | 16 |
|
| 92 |
| `num_selected_heads` | 16 |
|
| 93 |
| `num_sliding_window_heads` | 16 |
|
__attention__bottlenecked_ensemble_attention.py
CHANGED
|
@@ -38,14 +38,14 @@ class BottleneckedEnsembleAttention(nn.Module):
|
|
| 38 |
|
| 39 |
Args:
|
| 40 |
config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
|
| 41 |
-
`head_dim`, `mosrah_rope_theta`, `
|
| 42 |
-
`
|
| 43 |
"""
|
| 44 |
|
| 45 |
def __init__(self, config: ShramConfig) -> None:
|
| 46 |
super().__init__()
|
| 47 |
|
| 48 |
-
self.hidden_size = config.
|
| 49 |
self.num_heads = config.num_mosrah_heads
|
| 50 |
self.head_dim = config.head_dim
|
| 51 |
|
|
@@ -68,11 +68,22 @@ class BottleneckedEnsembleAttention(nn.Module):
|
|
| 68 |
# BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
|
| 69 |
# this unit only consumes it. In training modes, dilation will be 1.0 and so
|
| 70 |
# no yarn dilation occurs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
self.rope = RotaryEmbedding(
|
| 72 |
mode="yarn",
|
| 73 |
head_dim=self.head_dim,
|
| 74 |
theta=config.mosrah_rope_theta,
|
| 75 |
-
|
| 76 |
dilation=config.scale,
|
| 77 |
alpha=config.alpha,
|
| 78 |
beta=config.beta,
|
|
|
|
| 38 |
|
| 39 |
Args:
|
| 40 |
config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
|
| 41 |
+
`head_dim`, `mosrah_rope_theta`, `inference_sequence_length`,
|
| 42 |
+
`scale`, `alpha`, and `beta`.
|
| 43 |
"""
|
| 44 |
|
| 45 |
def __init__(self, config: ShramConfig) -> None:
|
| 46 |
super().__init__()
|
| 47 |
|
| 48 |
+
self.hidden_size = config.embedding_width
|
| 49 |
self.num_heads = config.num_mosrah_heads
|
| 50 |
self.head_dim = config.head_dim
|
| 51 |
|
|
|
|
| 68 |
# BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
|
| 69 |
# this unit only consumes it. In training modes, dilation will be 1.0 and so
|
| 70 |
# no yarn dilation occurs.
|
| 71 |
+
#
|
| 72 |
+
# The required table size depends on position semantics:
|
| 73 |
+
# main_sequence — positions are original token positions, bounded by
|
| 74 |
+
# inference_sequence_length.
|
| 75 |
+
# semantic_sequence — positions are local per-expert slot indices, bounded
|
| 76 |
+
# by mosrah_packed_length.
|
| 77 |
+
maximum_rope_length = (
|
| 78 |
+
config.mosrah_packed_length
|
| 79 |
+
if config.rope_mode == "semantic_sequence"
|
| 80 |
+
else config.inference_sequence_length
|
| 81 |
+
)
|
| 82 |
self.rope = RotaryEmbedding(
|
| 83 |
mode="yarn",
|
| 84 |
head_dim=self.head_dim,
|
| 85 |
theta=config.mosrah_rope_theta,
|
| 86 |
+
maximum_sequence_length=maximum_rope_length,
|
| 87 |
dilation=config.scale,
|
| 88 |
alpha=config.alpha,
|
| 89 |
beta=config.beta,
|
__attention__expert_packing.py
CHANGED
|
@@ -3,26 +3,30 @@
|
|
| 3 |
This module implements the low-level token-choice -> expert-choice -> token-choice
|
| 4 |
conversion boundary specified in the paper. The externally visible behavior is fixed:
|
| 5 |
|
| 6 |
-
- setup_packing() prepares the auxiliary ordering data
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
- unpack_experts() restores token-choice ordering afterward.
|
| 9 |
|
| 10 |
Stable sort is a correctness requirement. It preserves causal ordering inside each
|
| 11 |
expert bucket, which is the foundation on which BEA's later triangular causal mask
|
| 12 |
is correct.
|
| 13 |
|
| 14 |
-
pack_experts() returns
|
| 15 |
-
be interchanged:
|
| 16 |
|
| 17 |
- unpacking_mask: marks every packed slot that contains a routed token copy,
|
| 18 |
live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
|
| 19 |
so its reshape invariant holds regardless of outer token liveness.
|
| 20 |
-
- active_mask: marks only the packed slots whose source
|
| 21 |
-
live. This is what BEA consumes for attention gating.
|
| 22 |
-
not influence sparse attention outputs.
|
| 23 |
"""
|
| 24 |
|
| 25 |
import torch
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
# ---------------------------------------------------------------------------
|
|
@@ -31,7 +35,7 @@ import torch
|
|
| 31 |
|
| 32 |
def setup_packing(
|
| 33 |
selected_heads: torch.Tensor,
|
| 34 |
-
) ->
|
| 35 |
"""Prepare the auxiliary ordering data used by pack/unpack.
|
| 36 |
|
| 37 |
Routing produces token-choice state I of shape (B, N, K): for each token, which
|
|
@@ -48,10 +52,11 @@ def setup_packing(
|
|
| 48 |
selected_heads: Routed token-choice head selections I of shape (B, N, K).
|
| 49 |
|
| 50 |
Returns:
|
| 51 |
-
|
| 52 |
-
- flattened_selected_heads: H of shape (B, N*K)
|
| 53 |
-
- permutation: stable expert-major permutation Pi of shape (B, N*K)
|
| 54 |
-
- inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
|
|
|
|
| 55 |
"""
|
| 56 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 57 |
flattened_selected_heads = selected_heads.reshape(
|
|
@@ -62,7 +67,11 @@ def setup_packing(
|
|
| 62 |
permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
|
| 63 |
inverse_permutation = torch.argsort(permutation, dim=-1)
|
| 64 |
|
| 65 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
# ---------------------------------------------------------------------------
|
|
@@ -70,27 +79,22 @@ def setup_packing(
|
|
| 70 |
# ---------------------------------------------------------------------------
|
| 71 |
|
| 72 |
def pack_experts(
|
| 73 |
-
|
| 74 |
-
|
| 75 |
selected_heads: torch.Tensor,
|
| 76 |
num_experts: int,
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 81 |
-
"""Pack token-choice hidden states into expert-choice padded form.
|
| 82 |
|
| 83 |
The paper's packing path has two jobs:
|
| 84 |
|
| 85 |
1. Convert routed token-choice copies into expert-major order.
|
| 86 |
2. Materialize that expert-major order into a padded tensor layout BEA can consume.
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
reordering those indices with Pi, then gathering hidden states, positions, and outer
|
| 92 |
-
liveness in that packed order. All three are carried through the same expert-major
|
| 93 |
-
rearrangement so they remain aligned in the packed frame.
|
| 94 |
|
| 95 |
Packed positions are sourced from the authoritative upstream position_ids tensor
|
| 96 |
rather than synthesized locally from arange(N). This preserves advanced positions
|
|
@@ -98,40 +102,40 @@ def pack_experts(
|
|
| 98 |
unchanged when position_ids is the ordinary sequential token positions.
|
| 99 |
|
| 100 |
Args:
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
selected_heads: Routed head selections I of shape (B, N, K).
|
| 104 |
num_experts: Total number of experts L.
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
semantically active in the packed sparse representation.
|
| 110 |
|
| 111 |
Returns:
|
| 112 |
Tuple of:
|
| 113 |
-
-
|
| 114 |
-
|
| 115 |
-
- unpacking_mask: of shape (B, L, T). True where a slot
|
| 116 |
-
routed token copy, live or dead. Always has exactly
|
| 117 |
-
Pass this to unpack_experts — not active_mask.
|
| 118 |
-
- active_mask: of shape (B, L, T). True only where a slot contains a
|
| 119 |
-
copy of a live outer token. Pass this to BEA for attention gating.
|
| 120 |
"""
|
| 121 |
-
batch_size, sequence_length,
|
| 122 |
-
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# -----------------------------------------------------------------------
|
| 125 |
# Reconstruct routed local source-token indices in token-choice order.
|
| 126 |
#
|
| 127 |
-
# The internal arange(N) is
|
| 128 |
-
#
|
| 129 |
-
#
|
| 130 |
-
# token-major routed-copy order.
|
| 131 |
# -----------------------------------------------------------------------
|
| 132 |
source_token_indices = torch.arange(
|
| 133 |
sequence_length,
|
| 134 |
-
device=
|
| 135 |
dtype=torch.long,
|
| 136 |
).view(1, sequence_length, 1).expand(
|
| 137 |
batch_size,
|
|
@@ -147,89 +151,71 @@ def pack_experts(
|
|
| 147 |
# Reorder source-token indices into expert-major order.
|
| 148 |
#
|
| 149 |
# Applying Pi yields the local source-token rows in the packed expert-major
|
| 150 |
-
# order required by the paper.
|
| 151 |
-
#
|
| 152 |
-
# liveness so all three remain aligned under the exact same packing
|
| 153 |
-
# transformation.
|
| 154 |
# -----------------------------------------------------------------------
|
| 155 |
sorted_source_indices = flattened_source_indices.gather(
|
| 156 |
dim=1,
|
| 157 |
index=permutation,
|
| 158 |
)
|
| 159 |
-
sorted_hidden_states = hidden_states.gather(
|
| 160 |
-
dim=1,
|
| 161 |
-
index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
|
| 162 |
-
)
|
| 163 |
-
sorted_positions = position_ids.gather(
|
| 164 |
-
dim=1,
|
| 165 |
-
index=sorted_source_indices,
|
| 166 |
-
)
|
| 167 |
-
sorted_active_mask = outer_active_mask.gather(
|
| 168 |
-
dim=1,
|
| 169 |
-
index=sorted_source_indices,
|
| 170 |
-
)
|
| 171 |
|
| 172 |
# -----------------------------------------------------------------------
|
| 173 |
-
# Count how many routed copies land in each expert bucket
|
|
|
|
| 174 |
#
|
| 175 |
-
# S[b, l] is the number of routed token copies assigned to expert l in
|
| 176 |
-
#
|
| 177 |
-
#
|
|
|
|
| 178 |
# -----------------------------------------------------------------------
|
| 179 |
-
tokens_per_expert =
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
)
|
| 183 |
-
max_tokens_per_expert = int(tokens_per_expert.max().item())
|
| 184 |
|
| 185 |
# -----------------------------------------------------------------------
|
| 186 |
-
# Construct the
|
| 187 |
#
|
| 188 |
# Each expert bucket is left-justified: if S[b, l] = s, then slots
|
| 189 |
-
# t = 0, ..., s-1 are
|
| 190 |
-
#
|
| 191 |
-
#
|
| 192 |
-
# outer token liveness, and always has exactly B*N*K True entries.
|
| 193 |
# -----------------------------------------------------------------------
|
| 194 |
time_axis = torch.arange(
|
| 195 |
-
|
| 196 |
-
device=
|
| 197 |
dtype=torch.long,
|
| 198 |
-
).view(1, 1,
|
| 199 |
unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
|
| 200 |
|
| 201 |
# -----------------------------------------------------------------------
|
| 202 |
-
# Materialize
|
| 203 |
#
|
| 204 |
-
#
|
| 205 |
-
#
|
| 206 |
-
#
|
| 207 |
-
#
|
| 208 |
# -----------------------------------------------------------------------
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
active_mask = torch.zeros(
|
| 221 |
-
batch_size,
|
| 222 |
-
num_experts,
|
| 223 |
-
max_tokens_per_expert,
|
| 224 |
-
dtype=torch.bool,
|
| 225 |
-
device=hidden_states.device,
|
| 226 |
-
)
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
return
|
| 233 |
|
| 234 |
|
| 235 |
# ---------------------------------------------------------------------------
|
|
@@ -238,9 +224,9 @@ def pack_experts(
|
|
| 238 |
|
| 239 |
def unpack_experts(
|
| 240 |
expert_outputs: torch.Tensor,
|
| 241 |
-
|
| 242 |
unpacking_mask: torch.Tensor,
|
| 243 |
-
|
| 244 |
) -> torch.Tensor:
|
| 245 |
"""Restore token-choice ordering from BEA expert-choice output.
|
| 246 |
|
|
@@ -257,14 +243,16 @@ def unpack_experts(
|
|
| 257 |
|
| 258 |
Args:
|
| 259 |
expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
|
| 260 |
-
|
| 261 |
unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
|
| 262 |
occupied packed slots regardless of outer token liveness.
|
| 263 |
-
|
| 264 |
|
| 265 |
Returns:
|
| 266 |
Restored token-choice tensor y_tilde of shape (B, N, K, d).
|
| 267 |
"""
|
|
|
|
|
|
|
| 268 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 269 |
hidden_dim = expert_outputs.shape[-1]
|
| 270 |
|
|
@@ -291,45 +279,63 @@ def unpack_experts(
|
|
| 291 |
# Helpers
|
| 292 |
# ---------------------------------------------------------------------------
|
| 293 |
|
| 294 |
-
def
|
| 295 |
-
|
| 296 |
-
num_bins: int,
|
| 297 |
-
) -> torch.Tensor:
|
| 298 |
-
"""Count per-row integer occurrences for a 2D tensor.
|
| 299 |
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
|
|
|
| 314 |
|
| 315 |
Args:
|
| 316 |
-
|
| 317 |
-
|
|
|
|
| 318 |
|
| 319 |
Returns:
|
| 320 |
-
Counts tensor of shape (B,
|
| 321 |
"""
|
| 322 |
-
batch_size =
|
| 323 |
-
|
| 324 |
-
row_offsets = torch.arange(
|
| 325 |
batch_size,
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
)
|
| 335 |
-
return counts
|
|
|
|
| 3 |
This module implements the low-level token-choice -> expert-choice -> token-choice
|
| 4 |
conversion boundary specified in the paper. The externally visible behavior is fixed:
|
| 5 |
|
| 6 |
+
- setup_packing() prepares the auxiliary ordering data and returns it as a dict
|
| 7 |
+
payload forwarded whole to pack_experts and unpack_experts.
|
| 8 |
+
- pack_experts() converts a dict of routed token-choice tensors into packed
|
| 9 |
+
expert-choice form. Each entry is paired with its intended padding value; all
|
| 10 |
+
entries undergo the same expert-major gather-scatter so they remain aligned.
|
| 11 |
- unpack_experts() restores token-choice ordering afterward.
|
| 12 |
|
| 13 |
Stable sort is a correctness requirement. It preserves causal ordering inside each
|
| 14 |
expert bucket, which is the foundation on which BEA's later triangular causal mask
|
| 15 |
is correct.
|
| 16 |
|
| 17 |
+
pack_experts() returns the packed entries dict together with a separate unpacking_mask.
|
| 18 |
+
Two masks serve different roles and must not be interchanged:
|
| 19 |
|
| 20 |
- unpacking_mask: marks every packed slot that contains a routed token copy,
|
| 21 |
live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
|
| 22 |
so its reshape invariant holds regardless of outer token liveness.
|
| 23 |
+
- active_mask (caller-supplied entry): marks only the packed slots whose source
|
| 24 |
+
token was semantically live. This is what BEA consumes for attention gating.
|
| 25 |
+
Dead outer tokens must not influence sparse attention outputs.
|
| 26 |
"""
|
| 27 |
|
| 28 |
import torch
|
| 29 |
+
from typing import Any
|
| 30 |
|
| 31 |
|
| 32 |
# ---------------------------------------------------------------------------
|
|
|
|
| 35 |
|
| 36 |
def setup_packing(
|
| 37 |
selected_heads: torch.Tensor,
|
| 38 |
+
) -> dict[str, torch.Tensor]:
|
| 39 |
"""Prepare the auxiliary ordering data used by pack/unpack.
|
| 40 |
|
| 41 |
Routing produces token-choice state I of shape (B, N, K): for each token, which
|
|
|
|
| 52 |
selected_heads: Routed token-choice head selections I of shape (B, N, K).
|
| 53 |
|
| 54 |
Returns:
|
| 55 |
+
Auxiliary payload dict with keys:
|
| 56 |
+
- "flattened_selected_heads": H of shape (B, N*K)
|
| 57 |
+
- "permutation": stable expert-major permutation Pi of shape (B, N*K)
|
| 58 |
+
- "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K)
|
| 59 |
+
This dict is forwarded whole to pack_experts and unpack_experts.
|
| 60 |
"""
|
| 61 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 62 |
flattened_selected_heads = selected_heads.reshape(
|
|
|
|
| 67 |
permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
|
| 68 |
inverse_permutation = torch.argsort(permutation, dim=-1)
|
| 69 |
|
| 70 |
+
return {
|
| 71 |
+
"flattened_selected_heads": flattened_selected_heads,
|
| 72 |
+
"permutation": permutation,
|
| 73 |
+
"inverse_permutation": inverse_permutation,
|
| 74 |
+
}
|
| 75 |
|
| 76 |
|
| 77 |
# ---------------------------------------------------------------------------
|
|
|
|
| 79 |
# ---------------------------------------------------------------------------
|
| 80 |
|
| 81 |
def pack_experts(
|
| 82 |
+
entries: dict[str, tuple[torch.Tensor, Any]],
|
| 83 |
+
setup: dict[str, torch.Tensor],
|
| 84 |
selected_heads: torch.Tensor,
|
| 85 |
num_experts: int,
|
| 86 |
+
packed_length: int,
|
| 87 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 88 |
+
"""Pack token-choice tensors into expert-choice padded form.
|
|
|
|
|
|
|
| 89 |
|
| 90 |
The paper's packing path has two jobs:
|
| 91 |
|
| 92 |
1. Convert routed token-choice copies into expert-major order.
|
| 93 |
2. Materialize that expert-major order into a padded tensor layout BEA can consume.
|
| 94 |
|
| 95 |
+
All entries in the provided dict undergo the same expert-major gather-scatter so
|
| 96 |
+
they remain mutually aligned in the packed frame. Each entry is paired with its
|
| 97 |
+
intended padding value, which fills slots that contain no routed token copy.
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
Packed positions are sourced from the authoritative upstream position_ids tensor
|
| 100 |
rather than synthesized locally from arange(N). This preserves advanced positions
|
|
|
|
| 102 |
unchanged when position_ids is the ordinary sequential token positions.
|
| 103 |
|
| 104 |
Args:
|
| 105 |
+
entries: Mapping from string keys to (tensor, padding_value) pairs. Each
|
| 106 |
+
tensor has shape (B, N, ...) and is rearranged into expert-choice layout
|
| 107 |
+
(B, L, T, ...). The returned dict carries the same keys.
|
| 108 |
+
setup: Auxiliary payload returned by setup_packing().
|
| 109 |
selected_heads: Routed head selections I of shape (B, N, K).
|
| 110 |
num_experts: Total number of experts L.
|
| 111 |
+
packed_length: Static packed time dimension T. All per-expert buffers are
|
| 112 |
+
allocated to exactly this length. Use config.mosrah_packed_length as the
|
| 113 |
+
source of this value. Raises if any actual per-expert token count exceeds
|
| 114 |
+
this value.
|
|
|
|
| 115 |
|
| 116 |
Returns:
|
| 117 |
Tuple of:
|
| 118 |
+
- packed_entries: Dict with same keys as entries; each value is the
|
| 119 |
+
packed tensor of shape (B, L, T, ...).
|
| 120 |
+
- unpacking_mask: Boolean tensor of shape (B, L, T). True where a slot
|
| 121 |
+
contains any routed token copy, live or dead. Always has exactly
|
| 122 |
+
B*N*K True entries. Pass this to unpack_experts — not active_mask.
|
|
|
|
|
|
|
| 123 |
"""
|
| 124 |
+
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 125 |
+
|
| 126 |
+
flattened_selected_heads = setup["flattened_selected_heads"]
|
| 127 |
+
permutation = setup["permutation"]
|
| 128 |
|
| 129 |
# -----------------------------------------------------------------------
|
| 130 |
# Reconstruct routed local source-token indices in token-choice order.
|
| 131 |
#
|
| 132 |
+
# The internal arange(N) is only the local source-row index object used to
|
| 133 |
+
# gather from the current chunk tensors. Flattening gives a (B, N*K) tensor
|
| 134 |
+
# aligned with H's token-major routed-copy order.
|
|
|
|
| 135 |
# -----------------------------------------------------------------------
|
| 136 |
source_token_indices = torch.arange(
|
| 137 |
sequence_length,
|
| 138 |
+
device=flattened_selected_heads.device,
|
| 139 |
dtype=torch.long,
|
| 140 |
).view(1, sequence_length, 1).expand(
|
| 141 |
batch_size,
|
|
|
|
| 151 |
# Reorder source-token indices into expert-major order.
|
| 152 |
#
|
| 153 |
# Applying Pi yields the local source-token rows in the packed expert-major
|
| 154 |
+
# order required by the paper. All entries are then gathered using these same
|
| 155 |
+
# reordered indices so they remain aligned under the exact same transformation.
|
|
|
|
|
|
|
| 156 |
# -----------------------------------------------------------------------
|
| 157 |
sorted_source_indices = flattened_source_indices.gather(
|
| 158 |
dim=1,
|
| 159 |
index=permutation,
|
| 160 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# -----------------------------------------------------------------------
|
| 163 |
+
# Count how many routed copies land in each expert bucket and verify
|
| 164 |
+
# that no bucket exceeds the statically preallocated packed_length T.
|
| 165 |
#
|
| 166 |
+
# S[b, l] is the number of routed token copies assigned to expert l in
|
| 167 |
+
# batch b. T (packed_length) is a static allocation derived from config,
|
| 168 |
+
# not a data-dependent maximum. Overflow is detected here and raises in
|
| 169 |
+
# both eager and compiled modes.
|
| 170 |
# -----------------------------------------------------------------------
|
| 171 |
+
tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts)
|
| 172 |
+
max_count = tokens_per_expert.max().item()
|
| 173 |
+
no_overflow = max_count <= packed_length
|
| 174 |
+
_enforce_no_overflow(no_overflow)
|
|
|
|
| 175 |
|
| 176 |
# -----------------------------------------------------------------------
|
| 177 |
+
# Construct the unpacking mask.
|
| 178 |
#
|
| 179 |
# Each expert bucket is left-justified: if S[b, l] = s, then slots
|
| 180 |
+
# t = 0, ..., s-1 are occupied and all later slots are padding. The mask
|
| 181 |
+
# marks slot occupancy regardless of outer token liveness, and always has
|
| 182 |
+
# exactly B*N*K True entries.
|
|
|
|
| 183 |
# -----------------------------------------------------------------------
|
| 184 |
time_axis = torch.arange(
|
| 185 |
+
packed_length,
|
| 186 |
+
device=flattened_selected_heads.device,
|
| 187 |
dtype=torch.long,
|
| 188 |
+
).view(1, 1, packed_length)
|
| 189 |
unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
|
| 190 |
|
| 191 |
# -----------------------------------------------------------------------
|
| 192 |
+
# Materialize all entries into the packed expert-choice frame.
|
| 193 |
#
|
| 194 |
+
# Each entry is gathered using the expert-major sorted source indices, then
|
| 195 |
+
# scattered into a padded buffer. The gather index is expanded to cover each
|
| 196 |
+
# tensor's trailing dimensions. Padding slots receive the caller-supplied fill
|
| 197 |
+
# value rather than an implicit zero.
|
| 198 |
# -----------------------------------------------------------------------
|
| 199 |
+
packed_entries: dict[str, torch.Tensor] = {}
|
| 200 |
+
for key, (tensor, padding_value) in entries.items():
|
| 201 |
+
extra_shape = tensor.shape[2:]
|
| 202 |
+
|
| 203 |
+
# Expand gather index to cover trailing dimensions, if any.
|
| 204 |
+
idx = sorted_source_indices.view(
|
| 205 |
+
batch_size,
|
| 206 |
+
sequence_length * num_selected_heads,
|
| 207 |
+
*(1,) * len(extra_shape),
|
| 208 |
+
).expand(-1, -1, *extra_shape)
|
| 209 |
+
sorted_tensor = tensor.gather(dim=1, index=idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
packed_tensor = tensor.new_full(
|
| 212 |
+
(batch_size, num_experts, packed_length, *extra_shape),
|
| 213 |
+
fill_value=padding_value,
|
| 214 |
+
)
|
| 215 |
+
packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
|
| 216 |
+
packed_entries[key] = packed_tensor
|
| 217 |
|
| 218 |
+
return packed_entries, unpacking_mask
|
| 219 |
|
| 220 |
|
| 221 |
# ---------------------------------------------------------------------------
|
|
|
|
| 224 |
|
| 225 |
def unpack_experts(
|
| 226 |
expert_outputs: torch.Tensor,
|
| 227 |
+
setup: dict[str, torch.Tensor],
|
| 228 |
unpacking_mask: torch.Tensor,
|
| 229 |
+
selected_heads: torch.Tensor,
|
| 230 |
) -> torch.Tensor:
|
| 231 |
"""Restore token-choice ordering from BEA expert-choice output.
|
| 232 |
|
|
|
|
| 243 |
|
| 244 |
Args:
|
| 245 |
expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
|
| 246 |
+
setup: Auxiliary payload returned by setup_packing().
|
| 247 |
unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
|
| 248 |
occupied packed slots regardless of outer token liveness.
|
| 249 |
+
selected_heads: Routed head selections I of shape (B, N, K).
|
| 250 |
|
| 251 |
Returns:
|
| 252 |
Restored token-choice tensor y_tilde of shape (B, N, K, d).
|
| 253 |
"""
|
| 254 |
+
inverse_permutation = setup["inverse_permutation"]
|
| 255 |
+
|
| 256 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 257 |
hidden_dim = expert_outputs.shape[-1]
|
| 258 |
|
|
|
|
| 279 |
# Helpers
|
| 280 |
# ---------------------------------------------------------------------------
|
| 281 |
|
| 282 |
+
def _enforce_no_overflow(condition: bool) -> None:
|
| 283 |
+
"""Enforce that no expert bucket exceeds the preallocated packed length.
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
This check fires when the number of tokens assigned to any expert in any
|
| 286 |
+
batch item exceeds mosrah_packed_length. When that limit is exceeded, the
|
| 287 |
+
packed buffer is too small to hold all assignments and data would be dropped.
|
| 288 |
+
Increase mosrah_overallocation_factor in ShramConfig to resolve.
|
| 289 |
|
| 290 |
+
The caller must derive condition via .item() on the max count tensor so that
|
| 291 |
+
dynamo captures a SymInt and the comparison produces a SymBool. Passing a
|
| 292 |
+
tensor comparison result directly bypasses the SymInt mechanism and prevents
|
| 293 |
+
the check from firing at compiled runtime.
|
| 294 |
|
| 295 |
+
Args:
|
| 296 |
+
condition: True means no overflow has occurred; False means at least one
|
| 297 |
+
expert bucket exceeds packed_length. In compiled mode this is a SymBool
|
| 298 |
+
produced by comparing a SymInt against the static packed_length.
|
| 299 |
+
"""
|
| 300 |
+
if torch.compiler.is_compiling():
|
| 301 |
+
torch._check(condition)
|
| 302 |
+
else:
|
| 303 |
+
if not condition:
|
| 304 |
+
raise RuntimeError(
|
| 305 |
+
"Expert packing overflow: at least one expert bucket contains more "
|
| 306 |
+
"tokens than mosrah_packed_length allows. Increase "
|
| 307 |
+
"mosrah_overallocation_factor in ShramConfig to resolve."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _count_tokens_per_expert(
|
| 312 |
+
flattened_selected_heads: torch.Tensor,
|
| 313 |
+
num_experts: int,
|
| 314 |
+
) -> torch.Tensor:
|
| 315 |
+
"""Count how many routed token copies are assigned to each expert per batch item.
|
| 316 |
|
| 317 |
+
Uses scatter_add into a pre-sized (B, num_experts) zero buffer, producing a
|
| 318 |
+
statically-shaped output that compiles without graph breaks. Each position in
|
| 319 |
+
flattened_selected_heads contributes one count to the corresponding expert slot.
|
| 320 |
|
| 321 |
Args:
|
| 322 |
+
flattened_selected_heads: Expert assignments of shape (B, N*K) with values
|
| 323 |
+
in [0, num_experts).
|
| 324 |
+
num_experts: Total number of experts L.
|
| 325 |
|
| 326 |
Returns:
|
| 327 |
+
Counts tensor of shape (B, num_experts).
|
| 328 |
"""
|
| 329 |
+
batch_size = flattened_selected_heads.shape[0]
|
| 330 |
+
counts = torch.zeros(
|
|
|
|
| 331 |
batch_size,
|
| 332 |
+
num_experts,
|
| 333 |
+
device=flattened_selected_heads.device,
|
| 334 |
+
dtype=flattened_selected_heads.dtype,
|
| 335 |
+
)
|
| 336 |
+
counts.scatter_add_(
|
| 337 |
+
dim=1,
|
| 338 |
+
index=flattened_selected_heads,
|
| 339 |
+
src=torch.ones_like(flattened_selected_heads),
|
| 340 |
)
|
| 341 |
+
return counts
|
__attention__mosrah.py
CHANGED
|
@@ -40,6 +40,7 @@ class MoSRAHLayer(nn.Module):
|
|
| 40 |
def __init__(self, config: ShramConfig) -> None:
|
| 41 |
super().__init__()
|
| 42 |
self.num_experts = config.num_mosrah_heads
|
|
|
|
| 43 |
|
| 44 |
self.router = MoSRAHRouter(config)
|
| 45 |
self.positions = SparseMoSRAHPositions(config)
|
|
@@ -91,18 +92,16 @@ class MoSRAHLayer(nn.Module):
|
|
| 91 |
hidden_states, active_mask
|
| 92 |
)
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
outer_active_mask=active_mask,
|
| 105 |
-
)
|
| 106 |
|
| 107 |
# -------------------------------------------------------------------
|
| 108 |
# Sparse attention runs entirely in the packed expert-choice frame, so
|
|
@@ -114,6 +113,7 @@ class MoSRAHLayer(nn.Module):
|
|
| 114 |
# -------------------------------------------------------------------
|
| 115 |
bea_positions = self.positions(
|
| 116 |
packed_positions=packed_positions,
|
|
|
|
| 117 |
cache=cache,
|
| 118 |
)
|
| 119 |
packed_outputs = self.bea(
|
|
@@ -133,9 +133,9 @@ class MoSRAHLayer(nn.Module):
|
|
| 133 |
# -------------------------------------------------------------------
|
| 134 |
token_choice_outputs = unpack_experts(
|
| 135 |
expert_outputs=packed_outputs,
|
| 136 |
-
|
| 137 |
unpacking_mask=unpacking_mask,
|
| 138 |
-
|
| 139 |
)
|
| 140 |
final_output = (
|
| 141 |
token_choice_outputs * routing_probs.unsqueeze(-1)
|
|
|
|
| 40 |
def __init__(self, config: ShramConfig) -> None:
|
| 41 |
super().__init__()
|
| 42 |
self.num_experts = config.num_mosrah_heads
|
| 43 |
+
self.packed_length = config.mosrah_packed_length
|
| 44 |
|
| 45 |
self.router = MoSRAHRouter(config)
|
| 46 |
self.positions = SparseMoSRAHPositions(config)
|
|
|
|
| 92 |
hidden_states, active_mask
|
| 93 |
)
|
| 94 |
|
| 95 |
+
setup = setup_packing(selected_heads)
|
| 96 |
+
entries = {
|
| 97 |
+
"hidden_states": (hidden_states, 0.0),
|
| 98 |
+
"position_ids": (position_ids, 0),
|
| 99 |
+
"active_mask": (active_mask, False),
|
| 100 |
+
}
|
| 101 |
+
packed, unpacking_mask = pack_experts(entries, setup, selected_heads, self.num_experts, self.packed_length)
|
| 102 |
+
packed_hidden_states = packed["hidden_states"]
|
| 103 |
+
packed_positions = packed["position_ids"]
|
| 104 |
+
active_mask = packed["active_mask"]
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# -------------------------------------------------------------------
|
| 107 |
# Sparse attention runs entirely in the packed expert-choice frame, so
|
|
|
|
| 113 |
# -------------------------------------------------------------------
|
| 114 |
bea_positions = self.positions(
|
| 115 |
packed_positions=packed_positions,
|
| 116 |
+
active_mask=active_mask,
|
| 117 |
cache=cache,
|
| 118 |
)
|
| 119 |
packed_outputs = self.bea(
|
|
|
|
| 133 |
# -------------------------------------------------------------------
|
| 134 |
token_choice_outputs = unpack_experts(
|
| 135 |
expert_outputs=packed_outputs,
|
| 136 |
+
setup=setup,
|
| 137 |
unpacking_mask=unpacking_mask,
|
| 138 |
+
selected_heads=selected_heads,
|
| 139 |
)
|
| 140 |
final_output = (
|
| 141 |
token_choice_outputs * routing_probs.unsqueeze(-1)
|
__attention__positions_converter.py
CHANGED
|
@@ -32,12 +32,17 @@ class SparseMoSRAHPositions(nn.Module):
|
|
| 32 |
def forward(
|
| 33 |
self,
|
| 34 |
packed_positions: torch.Tensor,
|
|
|
|
| 35 |
cache: MoSRAHCache | None,
|
| 36 |
) -> torch.Tensor:
|
| 37 |
"""Compute the packed position tensor P consumed by BEA.
|
| 38 |
|
| 39 |
Args:
|
| 40 |
packed_positions: Packed original-token positions J' of shape (B, L, T).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
|
| 42 |
mode, the current per-head occupancies offset the local packed sequence.
|
| 43 |
|
|
@@ -45,14 +50,15 @@ class SparseMoSRAHPositions(nn.Module):
|
|
| 45 |
Packed position tensor P of shape (B, L, T).
|
| 46 |
"""
|
| 47 |
if self.rope_mode == "main_sequence":
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
def _main_sequence_positions(
|
| 58 |
self,
|
|
|
|
| 32 |
def forward(
|
| 33 |
self,
|
| 34 |
packed_positions: torch.Tensor,
|
| 35 |
+
active_mask: torch.Tensor,
|
| 36 |
cache: MoSRAHCache | None,
|
| 37 |
) -> torch.Tensor:
|
| 38 |
"""Compute the packed position tensor P consumed by BEA.
|
| 39 |
|
| 40 |
Args:
|
| 41 |
packed_positions: Packed original-token positions J' of shape (B, L, T).
|
| 42 |
+
active_mask: Boolean active-token mask of shape (B, L, T). Inactive
|
| 43 |
+
positions are zeroed in the returned tensor regardless of mode —
|
| 44 |
+
their position value is semantically irrelevant and 0 is guaranteed
|
| 45 |
+
to be within any valid RoPE table.
|
| 46 |
cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
|
| 47 |
mode, the current per-head occupancies offset the local packed sequence.
|
| 48 |
|
|
|
|
| 50 |
Packed position tensor P of shape (B, L, T).
|
| 51 |
"""
|
| 52 |
if self.rope_mode == "main_sequence":
|
| 53 |
+
positions = self._main_sequence_positions(packed_positions)
|
| 54 |
+
elif self.rope_mode == "semantic_sequence":
|
| 55 |
+
positions = self._semantic_sequence_positions(packed_positions, cache)
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError(
|
| 58 |
+
f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return torch.where(active_mask, positions, torch.zeros_like(positions))
|
| 62 |
|
| 63 |
def _main_sequence_positions(
|
| 64 |
self,
|
__attention__router.py
CHANGED
|
@@ -57,10 +57,11 @@ class MoSRAHRouter(nn.Module):
|
|
| 57 |
super().__init__()
|
| 58 |
self.num_mosrah_heads = config.num_mosrah_heads
|
| 59 |
self.num_selected_heads = config.num_selected_heads
|
|
|
|
| 60 |
|
| 61 |
# W_r: routing projection, no bias (paper specifies xW_r, no additional term).
|
| 62 |
self.routing_projection = nn.Linear(
|
| 63 |
-
config.
|
| 64 |
)
|
| 65 |
|
| 66 |
# b: learned per-head bias for load balancing. Initialized to zero so that all
|
|
@@ -117,25 +118,31 @@ class MoSRAHRouter(nn.Module):
|
|
| 117 |
gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
|
| 118 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 119 |
|
| 120 |
-
#
|
| 121 |
-
#
|
| 122 |
-
#
|
| 123 |
-
#
|
| 124 |
-
# are live in this chunk.
|
| 125 |
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 126 |
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
| 127 |
active_assignments = assignment_mask * active_mask.unsqueeze(-1)
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
# Load balance loss via custom autograd. expert_bias is an input so PyTorch
|
| 132 |
# registers it as a graph node; the custom backward writes the DeepSeek-style
|
| 133 |
# correction gradient to expert_bias.grad for the optimizer to consume.
|
| 134 |
load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
|
| 135 |
|
| 136 |
-
# MaxVio is a detached monitoring scalar
|
| 137 |
-
#
|
| 138 |
-
# point of computation rather than left to callers to detach.
|
| 139 |
max_vio = self._compute_max_vio(routing_freqs, L)
|
| 140 |
|
| 141 |
return selected_heads, routing_probs, load_balance_loss, max_vio
|
|
@@ -145,15 +152,16 @@ class MoSRAHRouter(nn.Module):
|
|
| 145 |
"""Compute the MaxVio routing-imbalance scalar.
|
| 146 |
|
| 147 |
MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
|
| 148 |
-
head l and 1/L is the perfectly balanced target.
|
| 149 |
-
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
The result is detached from the autograd graph — MaxVio is a monitoring scalar
|
| 153 |
and must never contribute gradients to any parameter.
|
| 154 |
|
| 155 |
Args:
|
| 156 |
-
routing_freqs: Per-head routing frequencies of shape (L,).
|
| 157 |
num_heads: Total number of MoSRAH heads L.
|
| 158 |
|
| 159 |
Returns:
|
|
|
|
| 57 |
super().__init__()
|
| 58 |
self.num_mosrah_heads = config.num_mosrah_heads
|
| 59 |
self.num_selected_heads = config.num_selected_heads
|
| 60 |
+
self.load_balance_p = config.load_balance_p
|
| 61 |
|
| 62 |
# W_r: routing projection, no bias (paper specifies xW_r, no additional term).
|
| 63 |
self.routing_projection = nn.Linear(
|
| 64 |
+
config.embedding_width, config.num_mosrah_heads, bias=False
|
| 65 |
)
|
| 66 |
|
| 67 |
# b: learned per-head bias for load balancing. Initialized to zero so that all
|
|
|
|
| 118 |
gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
|
| 119 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 120 |
|
| 121 |
+
# Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
|
| 122 |
+
# fraction of that item's active K assignments over all tokens go to head l.
|
| 123 |
+
# Dead tokens are excluded before reduction. Normalization is per batch item so
|
| 124 |
+
# each item's frequencies sum to 1 independently of other items in the batch.
|
|
|
|
| 125 |
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 126 |
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
| 127 |
active_assignments = assignment_mask * active_mask.unsqueeze(-1)
|
| 128 |
+
per_item_counts = active_assignments.sum(dim=1) # (B, L)
|
| 129 |
+
per_item_total = active_mask.sum(dim=1, keepdim=True) * K # (B, 1)
|
| 130 |
+
per_item_freqs = per_item_counts / per_item_total # (B, L)
|
| 131 |
+
|
| 132 |
+
# p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,).
|
| 133 |
+
# p-mean weights aggregation toward the worst-case batch item relative to
|
| 134 |
+
# arithmetic mean, making the load balance signal sensitive to per-item spikes
|
| 135 |
+
# that cause packing overflow.
|
| 136 |
+
p = self.load_balance_p
|
| 137 |
+
routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
|
| 138 |
|
| 139 |
# Load balance loss via custom autograd. expert_bias is an input so PyTorch
|
| 140 |
# registers it as a graph node; the custom backward writes the DeepSeek-style
|
| 141 |
# correction gradient to expert_bias.grad for the optimizer to consume.
|
| 142 |
load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
|
| 143 |
|
| 144 |
+
# MaxVio is a detached monitoring scalar following the paper's formula
|
| 145 |
+
# L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
|
|
|
|
| 146 |
max_vio = self._compute_max_vio(routing_freqs, L)
|
| 147 |
|
| 148 |
return selected_heads, routing_probs, load_balance_loss, max_vio
|
|
|
|
| 152 |
"""Compute the MaxVio routing-imbalance scalar.
|
| 153 |
|
| 154 |
MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
|
| 155 |
+
head l and 1/L is the perfectly balanced target. Follows the paper's definition
|
| 156 |
+
(Wang et al.) applied to routing_freqs. A value of zero indicates perfect
|
| 157 |
+
balance; a value of 0.5 means the most overloaded head received 50% more routed
|
| 158 |
+
tokens than ideal.
|
| 159 |
|
| 160 |
The result is detached from the autograd graph — MaxVio is a monitoring scalar
|
| 161 |
and must never contribute gradients to any parameter.
|
| 162 |
|
| 163 |
Args:
|
| 164 |
+
routing_freqs: Per-head routing frequencies of shape (L,).
|
| 165 |
num_heads: Total number of MoSRAH heads L.
|
| 166 |
|
| 167 |
Returns:
|
__attention__shram.py
CHANGED
|
@@ -64,19 +64,6 @@ class SHRAMHybridLayer(nn.Module):
|
|
| 64 |
max_vio: Detached scalar routing-imbalance summary. Passed through
|
| 65 |
unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
|
| 66 |
"""
|
| 67 |
-
# ------------------------------------------------
|
| 68 |
-
# It is not possible, due to how bea constructs its block mask,
|
| 69 |
-
# for the model to process a sequence that does not start at zero
|
| 70 |
-
# without a cache to track the per-head offsets
|
| 71 |
-
# ------------------------------------------------
|
| 72 |
-
|
| 73 |
-
if cache is None and torch.any(position_ids[:, 0] != 0):
|
| 74 |
-
raise ValueError(
|
| 75 |
-
"Uncached SHRAMHybridLayer does not support nonzero starting positions. "
|
| 76 |
-
"Either provide a matching ShramLayerCache populated by the prefix for "
|
| 77 |
-
"continued decoding, or rebase the uncached sequence to start at 0."
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
# -------------------------------------------------------------------
|
| 81 |
# The hybrid layer's first responsibility is cache dispatch. The layer
|
| 82 |
# cache already owns the concrete sub-cache objects required by each
|
|
|
|
| 64 |
max_vio: Detached scalar routing-imbalance summary. Passed through
|
| 65 |
unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
|
| 66 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
# -------------------------------------------------------------------
|
| 68 |
# The hybrid layer's first responsibility is cache dispatch. The layer
|
| 69 |
# cache already owns the concrete sub-cache objects required by each
|
__attention__sliding_window_attention.py
CHANGED
|
@@ -44,7 +44,7 @@ class SlidingWindowAttention(nn.Module):
|
|
| 44 |
def __init__(self, config: ShramConfig) -> None:
|
| 45 |
super().__init__()
|
| 46 |
|
| 47 |
-
self.hidden_size = config.
|
| 48 |
self.num_heads = config.num_sliding_window_heads
|
| 49 |
self.head_dim = config.head_dim
|
| 50 |
self.window_size = config.window_size
|
|
@@ -69,6 +69,7 @@ class SlidingWindowAttention(nn.Module):
|
|
| 69 |
mode="default",
|
| 70 |
head_dim=self.head_dim,
|
| 71 |
theta=config.local_rope_theta,
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
def forward(
|
|
|
|
| 44 |
def __init__(self, config: ShramConfig) -> None:
|
| 45 |
super().__init__()
|
| 46 |
|
| 47 |
+
self.hidden_size = config.embedding_width
|
| 48 |
self.num_heads = config.num_sliding_window_heads
|
| 49 |
self.head_dim = config.head_dim
|
| 50 |
self.window_size = config.window_size
|
|
|
|
| 69 |
mode="default",
|
| 70 |
head_dim=self.head_dim,
|
| 71 |
theta=config.local_rope_theta,
|
| 72 |
+
maximum_sequence_length=config.inference_sequence_length,
|
| 73 |
)
|
| 74 |
|
| 75 |
def forward(
|
__cache__mosrah_cache.py
CHANGED
|
@@ -61,12 +61,13 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 61 |
batch_size: Number of sequences in the batch. Determines the first dimension
|
| 62 |
of all storage tensors.
|
| 63 |
device: Device on which to allocate all tensors. Should match the model device.
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
"""
|
| 68 |
|
| 69 |
-
is_compileable =
|
| 70 |
is_sliding = False
|
| 71 |
|
| 72 |
def __init__(
|
|
@@ -75,22 +76,23 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 75 |
head_dim: int,
|
| 76 |
batch_size: int,
|
| 77 |
device: torch.device,
|
| 78 |
-
|
| 79 |
) -> None:
|
| 80 |
super().__init__()
|
| 81 |
self.num_mosrah_heads = num_mosrah_heads
|
| 82 |
self.head_dim = head_dim
|
| 83 |
self.batch_size = batch_size
|
| 84 |
self.device = device
|
|
|
|
| 85 |
|
| 86 |
# Allocate primary storage into the mixin-standard self.keys / self.values so
|
| 87 |
# that inherited methods (offload, prefetch) operate on real tensors. _counts
|
| 88 |
# tracks valid occupancy per (batch, head) slot.
|
| 89 |
self.keys: torch.Tensor = torch.zeros(
|
| 90 |
-
batch_size, num_mosrah_heads,
|
| 91 |
)
|
| 92 |
self.values: torch.Tensor = torch.zeros(
|
| 93 |
-
batch_size, num_mosrah_heads,
|
| 94 |
)
|
| 95 |
self._counts: torch.Tensor = torch.zeros(
|
| 96 |
batch_size, num_mosrah_heads, dtype=torch.long, device=device
|
|
@@ -107,8 +109,8 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 107 |
def buffer_capacity(self) -> int:
|
| 108 |
"""Current number of slots allocated per (batch, head) pair.
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
"""
|
| 113 |
return self.keys.shape[2]
|
| 114 |
|
|
@@ -129,10 +131,11 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 129 |
active_mask is (B, L, T) bool with True marking real tokens. Only active
|
| 130 |
positions are written; inactive positions are ignored.
|
| 131 |
|
| 132 |
-
Uses a
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
| 136 |
|
| 137 |
Returns the full accumulated (keys, values, active_mask) across the cached
|
| 138 |
sparse sequence. The returned active_mask is True exactly for slots t <
|
|
@@ -150,35 +153,36 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 150 |
|
| 151 |
Returns:
|
| 152 |
Tuple of (keys, values, active_mask):
|
| 153 |
-
keys: (B, L,
|
| 154 |
-
values: (B, L,
|
| 155 |
-
active_mask: (B, L,
|
| 156 |
"""
|
| 157 |
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
#
|
| 163 |
-
#
|
| 164 |
-
#
|
| 165 |
-
#
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
# Scatter key and value vectors at all active positions.
|
| 173 |
-
b_idx, l_idx, t_idx = torch.where(active_mask)
|
| 174 |
-
self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
|
| 175 |
-
key_states[b_idx, l_idx, t_idx]
|
| 176 |
-
)
|
| 177 |
-
self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
|
| 178 |
-
value_states[b_idx, l_idx, t_idx]
|
| 179 |
)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
return self.keys, self.values, self._make_active_mask()
|
| 184 |
|
|
@@ -303,10 +307,13 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 303 |
)
|
| 304 |
|
| 305 |
def get_max_cache_shape(self) -> int: # type: ignore[override]
|
| 306 |
-
"""
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
def get_mask_sizes( # type: ignore[override]
|
| 312 |
self,
|
|
@@ -335,25 +342,26 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 335 |
< self._counts.unsqueeze(-1)
|
| 336 |
)
|
| 337 |
|
| 338 |
-
|
| 339 |
-
|
|
|
|
| 340 |
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
| 346 |
"""
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
new_values[:, :, :old_cap, :] = self.values
|
| 358 |
-
self.keys = new_keys
|
| 359 |
-
self.values = new_values
|
|
|
|
| 61 |
batch_size: Number of sequences in the batch. Determines the first dimension
|
| 62 |
of all storage tensors.
|
| 63 |
device: Device on which to allocate all tensors. Should match the model device.
|
| 64 |
+
mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to
|
| 65 |
+
config.mosrah_cache_length. The buffer never grows; if any slot would exceed
|
| 66 |
+
this capacity, update() raises in both eager and compiled modes. Increase
|
| 67 |
+
mosrah_overallocation_factor in ShramConfig to resolve an overflow.
|
| 68 |
"""
|
| 69 |
|
| 70 |
+
is_compileable = True
|
| 71 |
is_sliding = False
|
| 72 |
|
| 73 |
def __init__(
|
|
|
|
| 76 |
head_dim: int,
|
| 77 |
batch_size: int,
|
| 78 |
device: torch.device,
|
| 79 |
+
mosrah_cache_length: int,
|
| 80 |
) -> None:
|
| 81 |
super().__init__()
|
| 82 |
self.num_mosrah_heads = num_mosrah_heads
|
| 83 |
self.head_dim = head_dim
|
| 84 |
self.batch_size = batch_size
|
| 85 |
self.device = device
|
| 86 |
+
self.mosrah_cache_length = mosrah_cache_length
|
| 87 |
|
| 88 |
# Allocate primary storage into the mixin-standard self.keys / self.values so
|
| 89 |
# that inherited methods (offload, prefetch) operate on real tensors. _counts
|
| 90 |
# tracks valid occupancy per (batch, head) slot.
|
| 91 |
self.keys: torch.Tensor = torch.zeros(
|
| 92 |
+
batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
|
| 93 |
)
|
| 94 |
self.values: torch.Tensor = torch.zeros(
|
| 95 |
+
batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
|
| 96 |
)
|
| 97 |
self._counts: torch.Tensor = torch.zeros(
|
| 98 |
batch_size, num_mosrah_heads, dtype=torch.long, device=device
|
|
|
|
| 109 |
def buffer_capacity(self) -> int:
|
| 110 |
"""Current number of slots allocated per (batch, head) pair.
|
| 111 |
|
| 112 |
+
Equal to mosrah_cache_length as supplied at construction. Derived from
|
| 113 |
+
self.keys so it remains consistent with the actual buffer shape.
|
| 114 |
"""
|
| 115 |
return self.keys.shape[2]
|
| 116 |
|
|
|
|
| 131 |
active_mask is (B, L, T) bool with True marking real tokens. Only active
|
| 132 |
positions are written; inactive positions are ignored.
|
| 133 |
|
| 134 |
+
Uses a fixed-shape destination mask constructed from per-slot write intervals
|
| 135 |
+
to transfer active tokens into the buffer without any data-dependent shape
|
| 136 |
+
operations. Active tokens are left-justified within each packed slot by the
|
| 137 |
+
packing machinery, so the destination positions are a contiguous range
|
| 138 |
+
starting at the current slot count — no cumsum or torch.where needed.
|
| 139 |
|
| 140 |
Returns the full accumulated (keys, values, active_mask) across the cached
|
| 141 |
sparse sequence. The returned active_mask is True exactly for slots t <
|
|
|
|
| 153 |
|
| 154 |
Returns:
|
| 155 |
Tuple of (keys, values, active_mask):
|
| 156 |
+
keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots.
|
| 157 |
+
values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots.
|
| 158 |
+
active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written.
|
| 159 |
"""
|
| 160 |
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
|
| 161 |
|
| 162 |
+
post_counts = self._counts + incoming_delta
|
| 163 |
+
self._check_no_overflow(post_counts.max(), self.mosrah_cache_length)
|
| 164 |
+
|
| 165 |
+
# Build a fixed-shape destination mask in cache space. Active tokens within
|
| 166 |
+
# each (b, l) slot are left-justified by the packing machinery, so they occupy
|
| 167 |
+
# positions 0..s-1 in their packed slot. The corresponding cache positions are
|
| 168 |
+
# write_start[b,l]..write_start[b,l]+write_count[b,l]-1. Broadcasting a
|
| 169 |
+
# time arange against these per-slot intervals selects exactly the target
|
| 170 |
+
# positions without any data-dependent shape query.
|
| 171 |
+
write_start = self._counts.unsqueeze(-1) # cache position where new tokens begin
|
| 172 |
+
write_count = incoming_delta.unsqueeze(-1) # number of new tokens arriving per slot
|
| 173 |
+
time_arange = torch.arange(
|
| 174 |
+
self.mosrah_cache_length, device=active_mask.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
+
dest_mask = (time_arange >= write_start) & (time_arange < write_start + write_count)
|
| 177 |
+
# dest_mask: (B, L, mosrah_cache_length)
|
| 178 |
|
| 179 |
+
# Transfer key and value vectors. Left-justification guarantees that
|
| 180 |
+
# dest_mask and active_mask have equal True counts per (b, l) slot, so the
|
| 181 |
+
# boolean-mask transfer is correct without any explicit count verification.
|
| 182 |
+
self.keys[dest_mask] = key_states[active_mask]
|
| 183 |
+
self.values[dest_mask] = value_states[active_mask]
|
| 184 |
+
|
| 185 |
+
self._counts = post_counts
|
| 186 |
|
| 187 |
return self.keys, self.values, self._make_active_mask()
|
| 188 |
|
|
|
|
| 307 |
)
|
| 308 |
|
| 309 |
def get_max_cache_shape(self) -> int: # type: ignore[override]
|
| 310 |
+
"""Return the static per-(batch, head) slot capacity of this cache.
|
| 311 |
+
|
| 312 |
+
Equal to mosrah_cache_length as supplied at construction, which is derived
|
| 313 |
+
from config.mosrah_cache_length. Required by the HuggingFace static cache
|
| 314 |
+
contract; generation machinery uses this to size attention masks.
|
| 315 |
+
"""
|
| 316 |
+
return self.mosrah_cache_length
|
| 317 |
|
| 318 |
def get_mask_sizes( # type: ignore[override]
|
| 319 |
self,
|
|
|
|
| 342 |
< self._counts.unsqueeze(-1)
|
| 343 |
)
|
| 344 |
|
| 345 |
+
@staticmethod
|
| 346 |
+
def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
|
| 347 |
+
"""Raise if any (batch, head) slot would exceed the static buffer capacity.
|
| 348 |
|
| 349 |
+
Uses the 19.F.1 pattern: branches on whether the graph is being compiled.
|
| 350 |
+
In compiled mode, `.item()` folds into the graph when capture_scalar_outputs=True
|
| 351 |
+
and `torch._check` issues a compile-time assertion. In eager mode, a plain
|
| 352 |
+
RuntimeError is raised with a descriptive message.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
max_count: Scalar tensor — the maximum post-update count across all slots.
|
| 356 |
+
capacity: The static buffer capacity (mosrah_cache_length).
|
| 357 |
"""
|
| 358 |
+
if torch.compiler.is_compiling():
|
| 359 |
+
torch._check(max_count.item() <= capacity)
|
| 360 |
+
else:
|
| 361 |
+
if max_count.item() > capacity:
|
| 362 |
+
raise RuntimeError(
|
| 363 |
+
f"MoSRAHCache overflow: a (batch, head) slot would reach "
|
| 364 |
+
f"{max_count.item()} tokens but the static buffer capacity is "
|
| 365 |
+
f"{capacity}. Increase mosrah_overallocation_factor in ShramConfig."
|
| 366 |
+
)
|
| 367 |
+
|
|
|
|
|
|
|
|
|
__cache__shram_cache.py
CHANGED
|
@@ -21,6 +21,7 @@ what HuggingFace generation reads through get_seq_length().
|
|
| 21 |
import torch
|
| 22 |
from transformers.cache_utils import Cache
|
| 23 |
|
|
|
|
| 24 |
from .__cache__shram_layer_cache import ShramLayerCache
|
| 25 |
|
| 26 |
|
|
@@ -36,44 +37,28 @@ class ShramCache(Cache):
|
|
| 36 |
via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
|
| 37 |
|
| 38 |
Args:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
num_local_heads: Number of local attention heads per layer.
|
| 43 |
-
local_head_dim: Per-head embedding width for the local path.
|
| 44 |
-
num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
|
| 45 |
-
mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
|
| 46 |
batch_size: Number of sequences in the batch.
|
| 47 |
device: Device on which to allocate cache tensors.
|
| 48 |
-
initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
|
| 49 |
-
Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
|
| 50 |
-
during prompt processing.
|
| 51 |
"""
|
| 52 |
|
|
|
|
|
|
|
| 53 |
def __init__(
|
| 54 |
self,
|
| 55 |
-
|
| 56 |
-
sliding_window: int,
|
| 57 |
-
num_local_heads: int,
|
| 58 |
-
local_head_dim: int,
|
| 59 |
-
num_mosrah_heads: int,
|
| 60 |
-
mosrah_head_dim: int,
|
| 61 |
batch_size: int,
|
| 62 |
device: torch.device,
|
| 63 |
-
initial_buffer_size: int = 64,
|
| 64 |
) -> None:
|
| 65 |
layers = [
|
| 66 |
ShramLayerCache(
|
| 67 |
-
|
| 68 |
-
num_local_heads=num_local_heads,
|
| 69 |
-
local_head_dim=local_head_dim,
|
| 70 |
-
num_mosrah_heads=num_mosrah_heads,
|
| 71 |
-
mosrah_head_dim=mosrah_head_dim,
|
| 72 |
batch_size=batch_size,
|
| 73 |
device=device,
|
| 74 |
-
initial_buffer_size=initial_buffer_size,
|
| 75 |
)
|
| 76 |
-
for _ in range(
|
| 77 |
]
|
| 78 |
super().__init__(layers=layers)
|
| 79 |
|
|
@@ -133,9 +118,10 @@ class ShramCache(Cache):
|
|
| 133 |
|
| 134 |
@property
|
| 135 |
def max_cache_len(self) -> int:
|
| 136 |
-
"""
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
|
|
|
| 140 |
"""
|
| 141 |
-
|
|
|
|
| 21 |
import torch
|
| 22 |
from transformers.cache_utils import Cache
|
| 23 |
|
| 24 |
+
from .configuration import ShramConfig
|
| 25 |
from .__cache__shram_layer_cache import ShramLayerCache
|
| 26 |
|
| 27 |
|
|
|
|
| 37 |
via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
|
| 38 |
|
| 39 |
Args:
|
| 40 |
+
config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache
|
| 41 |
+
dimensions are derived from config so that a single source of truth governs
|
| 42 |
+
every buffer size across the full cache stack.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
batch_size: Number of sequences in the batch.
|
| 44 |
device: Device on which to allocate cache tensors.
|
|
|
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
|
| 47 |
+
is_compileable = True
|
| 48 |
+
|
| 49 |
def __init__(
|
| 50 |
self,
|
| 51 |
+
config: ShramConfig,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
batch_size: int,
|
| 53 |
device: torch.device,
|
|
|
|
| 54 |
) -> None:
|
| 55 |
layers = [
|
| 56 |
ShramLayerCache(
|
| 57 |
+
config=config,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
batch_size=batch_size,
|
| 59 |
device=device,
|
|
|
|
| 60 |
)
|
| 61 |
+
for _ in range(config.num_decoder_layers)
|
| 62 |
]
|
| 63 |
super().__init__(layers=layers)
|
| 64 |
|
|
|
|
| 118 |
|
| 119 |
@property
|
| 120 |
def max_cache_len(self) -> int:
|
| 121 |
+
"""Return the maximum sequence length the cache can serve.
|
| 122 |
|
| 123 |
+
Delegates to layers[0].get_max_cache_shape(), which returns
|
| 124 |
+
config.inference_sequence_length. HuggingFace's static-cache machinery reads
|
| 125 |
+
this value to size generation loops and verify compileable cache contracts.
|
| 126 |
"""
|
| 127 |
+
return self.layers[0].get_max_cache_shape()
|
__cache__shram_layer_cache.py
CHANGED
|
@@ -21,6 +21,7 @@ quantity HuggingFace generation reads through get_seq_length().
|
|
| 21 |
import torch
|
| 22 |
from transformers.cache_utils import CacheLayerMixin
|
| 23 |
|
|
|
|
| 24 |
from .__cache__mosrah_cache import MoSRAHCache
|
| 25 |
from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
|
| 26 |
|
|
@@ -40,46 +41,36 @@ class ShramLayerCache(CacheLayerMixin):
|
|
| 40 |
tracks the cumulative count of token positions processed across all update() calls.
|
| 41 |
|
| 42 |
Args:
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
local_head_dim: Per-head embedding width for the local path.
|
| 46 |
-
num_mosrah_heads: Total number of MoSRAH expert heads (L).
|
| 47 |
-
mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
|
| 48 |
batch_size: Number of sequences in the batch.
|
| 49 |
device: Device on which to allocate cache tensors.
|
| 50 |
-
initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
|
| 51 |
-
when any slot overflows. Defaults to 64 to avoid repeated reallocation during
|
| 52 |
-
prompt processing.
|
| 53 |
"""
|
| 54 |
|
| 55 |
-
is_compileable =
|
| 56 |
is_sliding = False
|
| 57 |
|
| 58 |
def __init__(
|
| 59 |
self,
|
| 60 |
-
|
| 61 |
-
num_local_heads: int,
|
| 62 |
-
local_head_dim: int,
|
| 63 |
-
num_mosrah_heads: int,
|
| 64 |
-
mosrah_head_dim: int,
|
| 65 |
batch_size: int,
|
| 66 |
device: torch.device,
|
| 67 |
-
initial_buffer_size: int = 64,
|
| 68 |
) -> None:
|
| 69 |
super().__init__()
|
|
|
|
| 70 |
self.sliding_window_cache = LocalSlidingWindowLayerCache(
|
| 71 |
-
sliding_window=
|
| 72 |
-
num_heads=
|
| 73 |
-
head_dim=
|
| 74 |
batch_size=batch_size,
|
| 75 |
device=device,
|
| 76 |
)
|
| 77 |
self.mosrah_cache = MoSRAHCache(
|
| 78 |
-
num_mosrah_heads=num_mosrah_heads,
|
| 79 |
-
head_dim=
|
| 80 |
batch_size=batch_size,
|
| 81 |
device=device,
|
| 82 |
-
|
| 83 |
)
|
| 84 |
|
| 85 |
# ---------------------------------------------------------------------------
|
|
@@ -208,26 +199,23 @@ class ShramLayerCache(CacheLayerMixin):
|
|
| 208 |
)
|
| 209 |
|
| 210 |
def get_max_cache_shape(self) -> int: # type: ignore[override]
|
| 211 |
-
"""
|
| 212 |
|
| 213 |
-
The
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
"""
|
| 216 |
-
|
| 217 |
-
"ShramLayerCache has no single maximum cache shape. "
|
| 218 |
-
"Query sliding_window_cache or mosrah_cache directly."
|
| 219 |
-
)
|
| 220 |
|
| 221 |
def get_mask_sizes( # type: ignore[override]
|
| 222 |
self,
|
| 223 |
cache_position: torch.Tensor,
|
| 224 |
) -> tuple[int, int]:
|
| 225 |
-
"""
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
|
|
|
| 229 |
"""
|
| 230 |
-
|
| 231 |
-
"ShramLayerCache does not support get_mask_sizes(). "
|
| 232 |
-
"Query sliding_window_cache or mosrah_cache directly."
|
| 233 |
-
)
|
|
|
|
| 21 |
import torch
|
| 22 |
from transformers.cache_utils import CacheLayerMixin
|
| 23 |
|
| 24 |
+
from .configuration import ShramConfig
|
| 25 |
from .__cache__mosrah_cache import MoSRAHCache
|
| 26 |
from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
|
| 27 |
|
|
|
|
| 41 |
tracks the cumulative count of token positions processed across all update() calls.
|
| 42 |
|
| 43 |
Args:
|
| 44 |
+
config: ShramConfig instance. All sub-cache dimensions and capacities are derived
|
| 45 |
+
from config so that a single source of truth governs every buffer size.
|
|
|
|
|
|
|
|
|
|
| 46 |
batch_size: Number of sequences in the batch.
|
| 47 |
device: Device on which to allocate cache tensors.
|
|
|
|
|
|
|
|
|
|
| 48 |
"""
|
| 49 |
|
| 50 |
+
is_compileable = True
|
| 51 |
is_sliding = False
|
| 52 |
|
| 53 |
def __init__(
|
| 54 |
self,
|
| 55 |
+
config: ShramConfig,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
batch_size: int,
|
| 57 |
device: torch.device,
|
|
|
|
| 58 |
) -> None:
|
| 59 |
super().__init__()
|
| 60 |
+
self._inference_sequence_length = config.inference_sequence_length
|
| 61 |
self.sliding_window_cache = LocalSlidingWindowLayerCache(
|
| 62 |
+
sliding_window=config.window_size,
|
| 63 |
+
num_heads=config.num_sliding_window_heads,
|
| 64 |
+
head_dim=config.head_dim,
|
| 65 |
batch_size=batch_size,
|
| 66 |
device=device,
|
| 67 |
)
|
| 68 |
self.mosrah_cache = MoSRAHCache(
|
| 69 |
+
num_mosrah_heads=config.num_mosrah_heads,
|
| 70 |
+
head_dim=config.head_dim,
|
| 71 |
batch_size=batch_size,
|
| 72 |
device=device,
|
| 73 |
+
mosrah_cache_length=config.mosrah_cache_length,
|
| 74 |
)
|
| 75 |
|
| 76 |
# ---------------------------------------------------------------------------
|
|
|
|
| 199 |
)
|
| 200 |
|
| 201 |
def get_max_cache_shape(self) -> int: # type: ignore[override]
|
| 202 |
+
"""Return the maximum sequence length this layer cache can serve.
|
| 203 |
|
| 204 |
+
The authoritative upper bound is ``config.inference_sequence_length``, which
|
| 205 |
+
governs the full accumulated token history the model is configured to handle.
|
| 206 |
+
HuggingFace's static-cache machinery reads this value to determine whether the
|
| 207 |
+
cache is compileable and to size generation loops.
|
| 208 |
"""
|
| 209 |
+
return self._inference_sequence_length
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
def get_mask_sizes( # type: ignore[override]
|
| 212 |
self,
|
| 213 |
cache_position: torch.Tensor,
|
| 214 |
) -> tuple[int, int]:
|
| 215 |
+
"""Return the KV dimensions for HuggingFace causal mask construction.
|
| 216 |
|
| 217 |
+
Returns (inference_sequence_length, 0): the full static cache capacity as
|
| 218 |
+
kv_length and zero offset. HuggingFace reads these values to size the causal
|
| 219 |
+
attention mask when is_compileable is True.
|
| 220 |
"""
|
| 221 |
+
return self._inference_sequence_length, 0
|
|
|
|
|
|
|
|
|
__cache__sliding_window_cache.py
CHANGED
|
@@ -39,7 +39,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 39 |
device: Device on which to allocate cache storage.
|
| 40 |
"""
|
| 41 |
|
| 42 |
-
is_compileable =
|
| 43 |
is_sliding = True
|
| 44 |
|
| 45 |
def __init__(
|
|
|
|
| 39 |
device: Device on which to allocate cache storage.
|
| 40 |
"""
|
| 41 |
|
| 42 |
+
is_compileable = True
|
| 43 |
is_sliding = True
|
| 44 |
|
| 45 |
def __init__(
|
__cache__slow_mosrah_cache.py
CHANGED
|
@@ -41,9 +41,9 @@ class SlowMoSRAHCache(CacheLayerMixin):
|
|
| 41 |
batch_size: Number of sequences in the batch. Determines the first dimension
|
| 42 |
of all storage tensors.
|
| 43 |
device: Device on which to allocate all tensors. Should match the model device.
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
"""
|
| 48 |
|
| 49 |
is_compileable = False
|
|
@@ -55,22 +55,23 @@ class SlowMoSRAHCache(CacheLayerMixin):
|
|
| 55 |
head_dim: int,
|
| 56 |
batch_size: int,
|
| 57 |
device: torch.device,
|
| 58 |
-
|
| 59 |
) -> None:
|
| 60 |
super().__init__()
|
| 61 |
self.num_mosrah_heads = num_mosrah_heads
|
| 62 |
self.head_dim = head_dim
|
| 63 |
self.batch_size = batch_size
|
| 64 |
self.device = device
|
|
|
|
| 65 |
|
| 66 |
# Allocate primary storage into the mixin-standard self.keys / self.values so
|
| 67 |
# that inherited methods (offload, prefetch) operate on real tensors. _counts
|
| 68 |
# tracks valid occupancy per (batch, head) slot.
|
| 69 |
self.keys: torch.Tensor = torch.zeros(
|
| 70 |
-
batch_size, num_mosrah_heads,
|
| 71 |
)
|
| 72 |
self.values: torch.Tensor = torch.zeros(
|
| 73 |
-
batch_size, num_mosrah_heads,
|
| 74 |
)
|
| 75 |
self._counts: torch.Tensor = torch.zeros(
|
| 76 |
batch_size, num_mosrah_heads, dtype=torch.long, device=device
|
|
@@ -87,8 +88,8 @@ class SlowMoSRAHCache(CacheLayerMixin):
|
|
| 87 |
def buffer_capacity(self) -> int:
|
| 88 |
"""Current number of slots allocated per (batch, head) pair.
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
"""
|
| 93 |
return self.keys.shape[2]
|
| 94 |
|
|
@@ -111,8 +112,8 @@ class SlowMoSRAHCache(CacheLayerMixin):
|
|
| 111 |
because the t dimension is traversed from 0 to T-1 and counts are updated
|
| 112 |
immediately after each write.
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
|
| 117 |
Args:
|
| 118 |
key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
|
|
@@ -122,17 +123,19 @@ class SlowMoSRAHCache(CacheLayerMixin):
|
|
| 122 |
|
| 123 |
Returns:
|
| 124 |
Tuple of (keys, values, active_mask):
|
| 125 |
-
keys: (B, L,
|
| 126 |
-
values: (B, L,
|
| 127 |
-
active_mask: (B, L,
|
| 128 |
"""
|
| 129 |
B, L, T = active_mask.shape
|
| 130 |
|
| 131 |
-
# Expansion check uses the total active tokens per slot, same as the
|
| 132 |
-
# vectorized implementation, so both expand under identical conditions.
|
| 133 |
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
|
| 134 |
-
if (self._counts + incoming_delta).max().item() > self.
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# Write each active position into the next available slot for its (batch, head)
|
| 138 |
# pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
|
|
@@ -297,25 +300,3 @@ class SlowMoSRAHCache(CacheLayerMixin):
|
|
| 297 |
< self._counts.unsqueeze(-1)
|
| 298 |
)
|
| 299 |
|
| 300 |
-
def _expand(self) -> None:
|
| 301 |
-
"""Double the buffer capacity, preserving existing data.
|
| 302 |
-
|
| 303 |
-
Called by update() when an incoming batch of tokens would cause any
|
| 304 |
-
(batch, head) slot to exceed the current buffer capacity. All existing
|
| 305 |
-
key and value data is copied into the low half of the new buffer; the
|
| 306 |
-
high half is zero-initialised and will be filled by subsequent writes.
|
| 307 |
-
After reassignment, buffer_capacity reflects the new size automatically.
|
| 308 |
-
"""
|
| 309 |
-
old_cap = self.buffer_capacity
|
| 310 |
-
new_cap = old_cap * 2
|
| 311 |
-
dev = self.keys.device
|
| 312 |
-
new_keys = torch.zeros(
|
| 313 |
-
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
|
| 314 |
-
)
|
| 315 |
-
new_values = torch.zeros(
|
| 316 |
-
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
|
| 317 |
-
)
|
| 318 |
-
new_keys[:, :, :old_cap, :] = self.keys
|
| 319 |
-
new_values[:, :, :old_cap, :] = self.values
|
| 320 |
-
self.keys = new_keys
|
| 321 |
-
self.values = new_values
|
|
|
|
| 41 |
batch_size: Number of sequences in the batch. Determines the first dimension
|
| 42 |
of all storage tensors.
|
| 43 |
device: Device on which to allocate all tensors. Should match the model device.
|
| 44 |
+
mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to
|
| 45 |
+
config.mosrah_cache_length. The buffer never grows; if any slot would exceed
|
| 46 |
+
this capacity, update() raises a RuntimeError.
|
| 47 |
"""
|
| 48 |
|
| 49 |
is_compileable = False
|
|
|
|
| 55 |
head_dim: int,
|
| 56 |
batch_size: int,
|
| 57 |
device: torch.device,
|
| 58 |
+
mosrah_cache_length: int,
|
| 59 |
) -> None:
|
| 60 |
super().__init__()
|
| 61 |
self.num_mosrah_heads = num_mosrah_heads
|
| 62 |
self.head_dim = head_dim
|
| 63 |
self.batch_size = batch_size
|
| 64 |
self.device = device
|
| 65 |
+
self.mosrah_cache_length = mosrah_cache_length
|
| 66 |
|
| 67 |
# Allocate primary storage into the mixin-standard self.keys / self.values so
|
| 68 |
# that inherited methods (offload, prefetch) operate on real tensors. _counts
|
| 69 |
# tracks valid occupancy per (batch, head) slot.
|
| 70 |
self.keys: torch.Tensor = torch.zeros(
|
| 71 |
+
batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
|
| 72 |
)
|
| 73 |
self.values: torch.Tensor = torch.zeros(
|
| 74 |
+
batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
|
| 75 |
)
|
| 76 |
self._counts: torch.Tensor = torch.zeros(
|
| 77 |
batch_size, num_mosrah_heads, dtype=torch.long, device=device
|
|
|
|
| 88 |
def buffer_capacity(self) -> int:
|
| 89 |
"""Current number of slots allocated per (batch, head) pair.
|
| 90 |
|
| 91 |
+
Equal to mosrah_cache_length as supplied at construction. Derived from
|
| 92 |
+
self.keys so it remains consistent with the actual buffer shape.
|
| 93 |
"""
|
| 94 |
return self.keys.shape[2]
|
| 95 |
|
|
|
|
| 112 |
because the t dimension is traversed from 0 to T-1 and counts are updated
|
| 113 |
immediately after each write.
|
| 114 |
|
| 115 |
+
Raises RuntimeError before any writes if the incoming tokens would cause any
|
| 116 |
+
slot to exceed the static mosrah_cache_length capacity.
|
| 117 |
|
| 118 |
Args:
|
| 119 |
key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
|
|
|
|
| 123 |
|
| 124 |
Returns:
|
| 125 |
Tuple of (keys, values, active_mask):
|
| 126 |
+
keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots.
|
| 127 |
+
values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots.
|
| 128 |
+
active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written.
|
| 129 |
"""
|
| 130 |
B, L, T = active_mask.shape
|
| 131 |
|
|
|
|
|
|
|
| 132 |
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
|
| 133 |
+
if (self._counts + incoming_delta).max().item() > self.mosrah_cache_length:
|
| 134 |
+
raise RuntimeError(
|
| 135 |
+
f"SlowMoSRAHCache overflow: a (batch, head) slot would exceed the "
|
| 136 |
+
f"static buffer capacity of {self.mosrah_cache_length}. Increase "
|
| 137 |
+
f"mosrah_overallocation_factor in ShramConfig."
|
| 138 |
+
)
|
| 139 |
|
| 140 |
# Write each active position into the next available slot for its (batch, head)
|
| 141 |
# pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
|
|
|
|
| 300 |
< self._counts.unsqueeze(-1)
|
| 301 |
)
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.json
CHANGED
|
@@ -6,14 +6,16 @@
|
|
| 6 |
"AutoModelForCausalLM": "huggingface.ShramForCausalLM"
|
| 7 |
},
|
| 8 |
"beta": 32.0,
|
|
|
|
| 9 |
"head_dim": 16,
|
| 10 |
-
"hidden_size": 512,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
-
"
|
| 13 |
"local_rope_theta": 10000.0,
|
|
|
|
| 14 |
"model_type": "shram",
|
|
|
|
| 15 |
"mosrah_rope_theta": 10000.0,
|
| 16 |
-
"
|
| 17 |
"num_mosrah_heads": 16,
|
| 18 |
"num_selected_heads": 16,
|
| 19 |
"num_sliding_window_heads": 16,
|
|
@@ -21,7 +23,7 @@
|
|
| 21 |
"rope_mode": "main_sequence",
|
| 22 |
"tie_word_embeddings": false,
|
| 23 |
"training_sequence_length": 1024,
|
| 24 |
-
"transformers_version": "5.8.
|
| 25 |
"use_cache": true,
|
| 26 |
"vocab_size": 50277,
|
| 27 |
"window_size": 128
|
|
|
|
| 6 |
"AutoModelForCausalLM": "huggingface.ShramForCausalLM"
|
| 7 |
},
|
| 8 |
"beta": 32.0,
|
| 9 |
+
"embedding_width": 512,
|
| 10 |
"head_dim": 16,
|
|
|
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
+
"load_balance_p": 2.0,
|
| 13 |
"local_rope_theta": 10000.0,
|
| 14 |
+
"mlp_width": 1366,
|
| 15 |
"model_type": "shram",
|
| 16 |
+
"mosrah_overallocation_factor": 2.0,
|
| 17 |
"mosrah_rope_theta": 10000.0,
|
| 18 |
+
"num_decoder_layers": 12,
|
| 19 |
"num_mosrah_heads": 16,
|
| 20 |
"num_selected_heads": 16,
|
| 21 |
"num_sliding_window_heads": 16,
|
|
|
|
| 23 |
"rope_mode": "main_sequence",
|
| 24 |
"tie_word_embeddings": false,
|
| 25 |
"training_sequence_length": 1024,
|
| 26 |
+
"transformers_version": "5.8.1",
|
| 27 |
"use_cache": true,
|
| 28 |
"vocab_size": 50277,
|
| 29 |
"window_size": 128
|
configuration.py
CHANGED
|
@@ -11,6 +11,8 @@ parameters directly and constructs its own RotaryEmbedding instance explicitly
|
|
| 11 |
HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
|
| 12 |
"""
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from transformers import PretrainedConfig
|
| 15 |
|
| 16 |
|
|
@@ -77,6 +79,15 @@ class ShramConfig(PretrainedConfig):
|
|
| 77 |
use_cache: Whether to return past_key_values for KV caching.
|
| 78 |
output_hidden_states: Whether to return hidden states after each layer.
|
| 79 |
tie_word_embeddings: Whether input embedding and LM head share weights.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""
|
| 81 |
|
| 82 |
model_type = "shram"
|
|
@@ -109,7 +120,9 @@ class ShramConfig(PretrainedConfig):
|
|
| 109 |
use_cache: bool = True,
|
| 110 |
output_hidden_states: bool = False,
|
| 111 |
tie_word_embeddings: bool = False,
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
):
|
| 114 |
if head_dim % 2 != 0:
|
| 115 |
raise ValueError(
|
|
@@ -137,10 +150,22 @@ class ShramConfig(PretrainedConfig):
|
|
| 137 |
f"got {inference_sequence_length}."
|
| 138 |
)
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
self.vocab_size = vocab_size
|
| 141 |
-
self.
|
| 142 |
-
self.
|
| 143 |
-
self.
|
| 144 |
self.num_sliding_window_heads = num_sliding_window_heads
|
| 145 |
self.num_mosrah_heads = num_mosrah_heads
|
| 146 |
self.num_selected_heads = num_selected_heads
|
|
@@ -154,13 +179,15 @@ class ShramConfig(PretrainedConfig):
|
|
| 154 |
self.inference_sequence_length = inference_sequence_length
|
| 155 |
self.alpha = alpha
|
| 156 |
self.beta = beta
|
|
|
|
|
|
|
| 157 |
self.attention_dropout = attention_dropout
|
| 158 |
self.use_cache = use_cache
|
| 159 |
|
| 160 |
super().__init__(
|
| 161 |
tie_word_embeddings=tie_word_embeddings,
|
| 162 |
output_hidden_states=output_hidden_states,
|
| 163 |
-
**kwargs
|
| 164 |
)
|
| 165 |
|
| 166 |
# Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
|
|
@@ -176,3 +203,47 @@ class ShramConfig(PretrainedConfig):
|
|
| 176 |
"""
|
| 177 |
return self.inference_sequence_length / self.training_sequence_length
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
from transformers import PretrainedConfig
|
| 17 |
|
| 18 |
|
|
|
|
| 79 |
use_cache: Whether to return past_key_values for KV caching.
|
| 80 |
output_hidden_states: Whether to return hidden states after each layer.
|
| 81 |
tie_word_embeddings: Whether input embedding and LM head share weights.
|
| 82 |
+
mosrah_overallocation_factor: Overallocation multiplier for the expert packing
|
| 83 |
+
buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
|
| 84 |
+
num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
|
| 85 |
+
Must be > 1.0 to guarantee a buffer larger than the balanced-routing
|
| 86 |
+
baseline. Default 2.0.
|
| 87 |
+
load_balance_p: Exponent p for the p-mean aggregation of per-item routing
|
| 88 |
+
frequencies into the load balance signal. Higher p weights aggregation
|
| 89 |
+
toward the worst-case batch item, making the correction signal more
|
| 90 |
+
sensitive to per-item allocation spikes. Must be positive. Default 2.0.
|
| 91 |
"""
|
| 92 |
|
| 93 |
model_type = "shram"
|
|
|
|
| 120 |
use_cache: bool = True,
|
| 121 |
output_hidden_states: bool = False,
|
| 122 |
tie_word_embeddings: bool = False,
|
| 123 |
+
mosrah_overallocation_factor: float = 2.0,
|
| 124 |
+
load_balance_p: float = 2.0,
|
| 125 |
+
**kwargs
|
| 126 |
):
|
| 127 |
if head_dim % 2 != 0:
|
| 128 |
raise ValueError(
|
|
|
|
| 150 |
f"got {inference_sequence_length}."
|
| 151 |
)
|
| 152 |
|
| 153 |
+
if mosrah_overallocation_factor <= 1.0:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
|
| 156 |
+
f"buffer larger than the balanced-routing baseline. "
|
| 157 |
+
f"Got {mosrah_overallocation_factor}."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if load_balance_p <= 0.0:
|
| 161 |
+
raise ValueError(
|
| 162 |
+
f"load_balance_p must be positive, got {load_balance_p}."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
self.vocab_size = vocab_size
|
| 166 |
+
self.embedding_width = embedding_width
|
| 167 |
+
self.mlp_width = mlp_width
|
| 168 |
+
self.num_decoder_layers = num_decoder_layers
|
| 169 |
self.num_sliding_window_heads = num_sliding_window_heads
|
| 170 |
self.num_mosrah_heads = num_mosrah_heads
|
| 171 |
self.num_selected_heads = num_selected_heads
|
|
|
|
| 179 |
self.inference_sequence_length = inference_sequence_length
|
| 180 |
self.alpha = alpha
|
| 181 |
self.beta = beta
|
| 182 |
+
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 183 |
+
self.load_balance_p = load_balance_p
|
| 184 |
self.attention_dropout = attention_dropout
|
| 185 |
self.use_cache = use_cache
|
| 186 |
|
| 187 |
super().__init__(
|
| 188 |
tie_word_embeddings=tie_word_embeddings,
|
| 189 |
output_hidden_states=output_hidden_states,
|
| 190 |
+
**kwargs
|
| 191 |
)
|
| 192 |
|
| 193 |
# Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
|
|
|
|
| 203 |
"""
|
| 204 |
return self.inference_sequence_length / self.training_sequence_length
|
| 205 |
|
| 206 |
+
@property
|
| 207 |
+
def mosrah_packed_length(self) -> int:
|
| 208 |
+
"""Static packed time dimension T for expert packing.
|
| 209 |
+
|
| 210 |
+
The expected tokens per expert under perfectly balanced routing is
|
| 211 |
+
``training_sequence_length * num_selected_heads / num_mosrah_heads``.
|
| 212 |
+
Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
|
| 213 |
+
that baseline. The ceiling ensures T is always an integer >= 1.
|
| 214 |
+
|
| 215 |
+
All consumers of the packed buffer size must read this property rather
|
| 216 |
+
than deriving T independently.
|
| 217 |
+
"""
|
| 218 |
+
return math.ceil(
|
| 219 |
+
self.training_sequence_length
|
| 220 |
+
* self.num_selected_heads
|
| 221 |
+
/ self.num_mosrah_heads
|
| 222 |
+
* self.mosrah_overallocation_factor
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def mosrah_cache_length(self) -> int:
|
| 227 |
+
"""Static per-(batch, head) slot capacity for the MoSRAH inference cache.
|
| 228 |
+
|
| 229 |
+
The expected tokens per expert over the full inference context under perfectly
|
| 230 |
+
balanced routing is ``inference_sequence_length * num_selected_heads /
|
| 231 |
+
num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
|
| 232 |
+
a buffer above that baseline. The ceiling ensures the result is always an
|
| 233 |
+
integer >= 1.
|
| 234 |
+
|
| 235 |
+
Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
|
| 236 |
+
using ``training_sequence_length``. This property uses
|
| 237 |
+
``inference_sequence_length`` because the cache must hold the full accumulated
|
| 238 |
+
token history across the entire inference run.
|
| 239 |
+
|
| 240 |
+
All consumers of the MoSRAH cache buffer size must read this property rather
|
| 241 |
+
than deriving the capacity independently.
|
| 242 |
+
"""
|
| 243 |
+
return math.ceil(
|
| 244 |
+
self.inference_sequence_length
|
| 245 |
+
* self.num_selected_heads
|
| 246 |
+
/ self.num_mosrah_heads
|
| 247 |
+
* self.mosrah_overallocation_factor
|
| 248 |
+
)
|
| 249 |
+
|
decoder_layer.py
CHANGED
|
@@ -46,8 +46,8 @@ class DecoderLayer(nn.Module):
|
|
| 46 |
|
| 47 |
def __init__(self, config: ShramConfig) -> None:
|
| 48 |
super().__init__()
|
| 49 |
-
self.attn_norm = nn.RMSNorm(config.
|
| 50 |
-
self.mlp_norm = nn.RMSNorm(config.
|
| 51 |
self.attention = SHRAMHybridLayer(config)
|
| 52 |
self.mlp = SwiGLUMLP(config)
|
| 53 |
|
|
|
|
| 46 |
|
| 47 |
def __init__(self, config: ShramConfig) -> None:
|
| 48 |
super().__init__()
|
| 49 |
+
self.attn_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
|
| 50 |
+
self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
|
| 51 |
self.attention = SHRAMHybridLayer(config)
|
| 52 |
self.mlp = SwiGLUMLP(config)
|
| 53 |
|
huggingface.py
CHANGED
|
@@ -74,9 +74,9 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 74 |
|
| 75 |
def __init__(self, config: ShramConfig) -> None:
|
| 76 |
super().__init__(config)
|
| 77 |
-
self.embed_tokens = nn.Embedding(config.vocab_size, config.
|
| 78 |
self.model = ShramModel(config)
|
| 79 |
-
self.lm_head = nn.Linear(config.
|
| 80 |
self._configure_tied_embeddings()
|
| 81 |
self.post_init()
|
| 82 |
|
|
@@ -127,12 +127,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 127 |
) -> ShramCache:
|
| 128 |
"""Construct a fresh top-level SHRAM cache."""
|
| 129 |
return ShramCache(
|
| 130 |
-
|
| 131 |
-
sliding_window=self.config.window_size,
|
| 132 |
-
num_local_heads=self.config.num_sliding_window_heads,
|
| 133 |
-
local_head_dim=self.config.head_dim,
|
| 134 |
-
num_mosrah_heads=self.config.num_mosrah_heads,
|
| 135 |
-
mosrah_head_dim=self.config.head_dim,
|
| 136 |
batch_size=batch_size,
|
| 137 |
device=device,
|
| 138 |
)
|
|
@@ -231,6 +226,26 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 231 |
past_key_values.reorder_cache(beam_idx)
|
| 232 |
return past_key_values
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
|
| 235 |
"""Validate token IDs at the wrapper boundary."""
|
| 236 |
if input_ids.ndim != 2:
|
|
@@ -352,6 +367,63 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 352 |
f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
|
| 353 |
)
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
def _standardize_full_attention_mask(
|
| 356 |
self,
|
| 357 |
input_ids: torch.Tensor,
|
|
@@ -449,6 +521,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 449 |
# This keeps the main sequence readable while ensuring invalid states
|
| 450 |
# fail before they can silently contaminate backbone execution.
|
| 451 |
# ------------------------------------------------------------------
|
|
|
|
| 452 |
self._validate_input_ids(input_ids)
|
| 453 |
self._validate_attention_mask(input_ids, attention_mask)
|
| 454 |
self._validate_position_ids(input_ids, position_ids)
|
|
@@ -487,6 +560,10 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 487 |
)
|
| 488 |
shram_cache: ShramCache | None = past_key_values if use_cache else None
|
| 489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
# ------------------------------------------------------------------
|
| 491 |
# Core wrapper responsibilities.
|
| 492 |
#
|
|
|
|
| 74 |
|
| 75 |
def __init__(self, config: ShramConfig) -> None:
|
| 76 |
super().__init__(config)
|
| 77 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_width)
|
| 78 |
self.model = ShramModel(config)
|
| 79 |
+
self.lm_head = nn.Linear(config.embedding_width, config.vocab_size, bias=False)
|
| 80 |
self._configure_tied_embeddings()
|
| 81 |
self.post_init()
|
| 82 |
|
|
|
|
| 127 |
) -> ShramCache:
|
| 128 |
"""Construct a fresh top-level SHRAM cache."""
|
| 129 |
return ShramCache(
|
| 130 |
+
config=self.config,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
batch_size=batch_size,
|
| 132 |
device=device,
|
| 133 |
)
|
|
|
|
| 226 |
past_key_values.reorder_cache(beam_idx)
|
| 227 |
return past_key_values
|
| 228 |
|
| 229 |
+
@staticmethod
|
| 230 |
+
def create_masks_for_generate(
|
| 231 |
+
config: Any,
|
| 232 |
+
inputs_embeds: torch.Tensor,
|
| 233 |
+
attention_mask: torch.Tensor | None,
|
| 234 |
+
past_key_values: Cache | None,
|
| 235 |
+
position_ids: torch.Tensor | None = None,
|
| 236 |
+
**kwargs: Any,
|
| 237 |
+
) -> torch.Tensor | None:
|
| 238 |
+
"""Return the 2D attention_mask unchanged.
|
| 239 |
+
|
| 240 |
+
HuggingFace calls this during compiled generation to convert the 2D
|
| 241 |
+
attention mask into a 4D causal additive-bias mask. SHRAM uses flex
|
| 242 |
+
attention with custom masking and constructs causality internally; the
|
| 243 |
+
4D format is incompatible with the SHRAM masking contract. Overriding
|
| 244 |
+
as a no-op restores symmetry between compiled and non-compiled pathways
|
| 245 |
+
without any loss of correctness or performance (see Unit 19.G.4).
|
| 246 |
+
"""
|
| 247 |
+
return attention_mask
|
| 248 |
+
|
| 249 |
def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
|
| 250 |
"""Validate token IDs at the wrapper boundary."""
|
| 251 |
if input_ids.ndim != 2:
|
|
|
|
| 367 |
f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
|
| 368 |
)
|
| 369 |
|
| 370 |
+
@staticmethod
|
| 371 |
+
def _enforce_uncached_starting_position(condition: torch.Tensor) -> None:
|
| 372 |
+
"""Enforce that an uncached forward pass begins at position 0.
|
| 373 |
+
|
| 374 |
+
An uncached forward has no prior KV state. Nonzero starting positions
|
| 375 |
+
produce silently incorrect RoPE encoding and attention outputs with no
|
| 376 |
+
downstream diagnostic. This method intercepts that misuse at the
|
| 377 |
+
outermost boundary before any backbone computation runs.
|
| 378 |
+
|
| 379 |
+
To resolve a violation: either supply a ShramCache populated with the
|
| 380 |
+
prefix (for continued decoding), or rebase the sequence so positions
|
| 381 |
+
start at 0.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
condition: Scalar bool tensor. True = all batch items start at 0
|
| 385 |
+
(valid); False = at least one batch item starts nonzero
|
| 386 |
+
(violated).
|
| 387 |
+
"""
|
| 388 |
+
if torch.compiler.is_compiling():
|
| 389 |
+
# bool.item() is not captured as a SymBool by dynamo; converting to
|
| 390 |
+
# int first produces a SymInt, and the Python comparison (!=0) then
|
| 391 |
+
# yields a SymBool that torch._check folds into the compiled graph.
|
| 392 |
+
condition_as_int = condition.to(torch.int).item()
|
| 393 |
+
torch._check(condition_as_int != 0)
|
| 394 |
+
else:
|
| 395 |
+
if not condition.item():
|
| 396 |
+
raise RuntimeError(
|
| 397 |
+
"Uncached ShramForCausalLM forward does not support nonzero "
|
| 398 |
+
"starting positions. Either provide a ShramCache populated "
|
| 399 |
+
"with the prefix for continued decoding, or rebase the "
|
| 400 |
+
"uncached sequence to start at 0.",
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
@staticmethod
|
| 404 |
+
def _enforce_capture_scalar_outputs() -> None:
|
| 405 |
+
"""Enforce that capture_scalar_outputs is enabled when compiling.
|
| 406 |
+
|
| 407 |
+
The safety checks in this model (e.g. position-zero constraint, packing
|
| 408 |
+
overflow detection) rely on torch._check folding into the compiled graph,
|
| 409 |
+
which requires torch._dynamo.config.capture_scalar_outputs = True. Without
|
| 410 |
+
it those checks are silently absent in the compiled model while appearing
|
| 411 |
+
to work in eager mode — a misconfiguration with no diagnostic output.
|
| 412 |
+
|
| 413 |
+
This method fires during dynamo tracing so the missing flag is surfaced
|
| 414 |
+
immediately at compile time rather than discovered from downstream failures.
|
| 415 |
+
"""
|
| 416 |
+
if torch.compiler.is_compiling():
|
| 417 |
+
torch._check(
|
| 418 |
+
torch._dynamo.config.capture_scalar_outputs,
|
| 419 |
+
lambda: RuntimeError(
|
| 420 |
+
"ShramForCausalLM requires torch._dynamo.config.capture_scalar_outputs = True "
|
| 421 |
+
"when compiled. Without it, runtime safety checks (position constraints, "
|
| 422 |
+
"overflow detection) are silently absent in the compiled model. Set the flag "
|
| 423 |
+
"before calling torch.compile()."
|
| 424 |
+
),
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
def _standardize_full_attention_mask(
|
| 428 |
self,
|
| 429 |
input_ids: torch.Tensor,
|
|
|
|
| 521 |
# This keeps the main sequence readable while ensuring invalid states
|
| 522 |
# fail before they can silently contaminate backbone execution.
|
| 523 |
# ------------------------------------------------------------------
|
| 524 |
+
self._enforce_capture_scalar_outputs()
|
| 525 |
self._validate_input_ids(input_ids)
|
| 526 |
self._validate_attention_mask(input_ids, attention_mask)
|
| 527 |
self._validate_position_ids(input_ids, position_ids)
|
|
|
|
| 560 |
)
|
| 561 |
shram_cache: ShramCache | None = past_key_values if use_cache else None
|
| 562 |
|
| 563 |
+
if shram_cache is None:
|
| 564 |
+
positions_start_sane = torch.all(current_position_ids[:, 0] == 0)
|
| 565 |
+
self._enforce_uncached_starting_position(positions_start_sane)
|
| 566 |
+
|
| 567 |
# ------------------------------------------------------------------
|
| 568 |
# Core wrapper responsibilities.
|
| 569 |
#
|
mlp.py
CHANGED
|
@@ -36,9 +36,9 @@ class SwiGLUMLP(nn.Module):
|
|
| 36 |
|
| 37 |
def __init__(self, config: PretrainedConfig) -> None:
|
| 38 |
super().__init__()
|
| 39 |
-
self.gate_proj = nn.Linear(config.
|
| 40 |
-
self.up_proj = nn.Linear(config.
|
| 41 |
-
self.down_proj = nn.Linear(config.
|
| 42 |
|
| 43 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
"""Apply the SwiGLU feed-forward transformation.
|
|
|
|
| 36 |
|
| 37 |
def __init__(self, config: PretrainedConfig) -> None:
|
| 38 |
super().__init__()
|
| 39 |
+
self.gate_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False)
|
| 40 |
+
self.up_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False)
|
| 41 |
+
self.down_proj = nn.Linear(config.mlp_width, config.embedding_width, bias=False)
|
| 42 |
|
| 43 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
"""Apply the SwiGLU feed-forward transformation.
|
model.py
CHANGED
|
@@ -58,9 +58,9 @@ class ShramModel(nn.Module):
|
|
| 58 |
super().__init__()
|
| 59 |
self.config = config
|
| 60 |
self.layers = nn.ModuleList(
|
| 61 |
-
[DecoderLayer(config) for _ in range(config.
|
| 62 |
)
|
| 63 |
-
self.norm = nn.RMSNorm(config.
|
| 64 |
|
| 65 |
def num_mosrah_parameters(self) -> int:
|
| 66 |
"""Return the total number of trainable MoSRAH parameters across all decoder layers."""
|
|
|
|
| 58 |
super().__init__()
|
| 59 |
self.config = config
|
| 60 |
self.layers = nn.ModuleList(
|
| 61 |
+
[DecoderLayer(config) for _ in range(config.num_decoder_layers)]
|
| 62 |
)
|
| 63 |
+
self.norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
|
| 64 |
|
| 65 |
def num_mosrah_parameters(self) -> int:
|
| 66 |
"""Return the total number of trainable MoSRAH parameters across all decoder layers."""
|
rope.py
CHANGED
|
@@ -26,10 +26,10 @@ Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explic
|
|
| 26 |
parameters — no shared instance, no config reading. See Unit 5.A design decisions.
|
| 27 |
|
| 28 |
Cache sharing: all instances with identical parameters share one cos/sin table via a
|
| 29 |
-
class-level registry. The first instance that needs a particular (parameters,
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
"""
|
| 34 |
|
| 35 |
import math
|
|
@@ -66,17 +66,21 @@ class RotaryEmbedding(nn.Module):
|
|
| 66 |
h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
|
| 67 |
config object is read inside this module.
|
| 68 |
|
| 69 |
-
The cos/sin
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
Args:
|
| 75 |
mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
|
| 76 |
head_dim: Per-head embedding dimension ``u``. Must be even.
|
| 77 |
theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
dilation: Scale factor ``s = C_target / C_train`` — how much the context
|
| 81 |
window is extended beyond training length. Required for ``mode="yarn"``.
|
| 82 |
When ``dilation=1.0``, YaRN reduces to standard RoPE.
|
|
@@ -88,11 +92,11 @@ class RotaryEmbedding(nn.Module):
|
|
| 88 |
|
| 89 |
Raises:
|
| 90 |
NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
|
| 91 |
-
ValueError: If ``mode="yarn"`` and any of ``
|
| 92 |
-
``
|
| 93 |
"""
|
| 94 |
|
| 95 |
-
# Maps (freq_key,
|
| 96 |
# Shared across all RotaryEmbedding instances in the process. Keys include device
|
| 97 |
# and dtype so that tables built on different devices or in different precisions
|
| 98 |
# are stored independently.
|
|
@@ -103,7 +107,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 103 |
mode: str,
|
| 104 |
head_dim: int,
|
| 105 |
theta: float,
|
| 106 |
-
|
| 107 |
dilation: float | None = None,
|
| 108 |
alpha: float | None = None,
|
| 109 |
beta: float | None = None,
|
|
@@ -112,8 +116,9 @@ class RotaryEmbedding(nn.Module):
|
|
| 112 |
super().__init__()
|
| 113 |
|
| 114 |
self._validate_mode(mode)
|
| 115 |
-
self._validate_yarn_params(mode,
|
| 116 |
self.mode = mode
|
|
|
|
| 117 |
|
| 118 |
# Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
|
| 119 |
# d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
|
|
@@ -128,9 +133,14 @@ class RotaryEmbedding(nn.Module):
|
|
| 128 |
else: # yarn
|
| 129 |
s = dilation
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
# r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
|
| 132 |
# function to classify each dimension into one of three regimes.
|
| 133 |
-
normalized_freqs =
|
| 134 |
|
| 135 |
# γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
|
| 136 |
# linear blend between α and β.
|
|
@@ -142,16 +152,13 @@ class RotaryEmbedding(nn.Module):
|
|
| 142 |
# A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
|
| 143 |
self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
|
| 144 |
|
| 145 |
-
# freq_key uniquely identifies the parameter set that produced rotation_freqs
|
| 146 |
-
#
|
|
|
|
| 147 |
if mode == "default":
|
| 148 |
-
self._freq_key: tuple = ("default", head_dim,
|
| 149 |
else:
|
| 150 |
-
self._freq_key = (
|
| 151 |
-
"yarn", head_dim, float(theta),
|
| 152 |
-
int(initial_seq_length), float(dilation),
|
| 153 |
-
float(alpha), float(beta),
|
| 154 |
-
)
|
| 155 |
|
| 156 |
# rotation_freqs is a non-persistent buffer so it moves with the model across
|
| 157 |
# devices via .to() / .cuda() without appearing in saved checkpoints.
|
|
@@ -167,6 +174,11 @@ class RotaryEmbedding(nn.Module):
|
|
| 167 |
self._cos_cached: torch.Tensor | None = None
|
| 168 |
self._sin_cached: torch.Tensor | None = None
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# ---------------------------------------------------------------------------
|
| 171 |
# Validation helpers
|
| 172 |
# ---------------------------------------------------------------------------
|
|
@@ -182,7 +194,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 182 |
@staticmethod
|
| 183 |
def _validate_yarn_params(
|
| 184 |
mode: str,
|
| 185 |
-
initial_seq_length: int | None,
|
| 186 |
dilation: float | None,
|
| 187 |
alpha: float | None,
|
| 188 |
beta: float | None,
|
|
@@ -192,7 +203,6 @@ class RotaryEmbedding(nn.Module):
|
|
| 192 |
return
|
| 193 |
missing = [
|
| 194 |
name for name, val in [
|
| 195 |
-
("initial_seq_length", initial_seq_length),
|
| 196 |
("dilation", dilation),
|
| 197 |
("alpha", alpha),
|
| 198 |
("beta", beta),
|
|
@@ -206,20 +216,23 @@ class RotaryEmbedding(nn.Module):
|
|
| 206 |
# Cache management
|
| 207 |
# ---------------------------------------------------------------------------
|
| 208 |
|
| 209 |
-
def
|
| 210 |
-
"""Build the cos/sin table to cover positions [0,
|
| 211 |
|
| 212 |
Checks the class-level registry first. If a table already exists for this
|
| 213 |
-
exact (parameters,
|
| 214 |
otherwise it is computed and stored. The instance attributes are pointed at
|
| 215 |
the registry entry so that all layers sharing the same parametrisation
|
| 216 |
reference the same tensor.
|
| 217 |
"""
|
| 218 |
-
cache_key = (self._freq_key,
|
| 219 |
|
| 220 |
if cache_key not in RotaryEmbedding._cache:
|
| 221 |
-
positions = torch.arange(
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
| 223 |
freqs = torch.outer(
|
| 224 |
positions,
|
| 225 |
self.rotation_freqs.to(device=device, dtype=torch.float32),
|
|
@@ -240,11 +253,12 @@ class RotaryEmbedding(nn.Module):
|
|
| 240 |
) -> tuple[torch.Tensor, torch.Tensor, float]:
|
| 241 |
"""Apply rotary embeddings to query and key tensors.
|
| 242 |
|
| 243 |
-
The cos/sin
|
| 244 |
-
|
|
|
|
| 245 |
|
| 246 |
-
``position_ids`` may be any integer tensor shape. Its values
|
| 247 |
-
|
| 248 |
|
| 249 |
- h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
|
| 250 |
- BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
|
|
@@ -262,18 +276,11 @@ class RotaryEmbedding(nn.Module):
|
|
| 262 |
1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
|
| 263 |
apply to attention logits before softmax.
|
| 264 |
"""
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
cache_missing = self._cos_cached is None
|
| 271 |
-
cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
|
| 272 |
-
wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
|
| 273 |
-
wrong_device = not cache_missing and self._cos_cached.device != q.device
|
| 274 |
-
|
| 275 |
-
if cache_missing or cache_too_short or wrong_dtype or wrong_device:
|
| 276 |
-
self._extend_cache(seq_len, device=q.device, dtype=q.dtype)
|
| 277 |
|
| 278 |
cos = self._cos_cached[position_ids]
|
| 279 |
sin = self._sin_cached[position_ids]
|
|
|
|
| 26 |
parameters — no shared instance, no config reading. See Unit 5.A design decisions.
|
| 27 |
|
| 28 |
Cache sharing: all instances with identical parameters share one cos/sin table via a
|
| 29 |
+
class-level registry. The first instance that needs a particular (parameters, device,
|
| 30 |
+
dtype) combination builds the table; all subsequent instances reference it directly.
|
| 31 |
+
This avoids redundant builds across the num_hidden_layers instances that share the
|
| 32 |
+
same parametrisation.
|
| 33 |
"""
|
| 34 |
|
| 35 |
import math
|
|
|
|
| 66 |
h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
|
| 67 |
config object is read inside this module.
|
| 68 |
|
| 69 |
+
The cos/sin table is built at construction time to cover all positions in
|
| 70 |
+
``[0, maximum_sequence_length)``. In forward, the table is rebuilt only if
|
| 71 |
+
the query tensor's dtype or device has changed since construction.
|
| 72 |
+
|
| 73 |
+
Instances with identical parameters share one cos/sin table via the class-level
|
| 74 |
+
``_cache`` registry, avoiding redundant computation across decoder layers.
|
| 75 |
|
| 76 |
Args:
|
| 77 |
mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
|
| 78 |
head_dim: Per-head embedding dimension ``u``. Must be even.
|
| 79 |
theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
|
| 80 |
+
maximum_sequence_length: Maximum number of positions the table must cover.
|
| 81 |
+
The cos/sin table is preallocated to this length at construction time.
|
| 82 |
+
For ``mode="yarn"``, the training context length C_train is derived
|
| 83 |
+
internally as ``round(maximum_sequence_length / dilation)``.
|
| 84 |
dilation: Scale factor ``s = C_target / C_train`` — how much the context
|
| 85 |
window is extended beyond training length. Required for ``mode="yarn"``.
|
| 86 |
When ``dilation=1.0``, YaRN reduces to standard RoPE.
|
|
|
|
| 92 |
|
| 93 |
Raises:
|
| 94 |
NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
|
| 95 |
+
ValueError: If ``mode="yarn"`` and any of ``dilation``, ``alpha``,
|
| 96 |
+
``beta`` are absent.
|
| 97 |
"""
|
| 98 |
|
| 99 |
+
# Maps (freq_key, device_str, dtype_str) → (cos_table, sin_table).
|
| 100 |
# Shared across all RotaryEmbedding instances in the process. Keys include device
|
| 101 |
# and dtype so that tables built on different devices or in different precisions
|
| 102 |
# are stored independently.
|
|
|
|
| 107 |
mode: str,
|
| 108 |
head_dim: int,
|
| 109 |
theta: float,
|
| 110 |
+
maximum_sequence_length: int,
|
| 111 |
dilation: float | None = None,
|
| 112 |
alpha: float | None = None,
|
| 113 |
beta: float | None = None,
|
|
|
|
| 116 |
super().__init__()
|
| 117 |
|
| 118 |
self._validate_mode(mode)
|
| 119 |
+
self._validate_yarn_params(mode, dilation, alpha, beta)
|
| 120 |
self.mode = mode
|
| 121 |
+
self._maximum_sequence_length = maximum_sequence_length
|
| 122 |
|
| 123 |
# Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
|
| 124 |
# d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
|
|
|
|
| 133 |
else: # yarn
|
| 134 |
s = dilation
|
| 135 |
|
| 136 |
+
# C_train is the training context length, recovered from the inference
|
| 137 |
+
# context length and the dilation factor. round() guards against floating
|
| 138 |
+
# point error since both underlying quantities are integers.
|
| 139 |
+
c_train: int = round(maximum_sequence_length / dilation)
|
| 140 |
+
|
| 141 |
# r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
|
| 142 |
# function to classify each dimension into one of three regimes.
|
| 143 |
+
normalized_freqs = c_train * base_freqs / (2.0 * math.pi)
|
| 144 |
|
| 145 |
# γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
|
| 146 |
# linear blend between α and β.
|
|
|
|
| 152 |
# A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
|
| 153 |
self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
|
| 154 |
|
| 155 |
+
# freq_key uniquely identifies the parameter set that produced rotation_freqs,
|
| 156 |
+
# including maximum_sequence_length so instances with different table sizes
|
| 157 |
+
# do not collide in the registry.
|
| 158 |
if mode == "default":
|
| 159 |
+
self._freq_key: tuple = ("default", head_dim, theta, maximum_sequence_length)
|
| 160 |
else:
|
| 161 |
+
self._freq_key = ("yarn", head_dim, theta, maximum_sequence_length, dilation, alpha, beta)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
# rotation_freqs is a non-persistent buffer so it moves with the model across
|
| 164 |
# devices via .to() / .cuda() without appearing in saved checkpoints.
|
|
|
|
| 174 |
self._cos_cached: torch.Tensor | None = None
|
| 175 |
self._sin_cached: torch.Tensor | None = None
|
| 176 |
|
| 177 |
+
# Build the table at construction time. Forward rebuilds only on dtype or
|
| 178 |
+
# device change. If no device is specified, build on CPU as the default.
|
| 179 |
+
build_device = device if device is not None else torch.device("cpu")
|
| 180 |
+
self._build_cache(device=build_device, dtype=torch.float32)
|
| 181 |
+
|
| 182 |
# ---------------------------------------------------------------------------
|
| 183 |
# Validation helpers
|
| 184 |
# ---------------------------------------------------------------------------
|
|
|
|
| 194 |
@staticmethod
|
| 195 |
def _validate_yarn_params(
|
| 196 |
mode: str,
|
|
|
|
| 197 |
dilation: float | None,
|
| 198 |
alpha: float | None,
|
| 199 |
beta: float | None,
|
|
|
|
| 203 |
return
|
| 204 |
missing = [
|
| 205 |
name for name, val in [
|
|
|
|
| 206 |
("dilation", dilation),
|
| 207 |
("alpha", alpha),
|
| 208 |
("beta", beta),
|
|
|
|
| 216 |
# Cache management
|
| 217 |
# ---------------------------------------------------------------------------
|
| 218 |
|
| 219 |
+
def _build_cache(self, device: torch.device, dtype: torch.dtype) -> None:
|
| 220 |
+
"""Build the cos/sin table to cover positions [0, maximum_sequence_length).
|
| 221 |
|
| 222 |
Checks the class-level registry first. If a table already exists for this
|
| 223 |
+
exact (parameters, device, dtype) combination it is reused directly;
|
| 224 |
otherwise it is computed and stored. The instance attributes are pointed at
|
| 225 |
the registry entry so that all layers sharing the same parametrisation
|
| 226 |
reference the same tensor.
|
| 227 |
"""
|
| 228 |
+
cache_key = (self._freq_key, str(device), str(dtype))
|
| 229 |
|
| 230 |
if cache_key not in RotaryEmbedding._cache:
|
| 231 |
+
positions = torch.arange(
|
| 232 |
+
self._maximum_sequence_length, device=device, dtype=torch.float32
|
| 233 |
+
)
|
| 234 |
+
# outer product → (maximum_sequence_length, head_dim // 2);
|
| 235 |
+
# duplicate to (maximum_sequence_length, head_dim)
|
| 236 |
freqs = torch.outer(
|
| 237 |
positions,
|
| 238 |
self.rotation_freqs.to(device=device, dtype=torch.float32),
|
|
|
|
| 253 |
) -> tuple[torch.Tensor, torch.Tensor, float]:
|
| 254 |
"""Apply rotary embeddings to query and key tensors.
|
| 255 |
|
| 256 |
+
The cos/sin table is built at construction time. It is rebuilt here only
|
| 257 |
+
if ``q``'s dtype or device differs from the cached table — for example,
|
| 258 |
+
after moving the model to a different device via ``.cuda()``.
|
| 259 |
|
| 260 |
+
``position_ids`` may be any integer tensor shape. Its values must be in
|
| 261 |
+
``[0, maximum_sequence_length)``:
|
| 262 |
|
| 263 |
- h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
|
| 264 |
- BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
|
|
|
|
| 276 |
1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
|
| 277 |
apply to attention logits before softmax.
|
| 278 |
"""
|
| 279 |
+
wrong_dtype = self._cos_cached.dtype != q.dtype
|
| 280 |
+
wrong_device = self._cos_cached.device != q.device
|
| 281 |
+
|
| 282 |
+
if wrong_dtype or wrong_device:
|
| 283 |
+
self._build_cache(device=q.device, dtype=q.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
cos = self._cos_cached[position_ids]
|
| 286 |
sin = self._sin_cached[position_ids]
|