Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- README.md +1 -1
- configuration.py +262 -262
- huggingface.py +164 -203
README.md
CHANGED
|
@@ -48,7 +48,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
|
| 48 |
config = AutoConfig.from_pretrained(
|
| 49 |
"smithblack-0/SHRAM-dev",
|
| 50 |
trust_remote_code=True,
|
| 51 |
-
|
| 52 |
num_mosrah_heads=32, # example override
|
| 53 |
)
|
| 54 |
|
|
|
|
| 48 |
config = AutoConfig.from_pretrained(
|
| 49 |
"smithblack-0/SHRAM-dev",
|
| 50 |
trust_remote_code=True,
|
| 51 |
+
num_decoder_layers=16, # example override
|
| 52 |
num_mosrah_heads=32, # example override
|
| 53 |
)
|
| 54 |
|
configuration.py
CHANGED
|
@@ -1,262 +1,262 @@
|
|
| 1 |
-
"""Configuration for the SHRAM transformer.
|
| 2 |
-
|
| 3 |
-
All architectural parameters that vary across model scales or are meaningful research
|
| 4 |
-
variables are expressed here. Architectural constants (no bias in linear layers,
|
| 5 |
-
SwiGLU activation with SiLU gate) are implemented in the relevant modules and
|
| 6 |
-
documented at the point of use — they are not config parameters because they do not
|
| 7 |
-
vary and changing them produces a different architecture, not a different scale.
|
| 8 |
-
|
| 9 |
-
RoPE configuration is owned entirely by this config. Each attention path reads its
|
| 10 |
-
parameters directly and constructs its own RotaryEmbedding instance explicitly — no
|
| 11 |
-
HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
import math
|
| 15 |
-
|
| 16 |
-
from transformers import PretrainedConfig
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ShramConfig(PretrainedConfig):
|
| 20 |
-
"""Configuration class for the SHRAM decoder-only transformer.
|
| 21 |
-
|
| 22 |
-
SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
|
| 23 |
-
attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
|
| 24 |
-
local sliding-window causal attention path and h_s is the MoSRAH sparse routed
|
| 25 |
-
path. All other components follow the Llama 3 baseline.
|
| 26 |
-
|
| 27 |
-
This config is the single source of truth for every architectural dimension of the
|
| 28 |
-
model. Nothing in the architecture may use a literal number that belongs here.
|
| 29 |
-
|
| 30 |
-
Two independent RoPE configurations exist — one per attention path:
|
| 31 |
-
|
| 32 |
-
- h_l always uses standard RoPE with ``local_rope_theta``.
|
| 33 |
-
- BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
|
| 34 |
-
``inference_sequence_length``, ``alpha``, and ``beta``. When
|
| 35 |
-
``inference_sequence_length == training_sequence_length`` the YaRN scale factor
|
| 36 |
-
``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
|
| 37 |
-
and the correct setting for experiments that do not require context extension.
|
| 38 |
-
|
| 39 |
-
Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
|
| 40 |
-
|
| 41 |
-
config = AutoConfig.from_pretrained(
|
| 42 |
-
"your-namespace/advanced-transformers-lib",
|
| 43 |
-
trust_remote_code=True,
|
| 44 |
-
|
| 45 |
-
)
|
| 46 |
-
model = AutoModelForCausalLM.from_config(config)
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
vocab_size: Vocabulary size. Controls the embedding table and output logits
|
| 50 |
-
dimension. Must match the tokenizer.
|
| 51 |
-
embedding_width: Model width ``d``. The dimension of the residual stream.
|
| 52 |
-
mlp_width: FFN hidden dimension.
|
| 53 |
-
num_decoder_layers: Number of transformer blocks stacked in sequence.
|
| 54 |
-
num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
|
| 55 |
-
num_mosrah_heads: Total MoSRAH expert heads available ``L``.
|
| 56 |
-
num_selected_heads: MoSRAH heads each token selects ``K``.
|
| 57 |
-
head_dim: Per-head dimension, shared by both attention paths. Must be even
|
| 58 |
-
(RoPE rotates dimensions in pairs). Paper uses 16.
|
| 59 |
-
window_size: Sliding window size for h_l. Paper uses 128.
|
| 60 |
-
rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
|
| 61 |
-
original sequence positions; ``"semantic_sequence"`` supplies local slot
|
| 62 |
-
indices. Both are required; experimentally correct mode is undetermined
|
| 63 |
-
(paper §4). Default ``"main_sequence"``.
|
| 64 |
-
rms_norm_eps: Epsilon for RMSNorm layers.
|
| 65 |
-
local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
|
| 66 |
-
Paper uses b=10000.
|
| 67 |
-
mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
|
| 68 |
-
b=10000.
|
| 69 |
-
training_sequence_length: Context length ``C_train`` the model was or will be
|
| 70 |
-
trained at. Used to compute the YaRN scale factor for BEA.
|
| 71 |
-
inference_sequence_length: Context length ``C_target`` the model must support
|
| 72 |
-
at inference. Optional; defaults to ``training_sequence_length`` so that
|
| 73 |
-
``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
|
| 74 |
-
alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
|
| 75 |
-
``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
|
| 76 |
-
beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
|
| 77 |
-
``r(d) > beta`` are left unscaled. Paper value: 32.0.
|
| 78 |
-
attention_dropout: Dropout probability on attention weights. Default 0.0.
|
| 79 |
-
use_cache: Whether to return past_key_values for KV caching.
|
| 80 |
-
output_hidden_states: Whether to return hidden states after each layer.
|
| 81 |
-
tie_word_embeddings: Whether input embedding and LM head share weights.
|
| 82 |
-
mosrah_overallocation_factor: Overallocation multiplier for the expert packing
|
| 83 |
-
buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
|
| 84 |
-
num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
|
| 85 |
-
Must be > 1.0 to guarantee a buffer larger than the balanced-routing
|
| 86 |
-
baseline. Default 2.0.
|
| 87 |
-
load_balance_p: Exponent p for the p-mean aggregation of per-item routing
|
| 88 |
-
frequencies into the load balance signal. Higher p weights aggregation
|
| 89 |
-
toward the worst-case batch item, making the correction signal more
|
| 90 |
-
sensitive to per-item allocation spikes. Must be positive. Default 2.0.
|
| 91 |
-
max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
|
| 92 |
-
solver in ``balance_capacity``. 10 covers convergence at approximately
|
| 93 |
-
the 98th percentile of routing densities; the top 2% of extreme-density
|
| 94 |
-
cases are not expected under normal training. The bound exists as a
|
| 95 |
-
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 96 |
-
Default 10.
|
| 97 |
-
"""
|
| 98 |
-
|
| 99 |
-
model_type = "shram"
|
| 100 |
-
|
| 101 |
-
auto_map = {
|
| 102 |
-
"AutoConfig": "configuration.ShramConfig",
|
| 103 |
-
"AutoModelForCausalLM": "huggingface.ShramForCausalLM",
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
def __init__(
|
| 107 |
-
self,
|
| 108 |
-
vocab_size: int = 50277,
|
| 109 |
-
embedding_width: int = 512,
|
| 110 |
-
mlp_width: int = 1366,
|
| 111 |
-
num_decoder_layers: int = 12,
|
| 112 |
-
num_sliding_window_heads: int = 16,
|
| 113 |
-
num_mosrah_heads: int = 16,
|
| 114 |
-
num_selected_heads: int = 16,
|
| 115 |
-
head_dim: int = 16,
|
| 116 |
-
window_size: int = 128,
|
| 117 |
-
rope_mode: str = "main_sequence",
|
| 118 |
-
rms_norm_eps: float = 1e-5,
|
| 119 |
-
local_rope_theta: float = 10000.0,
|
| 120 |
-
mosrah_rope_theta: float = 10000.0,
|
| 121 |
-
training_sequence_length: int = 1024,
|
| 122 |
-
inference_sequence_length: int | None = None,
|
| 123 |
-
alpha: float = 1.0,
|
| 124 |
-
beta: float = 32.0,
|
| 125 |
-
attention_dropout: float = 0.0,
|
| 126 |
-
use_cache: bool = True,
|
| 127 |
-
output_hidden_states: bool = False,
|
| 128 |
-
tie_word_embeddings: bool = False,
|
| 129 |
-
mosrah_overallocation_factor: float = 2.0,
|
| 130 |
-
load_balance_p: float = 2.0,
|
| 131 |
-
max_bid_rounds: int = 10,
|
| 132 |
-
**kwargs
|
| 133 |
-
):
|
| 134 |
-
if head_dim % 2 != 0:
|
| 135 |
-
raise ValueError(
|
| 136 |
-
f"head_dim must be even (RoPE rotates dimensions in pairs). "
|
| 137 |
-
f"Got head_dim={head_dim}."
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
if rope_mode not in {"main_sequence", "semantic_sequence"}:
|
| 141 |
-
raise ValueError(
|
| 142 |
-
f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
|
| 143 |
-
f"got '{rope_mode}'."
|
| 144 |
-
)
|
| 145 |
-
|
| 146 |
-
if training_sequence_length <= 0:
|
| 147 |
-
raise ValueError(
|
| 148 |
-
f"training_sequence_length must be positive, "
|
| 149 |
-
f"got {training_sequence_length}."
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
-
if inference_sequence_length is None:
|
| 153 |
-
inference_sequence_length = training_sequence_length
|
| 154 |
-
if inference_sequence_length <= 0:
|
| 155 |
-
raise ValueError(
|
| 156 |
-
f"inference_sequence_length must be positive, "
|
| 157 |
-
f"got {inference_sequence_length}."
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
if mosrah_overallocation_factor <= 1.0:
|
| 161 |
-
raise ValueError(
|
| 162 |
-
f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
|
| 163 |
-
f"buffer larger than the balanced-routing baseline. "
|
| 164 |
-
f"Got {mosrah_overallocation_factor}."
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
if load_balance_p <= 0.0:
|
| 168 |
-
raise ValueError(
|
| 169 |
-
f"load_balance_p must be positive, got {load_balance_p}."
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
if max_bid_rounds < 1:
|
| 173 |
-
raise ValueError(
|
| 174 |
-
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
self.vocab_size = vocab_size
|
| 178 |
-
self.embedding_width = embedding_width
|
| 179 |
-
self.mlp_width = mlp_width
|
| 180 |
-
self.num_decoder_layers = num_decoder_layers
|
| 181 |
-
self.num_sliding_window_heads = num_sliding_window_heads
|
| 182 |
-
self.num_mosrah_heads = num_mosrah_heads
|
| 183 |
-
self.num_selected_heads = num_selected_heads
|
| 184 |
-
self.head_dim = head_dim
|
| 185 |
-
self.window_size = window_size
|
| 186 |
-
self.rope_mode = rope_mode
|
| 187 |
-
self.rms_norm_eps = rms_norm_eps
|
| 188 |
-
self.local_rope_theta = local_rope_theta
|
| 189 |
-
self.mosrah_rope_theta = mosrah_rope_theta
|
| 190 |
-
self.training_sequence_length = training_sequence_length
|
| 191 |
-
self.inference_sequence_length = inference_sequence_length
|
| 192 |
-
self.alpha = alpha
|
| 193 |
-
self.beta = beta
|
| 194 |
-
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 195 |
-
self.load_balance_p = load_balance_p
|
| 196 |
-
self.max_bid_rounds = max_bid_rounds
|
| 197 |
-
self.attention_dropout = attention_dropout
|
| 198 |
-
self.use_cache = use_cache
|
| 199 |
-
|
| 200 |
-
super().__init__(
|
| 201 |
-
tie_word_embeddings=tie_word_embeddings,
|
| 202 |
-
output_hidden_states=output_hidden_states,
|
| 203 |
-
**kwargs
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
# Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
|
| 207 |
-
# serialises it into config.json.
|
| 208 |
-
self.auto_map = type(self).auto_map
|
| 209 |
-
|
| 210 |
-
@property
|
| 211 |
-
def scale(self) -> float:
|
| 212 |
-
"""YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
|
| 213 |
-
|
| 214 |
-
When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
|
| 215 |
-
adjustments cancel and A_rope = 1. This is the default state.
|
| 216 |
-
"""
|
| 217 |
-
return self.inference_sequence_length / self.training_sequence_length
|
| 218 |
-
|
| 219 |
-
@property
|
| 220 |
-
def mosrah_packed_length(self) -> int:
|
| 221 |
-
"""Static packed time dimension T for expert packing.
|
| 222 |
-
|
| 223 |
-
The expected tokens per expert under perfectly balanced routing is
|
| 224 |
-
``training_sequence_length * num_selected_heads / num_mosrah_heads``.
|
| 225 |
-
Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
|
| 226 |
-
that baseline. The ceiling ensures T is always an integer >= 1.
|
| 227 |
-
|
| 228 |
-
All consumers of the packed buffer size must read this property rather
|
| 229 |
-
than deriving T independently.
|
| 230 |
-
"""
|
| 231 |
-
return math.ceil(
|
| 232 |
-
self.training_sequence_length
|
| 233 |
-
* self.num_selected_heads
|
| 234 |
-
/ self.num_mosrah_heads
|
| 235 |
-
* self.mosrah_overallocation_factor
|
| 236 |
-
)
|
| 237 |
-
|
| 238 |
-
@property
|
| 239 |
-
def mosrah_cache_length(self) -> int:
|
| 240 |
-
"""Static per-(batch, head) slot capacity for the MoSRAH inference cache.
|
| 241 |
-
|
| 242 |
-
The expected tokens per expert over the full inference context under perfectly
|
| 243 |
-
balanced routing is ``inference_sequence_length * num_selected_heads /
|
| 244 |
-
num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
|
| 245 |
-
a buffer above that baseline. The ceiling ensures the result is always an
|
| 246 |
-
integer >= 1.
|
| 247 |
-
|
| 248 |
-
Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
|
| 249 |
-
using ``training_sequence_length``. This property uses
|
| 250 |
-
``inference_sequence_length`` because the cache must hold the full accumulated
|
| 251 |
-
token history across the entire inference run.
|
| 252 |
-
|
| 253 |
-
All consumers of the MoSRAH cache buffer size must read this property rather
|
| 254 |
-
than deriving the capacity independently.
|
| 255 |
-
"""
|
| 256 |
-
return math.ceil(
|
| 257 |
-
self.inference_sequence_length
|
| 258 |
-
* self.num_selected_heads
|
| 259 |
-
/ self.num_mosrah_heads
|
| 260 |
-
* self.mosrah_overallocation_factor
|
| 261 |
-
)
|
| 262 |
-
|
|
|
|
| 1 |
+
"""Configuration for the SHRAM transformer.
|
| 2 |
+
|
| 3 |
+
All architectural parameters that vary across model scales or are meaningful research
|
| 4 |
+
variables are expressed here. Architectural constants (no bias in linear layers,
|
| 5 |
+
SwiGLU activation with SiLU gate) are implemented in the relevant modules and
|
| 6 |
+
documented at the point of use — they are not config parameters because they do not
|
| 7 |
+
vary and changing them produces a different architecture, not a different scale.
|
| 8 |
+
|
| 9 |
+
RoPE configuration is owned entirely by this config. Each attention path reads its
|
| 10 |
+
parameters directly and constructs its own RotaryEmbedding instance explicitly — no
|
| 11 |
+
HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
from transformers import PretrainedConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ShramConfig(PretrainedConfig):
|
| 20 |
+
"""Configuration class for the SHRAM decoder-only transformer.
|
| 21 |
+
|
| 22 |
+
SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
|
| 23 |
+
attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
|
| 24 |
+
local sliding-window causal attention path and h_s is the MoSRAH sparse routed
|
| 25 |
+
path. All other components follow the Llama 3 baseline.
|
| 26 |
+
|
| 27 |
+
This config is the single source of truth for every architectural dimension of the
|
| 28 |
+
model. Nothing in the architecture may use a literal number that belongs here.
|
| 29 |
+
|
| 30 |
+
Two independent RoPE configurations exist — one per attention path:
|
| 31 |
+
|
| 32 |
+
- h_l always uses standard RoPE with ``local_rope_theta``.
|
| 33 |
+
- BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
|
| 34 |
+
``inference_sequence_length``, ``alpha``, and ``beta``. When
|
| 35 |
+
``inference_sequence_length == training_sequence_length`` the YaRN scale factor
|
| 36 |
+
``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
|
| 37 |
+
and the correct setting for experiments that do not require context extension.
|
| 38 |
+
|
| 39 |
+
Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
|
| 40 |
+
|
| 41 |
+
config = AutoConfig.from_pretrained(
|
| 42 |
+
"your-namespace/advanced-transformers-lib",
|
| 43 |
+
trust_remote_code=True,
|
| 44 |
+
num_decoder_layers=12,
|
| 45 |
+
)
|
| 46 |
+
model = AutoModelForCausalLM.from_config(config)
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
vocab_size: Vocabulary size. Controls the embedding table and output logits
|
| 50 |
+
dimension. Must match the tokenizer.
|
| 51 |
+
embedding_width: Model width ``d``. The dimension of the residual stream.
|
| 52 |
+
mlp_width: FFN hidden dimension.
|
| 53 |
+
num_decoder_layers: Number of transformer blocks stacked in sequence.
|
| 54 |
+
num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
|
| 55 |
+
num_mosrah_heads: Total MoSRAH expert heads available ``L``.
|
| 56 |
+
num_selected_heads: MoSRAH heads each token selects ``K``.
|
| 57 |
+
head_dim: Per-head dimension, shared by both attention paths. Must be even
|
| 58 |
+
(RoPE rotates dimensions in pairs). Paper uses 16.
|
| 59 |
+
window_size: Sliding window size for h_l. Paper uses 128.
|
| 60 |
+
rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
|
| 61 |
+
original sequence positions; ``"semantic_sequence"`` supplies local slot
|
| 62 |
+
indices. Both are required; experimentally correct mode is undetermined
|
| 63 |
+
(paper §4). Default ``"main_sequence"``.
|
| 64 |
+
rms_norm_eps: Epsilon for RMSNorm layers.
|
| 65 |
+
local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
|
| 66 |
+
Paper uses b=10000.
|
| 67 |
+
mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
|
| 68 |
+
b=10000.
|
| 69 |
+
training_sequence_length: Context length ``C_train`` the model was or will be
|
| 70 |
+
trained at. Used to compute the YaRN scale factor for BEA.
|
| 71 |
+
inference_sequence_length: Context length ``C_target`` the model must support
|
| 72 |
+
at inference. Optional; defaults to ``training_sequence_length`` so that
|
| 73 |
+
``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
|
| 74 |
+
alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
|
| 75 |
+
``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
|
| 76 |
+
beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
|
| 77 |
+
``r(d) > beta`` are left unscaled. Paper value: 32.0.
|
| 78 |
+
attention_dropout: Dropout probability on attention weights. Default 0.0.
|
| 79 |
+
use_cache: Whether to return past_key_values for KV caching.
|
| 80 |
+
output_hidden_states: Whether to return hidden states after each layer.
|
| 81 |
+
tie_word_embeddings: Whether input embedding and LM head share weights.
|
| 82 |
+
mosrah_overallocation_factor: Overallocation multiplier for the expert packing
|
| 83 |
+
buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
|
| 84 |
+
num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
|
| 85 |
+
Must be > 1.0 to guarantee a buffer larger than the balanced-routing
|
| 86 |
+
baseline. Default 2.0.
|
| 87 |
+
load_balance_p: Exponent p for the p-mean aggregation of per-item routing
|
| 88 |
+
frequencies into the load balance signal. Higher p weights aggregation
|
| 89 |
+
toward the worst-case batch item, making the correction signal more
|
| 90 |
+
sensitive to per-item allocation spikes. Must be positive. Default 2.0.
|
| 91 |
+
max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
|
| 92 |
+
solver in ``balance_capacity``. 10 covers convergence at approximately
|
| 93 |
+
the 98th percentile of routing densities; the top 2% of extreme-density
|
| 94 |
+
cases are not expected under normal training. The bound exists as a
|
| 95 |
+
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 96 |
+
Default 10.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
model_type = "shram"
|
| 100 |
+
|
| 101 |
+
auto_map = {
|
| 102 |
+
"AutoConfig": "configuration.ShramConfig",
|
| 103 |
+
"AutoModelForCausalLM": "huggingface.ShramForCausalLM",
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
vocab_size: int = 50277,
|
| 109 |
+
embedding_width: int = 512,
|
| 110 |
+
mlp_width: int = 1366,
|
| 111 |
+
num_decoder_layers: int = 12,
|
| 112 |
+
num_sliding_window_heads: int = 16,
|
| 113 |
+
num_mosrah_heads: int = 16,
|
| 114 |
+
num_selected_heads: int = 16,
|
| 115 |
+
head_dim: int = 16,
|
| 116 |
+
window_size: int = 128,
|
| 117 |
+
rope_mode: str = "main_sequence",
|
| 118 |
+
rms_norm_eps: float = 1e-5,
|
| 119 |
+
local_rope_theta: float = 10000.0,
|
| 120 |
+
mosrah_rope_theta: float = 10000.0,
|
| 121 |
+
training_sequence_length: int = 1024,
|
| 122 |
+
inference_sequence_length: int | None = None,
|
| 123 |
+
alpha: float = 1.0,
|
| 124 |
+
beta: float = 32.0,
|
| 125 |
+
attention_dropout: float = 0.0,
|
| 126 |
+
use_cache: bool = True,
|
| 127 |
+
output_hidden_states: bool = False,
|
| 128 |
+
tie_word_embeddings: bool = False,
|
| 129 |
+
mosrah_overallocation_factor: float = 2.0,
|
| 130 |
+
load_balance_p: float = 2.0,
|
| 131 |
+
max_bid_rounds: int = 10,
|
| 132 |
+
**kwargs
|
| 133 |
+
):
|
| 134 |
+
if head_dim % 2 != 0:
|
| 135 |
+
raise ValueError(
|
| 136 |
+
f"head_dim must be even (RoPE rotates dimensions in pairs). "
|
| 137 |
+
f"Got head_dim={head_dim}."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if rope_mode not in {"main_sequence", "semantic_sequence"}:
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
|
| 143 |
+
f"got '{rope_mode}'."
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if training_sequence_length <= 0:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"training_sequence_length must be positive, "
|
| 149 |
+
f"got {training_sequence_length}."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if inference_sequence_length is None:
|
| 153 |
+
inference_sequence_length = training_sequence_length
|
| 154 |
+
if inference_sequence_length <= 0:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
f"inference_sequence_length must be positive, "
|
| 157 |
+
f"got {inference_sequence_length}."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if mosrah_overallocation_factor <= 1.0:
|
| 161 |
+
raise ValueError(
|
| 162 |
+
f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
|
| 163 |
+
f"buffer larger than the balanced-routing baseline. "
|
| 164 |
+
f"Got {mosrah_overallocation_factor}."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if load_balance_p <= 0.0:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"load_balance_p must be positive, got {load_balance_p}."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if max_bid_rounds < 1:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.vocab_size = vocab_size
|
| 178 |
+
self.embedding_width = embedding_width
|
| 179 |
+
self.mlp_width = mlp_width
|
| 180 |
+
self.num_decoder_layers = num_decoder_layers
|
| 181 |
+
self.num_sliding_window_heads = num_sliding_window_heads
|
| 182 |
+
self.num_mosrah_heads = num_mosrah_heads
|
| 183 |
+
self.num_selected_heads = num_selected_heads
|
| 184 |
+
self.head_dim = head_dim
|
| 185 |
+
self.window_size = window_size
|
| 186 |
+
self.rope_mode = rope_mode
|
| 187 |
+
self.rms_norm_eps = rms_norm_eps
|
| 188 |
+
self.local_rope_theta = local_rope_theta
|
| 189 |
+
self.mosrah_rope_theta = mosrah_rope_theta
|
| 190 |
+
self.training_sequence_length = training_sequence_length
|
| 191 |
+
self.inference_sequence_length = inference_sequence_length
|
| 192 |
+
self.alpha = alpha
|
| 193 |
+
self.beta = beta
|
| 194 |
+
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 195 |
+
self.load_balance_p = load_balance_p
|
| 196 |
+
self.max_bid_rounds = max_bid_rounds
|
| 197 |
+
self.attention_dropout = attention_dropout
|
| 198 |
+
self.use_cache = use_cache
|
| 199 |
+
|
| 200 |
+
super().__init__(
|
| 201 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 202 |
+
output_hidden_states=output_hidden_states,
|
| 203 |
+
**kwargs
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
|
| 207 |
+
# serialises it into config.json.
|
| 208 |
+
self.auto_map = type(self).auto_map
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def scale(self) -> float:
|
| 212 |
+
"""YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
|
| 213 |
+
|
| 214 |
+
When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
|
| 215 |
+
adjustments cancel and A_rope = 1. This is the default state.
|
| 216 |
+
"""
|
| 217 |
+
return self.inference_sequence_length / self.training_sequence_length
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def mosrah_packed_length(self) -> int:
|
| 221 |
+
"""Static packed time dimension T for expert packing.
|
| 222 |
+
|
| 223 |
+
The expected tokens per expert under perfectly balanced routing is
|
| 224 |
+
``training_sequence_length * num_selected_heads / num_mosrah_heads``.
|
| 225 |
+
Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
|
| 226 |
+
that baseline. The ceiling ensures T is always an integer >= 1.
|
| 227 |
+
|
| 228 |
+
All consumers of the packed buffer size must read this property rather
|
| 229 |
+
than deriving T independently.
|
| 230 |
+
"""
|
| 231 |
+
return math.ceil(
|
| 232 |
+
self.training_sequence_length
|
| 233 |
+
* self.num_selected_heads
|
| 234 |
+
/ self.num_mosrah_heads
|
| 235 |
+
* self.mosrah_overallocation_factor
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def mosrah_cache_length(self) -> int:
|
| 240 |
+
"""Static per-(batch, head) slot capacity for the MoSRAH inference cache.
|
| 241 |
+
|
| 242 |
+
The expected tokens per expert over the full inference context under perfectly
|
| 243 |
+
balanced routing is ``inference_sequence_length * num_selected_heads /
|
| 244 |
+
num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
|
| 245 |
+
a buffer above that baseline. The ceiling ensures the result is always an
|
| 246 |
+
integer >= 1.
|
| 247 |
+
|
| 248 |
+
Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
|
| 249 |
+
using ``training_sequence_length``. This property uses
|
| 250 |
+
``inference_sequence_length`` because the cache must hold the full accumulated
|
| 251 |
+
token history across the entire inference run.
|
| 252 |
+
|
| 253 |
+
All consumers of the MoSRAH cache buffer size must read this property rather
|
| 254 |
+
than deriving the capacity independently.
|
| 255 |
+
"""
|
| 256 |
+
return math.ceil(
|
| 257 |
+
self.inference_sequence_length
|
| 258 |
+
* self.num_selected_heads
|
| 259 |
+
/ self.num_mosrah_heads
|
| 260 |
+
* self.mosrah_overallocation_factor
|
| 261 |
+
)
|
| 262 |
+
|
huggingface.py
CHANGED
|
@@ -128,7 +128,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 128 |
config = AutoConfig.from_pretrained(
|
| 129 |
"your-namespace/advanced-transformers-lib",
|
| 130 |
trust_remote_code=True,
|
| 131 |
-
|
| 132 |
)
|
| 133 |
model = AutoModelForCausalLM.from_config(config)
|
| 134 |
|
|
@@ -725,17 +725,21 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 725 |
def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
|
| 726 |
"""Raise if any (batch, head) slot would exceed the static buffer capacity.
|
| 727 |
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
|
| 733 |
Args:
|
| 734 |
max_count: Scalar tensor — the maximum post-update count across all slots.
|
| 735 |
capacity: The static buffer capacity (mosrah_cache_length).
|
| 736 |
"""
|
| 737 |
if torch.compiler.is_compiling():
|
| 738 |
-
torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
else:
|
| 740 |
if max_count.item() > capacity:
|
| 741 |
raise RuntimeError(
|
|
@@ -856,7 +860,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 856 |
# Cumulative count of all token positions presented through update() for
|
| 857 |
# this cache instance. This is the quantity HuggingFace generation reads
|
| 858 |
# through get_seq_length() to track how far along the sequence we are.
|
| 859 |
-
self._total_processed
|
| 860 |
|
| 861 |
def update( # type: ignore[override]
|
| 862 |
self,
|
|
@@ -996,7 +1000,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 996 |
generation reads to track sequence progress and is not the same as active-token
|
| 997 |
count or current window occupancy.
|
| 998 |
"""
|
| 999 |
-
return self._total_processed
|
| 1000 |
|
| 1001 |
def get_max_cache_shape(self) -> int:
|
| 1002 |
return self.sliding_window
|
|
@@ -2299,29 +2303,24 @@ class BottleneckedEnsembleAttention(nn.Module):
|
|
| 2299 |
# -----------
|
| 2300 |
"""Expert packing and unpacking for the MoSRAH path.
|
| 2301 |
|
| 2302 |
-
This module
|
| 2303 |
-
|
| 2304 |
-
|
| 2305 |
-
- setup_packing() prepares the auxiliary ordering data
|
| 2306 |
-
|
| 2307 |
-
- pack_experts() converts
|
| 2308 |
-
|
| 2309 |
-
|
| 2310 |
-
|
| 2311 |
-
|
| 2312 |
-
|
| 2313 |
-
|
| 2314 |
-
|
| 2315 |
-
|
| 2316 |
-
pack_experts() returns
|
| 2317 |
-
|
| 2318 |
-
|
| 2319 |
-
|
| 2320 |
-
live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
|
| 2321 |
-
so its reshape invariant holds regardless of outer token liveness.
|
| 2322 |
-
- active_mask (caller-supplied entry): marks only the packed slots whose source
|
| 2323 |
-
token was semantically live. This is what BEA consumes for attention gating.
|
| 2324 |
-
Dead outer tokens must not influence sparse attention outputs.
|
| 2325 |
"""
|
| 2326 |
|
| 2327 |
|
|
@@ -2337,23 +2336,13 @@ def setup_packing(
|
|
| 2337 |
) -> dict[str, torch.Tensor]:
|
| 2338 |
"""Prepare the auxiliary ordering data used by pack/unpack.
|
| 2339 |
|
| 2340 |
-
Routing produces token-choice state I of shape (B, N, K): for each token, which
|
| 2341 |
-
K experts were selected. Packing needs the same routed token copies reordered into
|
| 2342 |
-
expert-major order so each expert bucket becomes contiguous.
|
| 2343 |
-
|
| 2344 |
-
The paper's setup step does this by flattening (N, K) into one axis to produce
|
| 2345 |
-
H in token-major order, then computing a stable argsort permutation Pi over the
|
| 2346 |
-
expert indices stored in H. Applying Pi reorders the flattened routed copies into
|
| 2347 |
-
expert-major order while preserving their original token order *within* each expert
|
| 2348 |
-
bucket. That preservation is why stable sort is required for causality.
|
| 2349 |
-
|
| 2350 |
Args:
|
| 2351 |
selected_heads: Routed token-choice head selections I of shape (B, N, K).
|
| 2352 |
|
| 2353 |
Returns:
|
| 2354 |
Auxiliary payload dict with keys:
|
| 2355 |
- "flattened_selected_heads": H of shape (B, N*K)
|
| 2356 |
-
- "permutation":
|
| 2357 |
- "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K)
|
| 2358 |
This dict is forwarded whole to pack_experts and unpack_experts.
|
| 2359 |
"""
|
|
@@ -2362,7 +2351,14 @@ def setup_packing(
|
|
| 2362 |
batch_size,
|
| 2363 |
sequence_length * num_selected_heads,
|
| 2364 |
)
|
| 2365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2366 |
permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
|
| 2367 |
inverse_permutation = torch.argsort(permutation, dim=-1)
|
| 2368 |
|
|
@@ -2370,7 +2366,6 @@ def setup_packing(
|
|
| 2370 |
"flattened_selected_heads": flattened_selected_heads,
|
| 2371 |
"permutation": permutation,
|
| 2372 |
"inverse_permutation": inverse_permutation,
|
| 2373 |
-
"num_elements" : num_elements,
|
| 2374 |
}
|
| 2375 |
|
| 2376 |
|
|
@@ -2387,20 +2382,6 @@ def pack_experts(
|
|
| 2387 |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 2388 |
"""Pack token-choice tensors into expert-choice padded form.
|
| 2389 |
|
| 2390 |
-
The paper's packing path has two jobs:
|
| 2391 |
-
|
| 2392 |
-
1. Convert routed token-choice copies into expert-major order.
|
| 2393 |
-
2. Materialize that expert-major order into a padded tensor layout BEA can consume.
|
| 2394 |
-
|
| 2395 |
-
All entries in the provided dict undergo the same expert-major gather-scatter so
|
| 2396 |
-
they remain mutually aligned in the packed frame. Each entry is paired with its
|
| 2397 |
-
intended padding value, which fills slots that contain no routed token copy.
|
| 2398 |
-
|
| 2399 |
-
Packed positions are sourced from the authoritative upstream position_ids tensor
|
| 2400 |
-
rather than synthesized locally from arange(N). This preserves advanced positions
|
| 2401 |
-
correctly during cached inference while leaving training/full-sequence behavior
|
| 2402 |
-
unchanged when position_ids is the ordinary sequential token positions.
|
| 2403 |
-
|
| 2404 |
Args:
|
| 2405 |
entries: Mapping from string keys to (tensor, padding_value) pairs. Each
|
| 2406 |
tensor has shape (B, N, ...) and is rearranged into expert-choice layout
|
|
@@ -2409,29 +2390,40 @@ def pack_experts(
|
|
| 2409 |
selected_heads: Routed head selections I of shape (B, N, K).
|
| 2410 |
num_experts: Total number of experts L.
|
| 2411 |
packed_length: Static packed time dimension T. All per-expert buffers are
|
| 2412 |
-
allocated to exactly this length.
|
| 2413 |
-
|
| 2414 |
-
this value.
|
| 2415 |
|
| 2416 |
Returns:
|
| 2417 |
Tuple of:
|
| 2418 |
- packed_entries: Dict with same keys as entries; each value is the
|
| 2419 |
packed tensor of shape (B, L, T, ...).
|
| 2420 |
-
-
|
| 2421 |
-
|
| 2422 |
-
|
| 2423 |
"""
|
| 2424 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
|
|
|
|
|
|
| 2425 |
|
| 2426 |
flattened_selected_heads = setup["flattened_selected_heads"]
|
| 2427 |
permutation = setup["permutation"]
|
| 2428 |
|
| 2429 |
# -----------------------------------------------------------------------
|
| 2430 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2431 |
#
|
| 2432 |
-
#
|
| 2433 |
-
#
|
| 2434 |
-
#
|
| 2435 |
# -----------------------------------------------------------------------
|
| 2436 |
source_token_indices = torch.arange(
|
| 2437 |
sequence_length,
|
|
@@ -2442,81 +2434,91 @@ def pack_experts(
|
|
| 2442 |
sequence_length,
|
| 2443 |
num_selected_heads,
|
| 2444 |
)
|
| 2445 |
-
|
| 2446 |
batch_size,
|
| 2447 |
-
|
| 2448 |
)
|
| 2449 |
-
|
| 2450 |
-
# -----------------------------------------------------------------------
|
| 2451 |
-
# Reorder source-token indices into expert-major order.
|
| 2452 |
-
#
|
| 2453 |
-
# Applying Pi yields the local source-token rows in the packed expert-major
|
| 2454 |
-
# order required by the paper. All entries are then gathered using these same
|
| 2455 |
-
# reordered indices so they remain aligned under the exact same transformation.
|
| 2456 |
-
# -----------------------------------------------------------------------
|
| 2457 |
-
sorted_source_indices = flattened_source_indices.gather(
|
| 2458 |
dim=1,
|
| 2459 |
index=permutation,
|
| 2460 |
)
|
| 2461 |
|
| 2462 |
# -----------------------------------------------------------------------
|
| 2463 |
-
#
|
| 2464 |
-
# that no bucket exceeds the statically preallocated packed_length T.
|
| 2465 |
#
|
| 2466 |
-
#
|
| 2467 |
-
#
|
| 2468 |
-
#
|
| 2469 |
-
# both eager and compiled modes.
|
| 2470 |
# -----------------------------------------------------------------------
|
| 2471 |
tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts)
|
| 2472 |
-
|
| 2473 |
-
no_overflow = max_count <= packed_length
|
| 2474 |
-
_enforce_no_overflow(no_overflow, tokens_per_expert, packed_length)
|
| 2475 |
|
| 2476 |
# -----------------------------------------------------------------------
|
| 2477 |
-
#
|
| 2478 |
#
|
| 2479 |
-
#
|
| 2480 |
-
#
|
| 2481 |
-
#
|
| 2482 |
-
#
|
|
|
|
| 2483 |
# -----------------------------------------------------------------------
|
| 2484 |
-
|
| 2485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2486 |
device=flattened_selected_heads.device,
|
| 2487 |
dtype=torch.long,
|
| 2488 |
-
)
|
| 2489 |
-
|
|
|
|
|
|
|
| 2490 |
|
| 2491 |
# -----------------------------------------------------------------------
|
| 2492 |
-
# Materialize
|
| 2493 |
#
|
| 2494 |
-
# Each entry
|
| 2495 |
-
#
|
| 2496 |
-
#
|
| 2497 |
-
# value rather than an implicit zero.
|
| 2498 |
# -----------------------------------------------------------------------
|
| 2499 |
packed_entries: dict[str, torch.Tensor] = {}
|
| 2500 |
for key, (tensor, padding_value) in entries.items():
|
| 2501 |
extra_shape = tensor.shape[2:]
|
| 2502 |
|
| 2503 |
-
#
|
| 2504 |
-
|
|
|
|
|
|
|
| 2505 |
batch_size,
|
| 2506 |
-
|
| 2507 |
*(1,) * len(extra_shape),
|
| 2508 |
).expand(-1, -1, *extra_shape)
|
| 2509 |
-
sorted_tensor = tensor.gather(dim=1, index=
|
| 2510 |
|
| 2511 |
packed_tensor = tensor.new_full(
|
| 2512 |
-
(batch_size
|
| 2513 |
fill_value=padding_value,
|
| 2514 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2515 |
|
| 2516 |
-
|
| 2517 |
-
packed_entries[key] = packed_tensor
|
| 2518 |
-
|
| 2519 |
-
return packed_entries, unpacking_mask
|
| 2520 |
|
| 2521 |
|
| 2522 |
# ---------------------------------------------------------------------------
|
|
@@ -2526,27 +2528,17 @@ def pack_experts(
|
|
| 2526 |
def unpack_experts(
|
| 2527 |
expert_outputs: torch.Tensor,
|
| 2528 |
setup: dict[str, torch.Tensor],
|
| 2529 |
-
|
| 2530 |
selected_heads: torch.Tensor,
|
| 2531 |
) -> torch.Tensor:
|
| 2532 |
"""Restore token-choice ordering from BEA expert-choice output.
|
| 2533 |
|
| 2534 |
-
Unpacking inverts the packing path only on occupied entries. Padding does not
|
| 2535 |
-
participate: the output tensor is first filtered by unpacking_mask to recover
|
| 2536 |
-
only the real routed-token copies in expert-major order, then Pi^{-1} restores
|
| 2537 |
-
the original token-choice ordering, and finally the tensor is reshaped back to
|
| 2538 |
-
(B, N, K, d).
|
| 2539 |
-
|
| 2540 |
-
The unpacking_mask — not active_mask — must be used here. Even copies of dead
|
| 2541 |
-
outer tokens occupy slots and must be un-scattered correctly for the inverse
|
| 2542 |
-
permutation to hold. The total True entry count in unpacking_mask is always
|
| 2543 |
-
B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
|
| 2544 |
-
|
| 2545 |
Args:
|
| 2546 |
expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
|
| 2547 |
setup: Auxiliary payload returned by setup_packing().
|
| 2548 |
-
|
| 2549 |
-
|
|
|
|
| 2550 |
selected_heads: Routed head selections I of shape (B, N, K).
|
| 2551 |
|
| 2552 |
Returns:
|
|
@@ -2555,22 +2547,22 @@ def unpack_experts(
|
|
| 2555 |
inverse_permutation = setup["inverse_permutation"]
|
| 2556 |
|
| 2557 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
|
|
|
| 2558 |
hidden_dim = expert_outputs.shape[-1]
|
| 2559 |
|
| 2560 |
-
|
| 2561 |
-
|
| 2562 |
-
|
| 2563 |
-
|
| 2564 |
-
|
| 2565 |
-
|
| 2566 |
-
|
| 2567 |
-
|
| 2568 |
-
|
| 2569 |
-
] # shape: (B*N*K, d)
|
| 2570 |
|
| 2571 |
-
sorted_token_choice_outputs =
|
| 2572 |
batch_size,
|
| 2573 |
-
|
| 2574 |
hidden_dim,
|
| 2575 |
)
|
| 2576 |
restored_outputs = sorted_token_choice_outputs.gather(
|
|
@@ -2589,34 +2581,34 @@ def unpack_experts(
|
|
| 2589 |
# Helpers
|
| 2590 |
# ---------------------------------------------------------------------------
|
| 2591 |
|
| 2592 |
-
def _enforce_no_overflow(condition: bool, tokens_per_expert, max_length) -> None:
|
| 2593 |
-
"""Enforce that no expert bucket exceeds the preallocated packed length.
|
| 2594 |
|
| 2595 |
-
|
| 2596 |
-
|
| 2597 |
-
packed buffer is too small to hold all assignments and data would be dropped.
|
| 2598 |
-
Increase mosrah_overallocation_factor in ShramConfig to resolve.
|
| 2599 |
|
| 2600 |
-
|
| 2601 |
-
|
| 2602 |
-
|
| 2603 |
-
|
| 2604 |
|
| 2605 |
Args:
|
| 2606 |
-
|
| 2607 |
-
|
| 2608 |
-
produced by comparing a SymInt against the static packed_length.
|
| 2609 |
"""
|
| 2610 |
if torch.compiler.is_compiling():
|
| 2611 |
-
torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2612 |
else:
|
| 2613 |
-
|
|
|
|
| 2614 |
raise RuntimeError(
|
| 2615 |
"Expert packing overflow: at least one expert bucket contains more "
|
| 2616 |
"tokens than mosrah_packed_length allows. Increase "
|
| 2617 |
"mosrah_overallocation_factor in ShramConfig to resolve.\n"
|
| 2618 |
-
f"
|
| 2619 |
-
f"
|
| 2620 |
)
|
| 2621 |
|
| 2622 |
|
|
@@ -2626,8 +2618,7 @@ def _count_tokens_per_expert(
|
|
| 2626 |
) -> torch.Tensor:
|
| 2627 |
"""Count how many routed token copies are assigned to each expert per batch item.
|
| 2628 |
|
| 2629 |
-
Uses scatter_add into a pre-sized (B, num_experts)
|
| 2630 |
-
statically-shaped output that compiles without graph breaks. Each position in
|
| 2631 |
flattened_selected_heads contributes one count to the corresponding expert slot.
|
| 2632 |
|
| 2633 |
Args:
|
|
@@ -2639,19 +2630,18 @@ def _count_tokens_per_expert(
|
|
| 2639 |
Counts tensor of shape (B, num_experts).
|
| 2640 |
"""
|
| 2641 |
batch_size = flattened_selected_heads.shape[0]
|
| 2642 |
-
|
| 2643 |
batch_size,
|
| 2644 |
num_experts,
|
| 2645 |
device=flattened_selected_heads.device,
|
| 2646 |
-
dtype=
|
| 2647 |
)
|
| 2648 |
-
|
| 2649 |
dim=1,
|
| 2650 |
index=flattened_selected_heads,
|
| 2651 |
-
src=torch.ones_like(flattened_selected_heads),
|
| 2652 |
)
|
| 2653 |
-
return
|
| 2654 |
-
|
| 2655 |
# -----------
|
| 2656 |
# Inlined from: router.py
|
| 2657 |
# -----------
|
|
@@ -2825,7 +2815,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 2825 |
self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
|
| 2826 |
|
| 2827 |
@staticmethod
|
| 2828 |
-
def
|
| 2829 |
tensor: torch.Tensor,
|
| 2830 |
dim: int,
|
| 2831 |
n: int | torch.Tensor,
|
|
@@ -2958,7 +2948,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 2958 |
choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
|
| 2959 |
|
| 2960 |
unproposed_logits = logits.masked_fill(proposals, float('-inf'))
|
| 2961 |
-
new_proposals = cls.
|
| 2962 |
unproposed_logits, dim=-1, n=choices_deficit, capacity_scalar=min_choices,
|
| 2963 |
)
|
| 2964 |
proposals = proposals | new_proposals
|
|
@@ -2969,7 +2959,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 2969 |
# Acceptances are recomputed from scratch each round so that a
|
| 2970 |
# stronger new proposal can displace a weaker prior one.
|
| 2971 |
proposed_logits = logits.masked_fill(~proposals, float('-inf'))
|
| 2972 |
-
acceptances = cls.
|
| 2973 |
proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
|
| 2974 |
)
|
| 2975 |
|
|
@@ -3351,7 +3341,7 @@ class MoSRAHLayer(nn.Module):
|
|
| 3351 |
"position_ids": (position_ids, 0),
|
| 3352 |
"active_mask": (active_mask, False),
|
| 3353 |
}
|
| 3354 |
-
packed,
|
| 3355 |
packed_hidden_states = packed["hidden_states"]
|
| 3356 |
packed_positions = packed["position_ids"]
|
| 3357 |
active_mask = packed["active_mask"]
|
|
@@ -3387,7 +3377,7 @@ class MoSRAHLayer(nn.Module):
|
|
| 3387 |
token_choice_outputs = unpack_experts(
|
| 3388 |
expert_outputs=packed_outputs,
|
| 3389 |
setup=setup,
|
| 3390 |
-
|
| 3391 |
selected_heads=selected_heads,
|
| 3392 |
)
|
| 3393 |
final_output = (
|
|
@@ -3886,11 +3876,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 3886 |
|
| 3887 |
@staticmethod
|
| 3888 |
def create_masks_for_generate(
|
| 3889 |
-
config: Any,
|
| 3890 |
-
inputs_embeds: torch.Tensor,
|
| 3891 |
attention_mask: torch.Tensor | None,
|
| 3892 |
-
past_key_values: Cache | None,
|
| 3893 |
-
position_ids: torch.Tensor | None = None,
|
| 3894 |
**kwargs: Any,
|
| 3895 |
) -> torch.Tensor | None:
|
| 3896 |
"""Return the 2D attention_mask unchanged.
|
|
@@ -3944,7 +3930,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 3944 |
raise ValueError(
|
| 3945 |
"position_ids must match the current input_ids shape exactly."
|
| 3946 |
)
|
| 3947 |
-
if
|
| 3948 |
raise TypeError("position_ids must be an long tensor.")
|
| 3949 |
|
| 3950 |
def _validate_labels(
|
|
@@ -3959,7 +3945,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 3959 |
raise ValueError("labels must have shape (batch, seq_len).")
|
| 3960 |
if labels.shape != input_ids.shape:
|
| 3961 |
raise ValueError("labels must have the same shape as input_ids.")
|
| 3962 |
-
if
|
| 3963 |
raise TypeError("labels must be a long tensor.")
|
| 3964 |
|
| 3965 |
def _validate_cache_inputs(
|
|
@@ -4044,11 +4030,11 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4044 |
(violated).
|
| 4045 |
"""
|
| 4046 |
if torch.compiler.is_compiling():
|
| 4047 |
-
|
| 4048 |
-
|
| 4049 |
-
|
| 4050 |
-
|
| 4051 |
-
|
| 4052 |
else:
|
| 4053 |
if not condition.item():
|
| 4054 |
raise RuntimeError(
|
|
@@ -4058,30 +4044,6 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4058 |
"uncached sequence to start at 0.",
|
| 4059 |
)
|
| 4060 |
|
| 4061 |
-
@staticmethod
|
| 4062 |
-
def _enforce_capture_scalar_outputs() -> None:
|
| 4063 |
-
"""Enforce that capture_scalar_outputs is enabled when compiling.
|
| 4064 |
-
|
| 4065 |
-
The safety checks in this model (e.g. position-zero constraint, packing
|
| 4066 |
-
overflow detection) rely on torch._check folding into the compiled graph,
|
| 4067 |
-
which requires torch._dynamo.config.capture_scalar_outputs = True. Without
|
| 4068 |
-
it those checks are silently absent in the compiled model while appearing
|
| 4069 |
-
to work in eager mode — a misconfiguration with no diagnostic output.
|
| 4070 |
-
|
| 4071 |
-
This method fires during dynamo tracing so the missing flag is surfaced
|
| 4072 |
-
immediately at compile time rather than discovered from downstream failures.
|
| 4073 |
-
"""
|
| 4074 |
-
if torch.compiler.is_compiling():
|
| 4075 |
-
torch._check(
|
| 4076 |
-
torch._dynamo.config.capture_scalar_outputs,
|
| 4077 |
-
lambda: RuntimeError(
|
| 4078 |
-
"ShramForCausalLM requires torch._dynamo.config.capture_scalar_outputs = True "
|
| 4079 |
-
"when compiled. Without it, runtime safety checks (position constraints, "
|
| 4080 |
-
"overflow detection) are silently absent in the compiled model. Set the flag "
|
| 4081 |
-
"before calling torch.compile()."
|
| 4082 |
-
),
|
| 4083 |
-
)
|
| 4084 |
-
|
| 4085 |
def _standardize_full_attention_mask(
|
| 4086 |
self,
|
| 4087 |
input_ids: torch.Tensor,
|
|
@@ -4179,7 +4141,6 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4179 |
# This keeps the main sequence readable while ensuring invalid states
|
| 4180 |
# fail before they can silently contaminate backbone execution.
|
| 4181 |
# ------------------------------------------------------------------
|
| 4182 |
-
self._enforce_capture_scalar_outputs()
|
| 4183 |
self._validate_input_ids(input_ids)
|
| 4184 |
self._validate_attention_mask(input_ids, attention_mask)
|
| 4185 |
self._validate_position_ids(input_ids, position_ids)
|
|
|
|
| 128 |
config = AutoConfig.from_pretrained(
|
| 129 |
"your-namespace/advanced-transformers-lib",
|
| 130 |
trust_remote_code=True,
|
| 131 |
+
num_decoder_layers=12,
|
| 132 |
)
|
| 133 |
model = AutoModelForCausalLM.from_config(config)
|
| 134 |
|
|
|
|
| 725 |
def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
|
| 726 |
"""Raise if any (batch, head) slot would exceed the static buffer capacity.
|
| 727 |
|
| 728 |
+
Branches on whether the graph is being compiled. In compiled mode,
|
| 729 |
+
torch._assert_async fires asynchronously on the GPU when the condition
|
| 730 |
+
tensor is False. In eager mode, a plain RuntimeError is raised with a
|
| 731 |
+
descriptive message.
|
| 732 |
|
| 733 |
Args:
|
| 734 |
max_count: Scalar tensor — the maximum post-update count across all slots.
|
| 735 |
capacity: The static buffer capacity (mosrah_cache_length).
|
| 736 |
"""
|
| 737 |
if torch.compiler.is_compiling():
|
| 738 |
+
torch._assert_async(
|
| 739 |
+
max_count <= capacity,
|
| 740 |
+
"MoSRAHCache overflow: buffer capacity exceeded. "
|
| 741 |
+
"Increase mosrah_overallocation_factor in ShramConfig.",
|
| 742 |
+
)
|
| 743 |
else:
|
| 744 |
if max_count.item() > capacity:
|
| 745 |
raise RuntimeError(
|
|
|
|
| 860 |
# Cumulative count of all token positions presented through update() for
|
| 861 |
# this cache instance. This is the quantity HuggingFace generation reads
|
| 862 |
# through get_seq_length() to track how far along the sequence we are.
|
| 863 |
+
self._total_processed = torch.tensor(0)
|
| 864 |
|
| 865 |
def update( # type: ignore[override]
|
| 866 |
self,
|
|
|
|
| 1000 |
generation reads to track sequence progress and is not the same as active-token
|
| 1001 |
count or current window occupancy.
|
| 1002 |
"""
|
| 1003 |
+
return int(self._total_processed)
|
| 1004 |
|
| 1005 |
def get_max_cache_shape(self) -> int:
|
| 1006 |
return self.sliding_window
|
|
|
|
| 2303 |
# -----------
|
| 2304 |
"""Expert packing and unpacking for the MoSRAH path.
|
| 2305 |
|
| 2306 |
+
This module owns the token-choice -> expert-choice -> token-choice conversion
|
| 2307 |
+
boundary used by the sparse routed attention path. Its public behavior is fixed:
|
| 2308 |
+
|
| 2309 |
+
- setup_packing() prepares the auxiliary ordering data forwarded through packing
|
| 2310 |
+
and unpacking.
|
| 2311 |
+
- pack_experts() converts routed token-choice tensors into padded expert-choice
|
| 2312 |
+
tensors.
|
| 2313 |
+
- unpack_experts() restores token-choice ordering from padded expert-choice output.
|
| 2314 |
+
|
| 2315 |
+
Packed expert-choice tensors are expert-major and left-justified. For each expert,
|
| 2316 |
+
routed token copies occupy the prefix of that expert's packed block; padding occupies
|
| 2317 |
+
the suffix. Every packed entry uses the same ordering and transfer artifact, so
|
| 2318 |
+
hidden states, positions, masks, and probabilities remain aligned across the boundary.
|
| 2319 |
+
|
| 2320 |
+
pack_experts() returns a flat transfer index together with the packed entries. This
|
| 2321 |
+
index replaces the old boolean unpacking artifact as the source of truth for
|
| 2322 |
+
pack/unpack data movement: packing writes to those flat packed slots, and unpacking
|
| 2323 |
+
reads from those same slots.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2324 |
"""
|
| 2325 |
|
| 2326 |
|
|
|
|
| 2336 |
) -> dict[str, torch.Tensor]:
|
| 2337 |
"""Prepare the auxiliary ordering data used by pack/unpack.
|
| 2338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2339 |
Args:
|
| 2340 |
selected_heads: Routed token-choice head selections I of shape (B, N, K).
|
| 2341 |
|
| 2342 |
Returns:
|
| 2343 |
Auxiliary payload dict with keys:
|
| 2344 |
- "flattened_selected_heads": H of shape (B, N*K)
|
| 2345 |
+
- "permutation": expert-major permutation Pi of shape (B, N*K)
|
| 2346 |
- "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K)
|
| 2347 |
This dict is forwarded whole to pack_experts and unpack_experts.
|
| 2348 |
"""
|
|
|
|
| 2351 |
batch_size,
|
| 2352 |
sequence_length * num_selected_heads,
|
| 2353 |
)
|
| 2354 |
+
|
| 2355 |
+
# -----------------------------------------------------------------------
|
| 2356 |
+
# Establish the expert-major ordering invariant.
|
| 2357 |
+
#
|
| 2358 |
+
# BEA later applies a triangular causal mask inside each expert bucket. That
|
| 2359 |
+
# mask is only meaningful if routed copies for the same expert preserve their
|
| 2360 |
+
# source-token order. Stable sorting by selected head establishes that order.
|
| 2361 |
+
# -----------------------------------------------------------------------
|
| 2362 |
permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
|
| 2363 |
inverse_permutation = torch.argsort(permutation, dim=-1)
|
| 2364 |
|
|
|
|
| 2366 |
"flattened_selected_heads": flattened_selected_heads,
|
| 2367 |
"permutation": permutation,
|
| 2368 |
"inverse_permutation": inverse_permutation,
|
|
|
|
| 2369 |
}
|
| 2370 |
|
| 2371 |
|
|
|
|
| 2382 |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 2383 |
"""Pack token-choice tensors into expert-choice padded form.
|
| 2384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2385 |
Args:
|
| 2386 |
entries: Mapping from string keys to (tensor, padding_value) pairs. Each
|
| 2387 |
tensor has shape (B, N, ...) and is rearranged into expert-choice layout
|
|
|
|
| 2390 |
selected_heads: Routed head selections I of shape (B, N, K).
|
| 2391 |
num_experts: Total number of experts L.
|
| 2392 |
packed_length: Static packed time dimension T. All per-expert buffers are
|
| 2393 |
+
allocated to exactly this length. Raises if any actual per-expert token
|
| 2394 |
+
count exceeds this value.
|
|
|
|
| 2395 |
|
| 2396 |
Returns:
|
| 2397 |
Tuple of:
|
| 2398 |
- packed_entries: Dict with same keys as entries; each value is the
|
| 2399 |
packed tensor of shape (B, L, T, ...).
|
| 2400 |
+
- flat_packed_transfer_indices: Long tensor of shape (B*N*K,). Each value
|
| 2401 |
+
is the flattened padded expert-choice slot occupied by the corresponding
|
| 2402 |
+
routed-copy row. Pass this to unpack_experts().
|
| 2403 |
"""
|
| 2404 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 2405 |
+
num_routed_copies_per_batch = sequence_length * num_selected_heads
|
| 2406 |
+
num_routed_copies = batch_size * num_routed_copies_per_batch
|
| 2407 |
|
| 2408 |
flattened_selected_heads = setup["flattened_selected_heads"]
|
| 2409 |
permutation = setup["permutation"]
|
| 2410 |
|
| 2411 |
# -----------------------------------------------------------------------
|
| 2412 |
+
# Algorithm overview.
|
| 2413 |
+
#
|
| 2414 |
+
# Packing first builds one routed-copy row for each selected token/expert
|
| 2415 |
+
# pair, ordered by the stable expert-major permutation. Those rows contain
|
| 2416 |
+
# no padding. The final packed tensor reserves packed_length slots per expert.
|
| 2417 |
+
# The flat transfer index bridges those layouts by adding back the cumulative
|
| 2418 |
+
# padding skipped before each expert block.
|
| 2419 |
+
# -----------------------------------------------------------------------
|
| 2420 |
+
|
| 2421 |
+
# -----------------------------------------------------------------------
|
| 2422 |
+
# Build the shared routed-copy source rows.
|
| 2423 |
#
|
| 2424 |
+
# This tensor identifies the source token row for each selected token/expert
|
| 2425 |
+
# pair after the stable expert-major permutation. Every packed entry uses this
|
| 2426 |
+
# same row plan, so all entries remain aligned before padded materialization.
|
| 2427 |
# -----------------------------------------------------------------------
|
| 2428 |
source_token_indices = torch.arange(
|
| 2429 |
sequence_length,
|
|
|
|
| 2434 |
sequence_length,
|
| 2435 |
num_selected_heads,
|
| 2436 |
)
|
| 2437 |
+
flattened_source_token_indices = source_token_indices.reshape(
|
| 2438 |
batch_size,
|
| 2439 |
+
num_routed_copies_per_batch,
|
| 2440 |
)
|
| 2441 |
+
sorted_source_token_indices = flattened_source_token_indices.gather(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2442 |
dim=1,
|
| 2443 |
index=permutation,
|
| 2444 |
)
|
| 2445 |
|
| 2446 |
# -----------------------------------------------------------------------
|
| 2447 |
+
# Establish packed expert occupancy and capacity.
|
|
|
|
| 2448 |
#
|
| 2449 |
+
# tokens_per_expert tells how many routed-copy rows occupy the prefix of each
|
| 2450 |
+
# expert block. The padded layout is valid only when every prefix fits inside
|
| 2451 |
+
# the configured packed_length.
|
|
|
|
| 2452 |
# -----------------------------------------------------------------------
|
| 2453 |
tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts)
|
| 2454 |
+
_enforce_no_overflow(tokens_per_expert, packed_length)
|
|
|
|
|
|
|
| 2455 |
|
| 2456 |
# -----------------------------------------------------------------------
|
| 2457 |
+
# Build the flat insertion points for the padded expert frame.
|
| 2458 |
#
|
| 2459 |
+
# Routed-copy rows omit padding, while the packed frame reserves packed_length
|
| 2460 |
+
# slots for every expert. The transfer index adds back the cumulative padding
|
| 2461 |
+
# skipped before each expert block, producing one flat destination slot for
|
| 2462 |
+
# every routed-copy row. This tensor is forwarded to unpack_experts so removal
|
| 2463 |
+
# uses the same positions that insertion used.
|
| 2464 |
# -----------------------------------------------------------------------
|
| 2465 |
+
flat_tokens_per_expert = tokens_per_expert.reshape(-1)
|
| 2466 |
+
flat_padding_per_expert = packed_length - flat_tokens_per_expert
|
| 2467 |
+
flat_padding_before_expert = (
|
| 2468 |
+
flat_padding_per_expert.cumsum(dim=0) - flat_padding_per_expert
|
| 2469 |
+
)
|
| 2470 |
+
|
| 2471 |
+
flat_padding_for_routed_rows = torch.repeat_interleave(
|
| 2472 |
+
flat_padding_before_expert,
|
| 2473 |
+
flat_tokens_per_expert,
|
| 2474 |
+
output_size=num_routed_copies,
|
| 2475 |
+
)
|
| 2476 |
+
flat_routed_row_indices = torch.arange(
|
| 2477 |
+
num_routed_copies,
|
| 2478 |
device=flattened_selected_heads.device,
|
| 2479 |
dtype=torch.long,
|
| 2480 |
+
)
|
| 2481 |
+
flat_packed_transfer_indices = (
|
| 2482 |
+
flat_routed_row_indices + flat_padding_for_routed_rows
|
| 2483 |
+
)
|
| 2484 |
|
| 2485 |
# -----------------------------------------------------------------------
|
| 2486 |
+
# Materialize each entry through the shared routing and transfer artifacts.
|
| 2487 |
#
|
| 2488 |
+
# Each entry first gathers into the shared routed-copy order. The flat packed
|
| 2489 |
+
# allocation supplies padding, and the transfer index writes each routed-copy
|
| 2490 |
+
# row into its padded expert slot before the public shape is restored.
|
|
|
|
| 2491 |
# -----------------------------------------------------------------------
|
| 2492 |
packed_entries: dict[str, torch.Tensor] = {}
|
| 2493 |
for key, (tensor, padding_value) in entries.items():
|
| 2494 |
extra_shape = tensor.shape[2:]
|
| 2495 |
|
| 2496 |
+
# The sorted source index is shared across all entries; expanding it over
|
| 2497 |
+
# trailing dimensions lets the same routing/order plan apply to hidden
|
| 2498 |
+
# states, positions, masks, probabilities, and any other packed tensor.
|
| 2499 |
+
sorted_gather_indices = sorted_source_token_indices.view(
|
| 2500 |
batch_size,
|
| 2501 |
+
num_routed_copies_per_batch,
|
| 2502 |
*(1,) * len(extra_shape),
|
| 2503 |
).expand(-1, -1, *extra_shape)
|
| 2504 |
+
sorted_tensor = tensor.gather(dim=1, index=sorted_gather_indices)
|
| 2505 |
|
| 2506 |
packed_tensor = tensor.new_full(
|
| 2507 |
+
(batch_size * num_experts * packed_length, *extra_shape),
|
| 2508 |
fill_value=padding_value,
|
| 2509 |
)
|
| 2510 |
+
packed_tensor[flat_packed_transfer_indices] = sorted_tensor.reshape(
|
| 2511 |
+
num_routed_copies,
|
| 2512 |
+
*extra_shape,
|
| 2513 |
+
)
|
| 2514 |
+
packed_entries[key] = packed_tensor.reshape(
|
| 2515 |
+
batch_size,
|
| 2516 |
+
num_experts,
|
| 2517 |
+
packed_length,
|
| 2518 |
+
*extra_shape,
|
| 2519 |
+
)
|
| 2520 |
|
| 2521 |
+
return packed_entries, flat_packed_transfer_indices
|
|
|
|
|
|
|
|
|
|
| 2522 |
|
| 2523 |
|
| 2524 |
# ---------------------------------------------------------------------------
|
|
|
|
| 2528 |
def unpack_experts(
|
| 2529 |
expert_outputs: torch.Tensor,
|
| 2530 |
setup: dict[str, torch.Tensor],
|
| 2531 |
+
flat_packed_transfer_indices: torch.Tensor,
|
| 2532 |
selected_heads: torch.Tensor,
|
| 2533 |
) -> torch.Tensor:
|
| 2534 |
"""Restore token-choice ordering from BEA expert-choice output.
|
| 2535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2536 |
Args:
|
| 2537 |
expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
|
| 2538 |
setup: Auxiliary payload returned by setup_packing().
|
| 2539 |
+
flat_packed_transfer_indices: Transfer index returned by pack_experts().
|
| 2540 |
+
Each value identifies a routed-copy slot in the flattened padded
|
| 2541 |
+
expert-choice frame.
|
| 2542 |
selected_heads: Routed head selections I of shape (B, N, K).
|
| 2543 |
|
| 2544 |
Returns:
|
|
|
|
| 2547 |
inverse_permutation = setup["inverse_permutation"]
|
| 2548 |
|
| 2549 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 2550 |
+
num_routed_copies_per_batch = sequence_length * num_selected_heads
|
| 2551 |
hidden_dim = expert_outputs.shape[-1]
|
| 2552 |
|
| 2553 |
+
# -----------------------------------------------------------------------
|
| 2554 |
+
# Recover routed-copy rows from the same packed slots used at insertion.
|
| 2555 |
+
#
|
| 2556 |
+
# Packing writes into the forwarded flat slots, and unpacking reads from those
|
| 2557 |
+
# same slots before applying the inverse routing permutation back to
|
| 2558 |
+
# token-choice order.
|
| 2559 |
+
# -----------------------------------------------------------------------
|
| 2560 |
+
flat_expert_outputs = expert_outputs.reshape(-1, hidden_dim)
|
| 2561 |
+
flat_routed_copy_outputs = flat_expert_outputs[flat_packed_transfer_indices]
|
|
|
|
| 2562 |
|
| 2563 |
+
sorted_token_choice_outputs = flat_routed_copy_outputs.reshape(
|
| 2564 |
batch_size,
|
| 2565 |
+
num_routed_copies_per_batch,
|
| 2566 |
hidden_dim,
|
| 2567 |
)
|
| 2568 |
restored_outputs = sorted_token_choice_outputs.gather(
|
|
|
|
| 2581 |
# Helpers
|
| 2582 |
# ---------------------------------------------------------------------------
|
| 2583 |
|
|
|
|
|
|
|
| 2584 |
|
| 2585 |
+
def _enforce_no_overflow(tokens_per_expert: torch.Tensor, packed_length: int) -> None:
|
| 2586 |
+
"""Enforce that no expert bucket exceeds the preallocated packed length.
|
|
|
|
|
|
|
| 2587 |
|
| 2588 |
+
This check fires when the number of tokens assigned to any expert in any batch
|
| 2589 |
+
item exceeds mosrah_packed_length. When that limit is exceeded, the packed buffer
|
| 2590 |
+
is too small to hold all assignments and data would be dropped. Increase
|
| 2591 |
+
mosrah_overallocation_factor in ShramConfig to resolve.
|
| 2592 |
|
| 2593 |
Args:
|
| 2594 |
+
tokens_per_expert: Per-expert token counts, shape (B, num_experts).
|
| 2595 |
+
packed_length: The preallocated packed time dimension.
|
|
|
|
| 2596 |
"""
|
| 2597 |
if torch.compiler.is_compiling():
|
| 2598 |
+
torch._assert_async(
|
| 2599 |
+
tokens_per_expert.max() <= packed_length,
|
| 2600 |
+
"Expert packing overflow: expert bucket exceeds mosrah_packed_length. "
|
| 2601 |
+
"Increase mosrah_overallocation_factor in ShramConfig.",
|
| 2602 |
+
)
|
| 2603 |
else:
|
| 2604 |
+
max_count = tokens_per_expert.max().item()
|
| 2605 |
+
if max_count > packed_length:
|
| 2606 |
raise RuntimeError(
|
| 2607 |
"Expert packing overflow: at least one expert bucket contains more "
|
| 2608 |
"tokens than mosrah_packed_length allows. Increase "
|
| 2609 |
"mosrah_overallocation_factor in ShramConfig to resolve.\n"
|
| 2610 |
+
f"Packed length: {packed_length}\n"
|
| 2611 |
+
f"Head lengths: {tokens_per_expert}\n"
|
| 2612 |
)
|
| 2613 |
|
| 2614 |
|
|
|
|
| 2618 |
) -> torch.Tensor:
|
| 2619 |
"""Count how many routed token copies are assigned to each expert per batch item.
|
| 2620 |
|
| 2621 |
+
Uses scatter_add into a pre-sized (B, num_experts) buffer. Each position in
|
|
|
|
| 2622 |
flattened_selected_heads contributes one count to the corresponding expert slot.
|
| 2623 |
|
| 2624 |
Args:
|
|
|
|
| 2630 |
Counts tensor of shape (B, num_experts).
|
| 2631 |
"""
|
| 2632 |
batch_size = flattened_selected_heads.shape[0]
|
| 2633 |
+
tokens_per_expert = torch.zeros(
|
| 2634 |
batch_size,
|
| 2635 |
num_experts,
|
| 2636 |
device=flattened_selected_heads.device,
|
| 2637 |
+
dtype=torch.long,
|
| 2638 |
)
|
| 2639 |
+
tokens_per_expert.scatter_add_(
|
| 2640 |
dim=1,
|
| 2641 |
index=flattened_selected_heads,
|
| 2642 |
+
src=torch.ones_like(flattened_selected_heads, dtype=torch.long),
|
| 2643 |
)
|
| 2644 |
+
return tokens_per_expert
|
|
|
|
| 2645 |
# -----------
|
| 2646 |
# Inlined from: router.py
|
| 2647 |
# -----------
|
|
|
|
| 2815 |
self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
|
| 2816 |
|
| 2817 |
@staticmethod
|
| 2818 |
+
def get_best_proposals(
|
| 2819 |
tensor: torch.Tensor,
|
| 2820 |
dim: int,
|
| 2821 |
n: int | torch.Tensor,
|
|
|
|
| 2948 |
choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
|
| 2949 |
|
| 2950 |
unproposed_logits = logits.masked_fill(proposals, float('-inf'))
|
| 2951 |
+
new_proposals = cls.get_best_proposals(
|
| 2952 |
unproposed_logits, dim=-1, n=choices_deficit, capacity_scalar=min_choices,
|
| 2953 |
)
|
| 2954 |
proposals = proposals | new_proposals
|
|
|
|
| 2959 |
# Acceptances are recomputed from scratch each round so that a
|
| 2960 |
# stronger new proposal can displace a weaker prior one.
|
| 2961 |
proposed_logits = logits.masked_fill(~proposals, float('-inf'))
|
| 2962 |
+
acceptances = cls.get_best_proposals(
|
| 2963 |
proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
|
| 2964 |
)
|
| 2965 |
|
|
|
|
| 3341 |
"position_ids": (position_ids, 0),
|
| 3342 |
"active_mask": (active_mask, False),
|
| 3343 |
}
|
| 3344 |
+
packed, unpacking_map = pack_experts(entries, setup, selected_heads, self.num_experts, self.packed_length)
|
| 3345 |
packed_hidden_states = packed["hidden_states"]
|
| 3346 |
packed_positions = packed["position_ids"]
|
| 3347 |
active_mask = packed["active_mask"]
|
|
|
|
| 3377 |
token_choice_outputs = unpack_experts(
|
| 3378 |
expert_outputs=packed_outputs,
|
| 3379 |
setup=setup,
|
| 3380 |
+
flat_packed_transfer_indices=unpacking_map,
|
| 3381 |
selected_heads=selected_heads,
|
| 3382 |
)
|
| 3383 |
final_output = (
|
|
|
|
| 3876 |
|
| 3877 |
@staticmethod
|
| 3878 |
def create_masks_for_generate(
|
|
|
|
|
|
|
| 3879 |
attention_mask: torch.Tensor | None,
|
|
|
|
|
|
|
| 3880 |
**kwargs: Any,
|
| 3881 |
) -> torch.Tensor | None:
|
| 3882 |
"""Return the 2D attention_mask unchanged.
|
|
|
|
| 3930 |
raise ValueError(
|
| 3931 |
"position_ids must match the current input_ids shape exactly."
|
| 3932 |
)
|
| 3933 |
+
if position_ids.dtype != torch.long:
|
| 3934 |
raise TypeError("position_ids must be an long tensor.")
|
| 3935 |
|
| 3936 |
def _validate_labels(
|
|
|
|
| 3945 |
raise ValueError("labels must have shape (batch, seq_len).")
|
| 3946 |
if labels.shape != input_ids.shape:
|
| 3947 |
raise ValueError("labels must have the same shape as input_ids.")
|
| 3948 |
+
if labels.dtype != torch.long:
|
| 3949 |
raise TypeError("labels must be a long tensor.")
|
| 3950 |
|
| 3951 |
def _validate_cache_inputs(
|
|
|
|
| 4030 |
(violated).
|
| 4031 |
"""
|
| 4032 |
if torch.compiler.is_compiling():
|
| 4033 |
+
torch._assert_async(
|
| 4034 |
+
condition,
|
| 4035 |
+
"Uncached ShramForCausalLM: nonzero starting positions. "
|
| 4036 |
+
"Supply a ShramCache with prefix or rebase sequence to start at 0.",
|
| 4037 |
+
)
|
| 4038 |
else:
|
| 4039 |
if not condition.item():
|
| 4040 |
raise RuntimeError(
|
|
|
|
| 4044 |
"uncached sequence to start at 0.",
|
| 4045 |
)
|
| 4046 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4047 |
def _standardize_full_attention_mask(
|
| 4048 |
self,
|
| 4049 |
input_ids: torch.Tensor,
|
|
|
|
| 4141 |
# This keeps the main sequence readable while ensuring invalid states
|
| 4142 |
# fail before they can silently contaminate backbone execution.
|
| 4143 |
# ------------------------------------------------------------------
|
|
|
|
| 4144 |
self._validate_input_ids(input_ids)
|
| 4145 |
self._validate_attention_mask(input_ids, attention_mask)
|
| 4146 |
self._validate_position_ids(input_ids, position_ids)
|