Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- README.md +2 -1
- config.json +2 -1
- configuration.py +19 -1
- huggingface.py +176 -79
- tokenizer_config.json +1 -1
README.md
CHANGED
|
@@ -82,7 +82,8 @@ contains no weights. All values are overridable via kwargs.
|
|
| 82 |
| `embedding_width` | 512 |
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
-
| `
|
|
|
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
| 87 |
| `max_bid_rounds` | 10 |
|
| 88 |
| `mlp_width` | 1366 |
|
|
|
|
| 82 |
| `embedding_width` | 512 |
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
+
| `load_balance_loss_type` | ce |
|
| 86 |
+
| `load_balance_p` | 1.0 |
|
| 87 |
| `local_rope_theta` | 10000.0 |
|
| 88 |
| `max_bid_rounds` | 10 |
|
| 89 |
| `mlp_width` | 1366 |
|
config.json
CHANGED
|
@@ -9,7 +9,8 @@
|
|
| 9 |
"embedding_width": 512,
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
-
"
|
|
|
|
| 13 |
"local_rope_theta": 10000.0,
|
| 14 |
"max_bid_rounds": 10,
|
| 15 |
"mlp_width": 1366,
|
|
|
|
| 9 |
"embedding_width": 512,
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
+
"load_balance_loss_type": "ce",
|
| 13 |
+
"load_balance_p": 1.0,
|
| 14 |
"local_rope_theta": 10000.0,
|
| 15 |
"max_bid_rounds": 10,
|
| 16 |
"mlp_width": 1366,
|
configuration.py
CHANGED
|
@@ -94,6 +94,11 @@ class ShramConfig(PretrainedConfig):
|
|
| 94 |
cases are not expected under normal training. The bound exists as a
|
| 95 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 96 |
Default 10.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"""
|
| 98 |
|
| 99 |
model_type = "shram"
|
|
@@ -127,8 +132,9 @@ class ShramConfig(PretrainedConfig):
|
|
| 127 |
output_hidden_states: bool = False,
|
| 128 |
tie_word_embeddings: bool = False,
|
| 129 |
mosrah_overallocation_factor: float = 2.0,
|
| 130 |
-
load_balance_p: float =
|
| 131 |
max_bid_rounds: int = 10,
|
|
|
|
| 132 |
**kwargs
|
| 133 |
):
|
| 134 |
if head_dim % 2 != 0:
|
|
@@ -174,6 +180,17 @@ class ShramConfig(PretrainedConfig):
|
|
| 174 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 175 |
)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
self.vocab_size = vocab_size
|
| 178 |
self.embedding_width = embedding_width
|
| 179 |
self.mlp_width = mlp_width
|
|
@@ -194,6 +211,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 194 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 195 |
self.load_balance_p = load_balance_p
|
| 196 |
self.max_bid_rounds = max_bid_rounds
|
|
|
|
| 197 |
self.attention_dropout = attention_dropout
|
| 198 |
self.use_cache = use_cache
|
| 199 |
|
|
|
|
| 94 |
cases are not expected under normal training. The bound exists as a
|
| 95 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 96 |
Default 10.
|
| 97 |
+
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 98 |
+
One of ``"gshard"``, ``"ce"``, or ``"bce"``. ``"ce"`` (cross-entropy)
|
| 99 |
+
is the default; its log-probability signal scales with violation severity
|
| 100 |
+
and makes correction magnitude proportional to routing imbalance.
|
| 101 |
+
Default ``"ce"``.
|
| 102 |
"""
|
| 103 |
|
| 104 |
model_type = "shram"
|
|
|
|
| 132 |
output_hidden_states: bool = False,
|
| 133 |
tie_word_embeddings: bool = False,
|
| 134 |
mosrah_overallocation_factor: float = 2.0,
|
| 135 |
+
load_balance_p: float = 1.0,
|
| 136 |
max_bid_rounds: int = 10,
|
| 137 |
+
load_balance_loss_type: str = "ce",
|
| 138 |
**kwargs
|
| 139 |
):
|
| 140 |
if head_dim % 2 != 0:
|
|
|
|
| 180 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 181 |
)
|
| 182 |
|
| 183 |
+
_supported_loss_types = {"gshard", "ce", "bce"}
|
| 184 |
+
if load_balance_loss_type not in _supported_loss_types:
|
| 185 |
+
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"load_balance_loss_type must be one of {supported}, "
|
| 188 |
+
f"got {load_balance_loss_type!r}."
|
| 189 |
+
)
|
| 190 |
+
if load_balance_loss_type == "ce" and load_balance_p != 1.0:
|
| 191 |
+
raise ValueError("In cross entropy mode, aggregation of "
|
| 192 |
+
"frequencies must be with mean 1.0")
|
| 193 |
+
|
| 194 |
self.vocab_size = vocab_size
|
| 195 |
self.embedding_width = embedding_width
|
| 196 |
self.mlp_width = mlp_width
|
|
|
|
| 211 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 212 |
self.load_balance_p = load_balance_p
|
| 213 |
self.max_bid_rounds = max_bid_rounds
|
| 214 |
+
self.load_balance_loss_type = load_balance_loss_type
|
| 215 |
self.attention_dropout = attention_dropout
|
| 216 |
self.use_cache = use_cache
|
| 217 |
|
huggingface.py
CHANGED
|
@@ -44,6 +44,7 @@ from torch import nn
|
|
| 44 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 45 |
from torch.nn.attention.flex_attention import flex_attention
|
| 46 |
import torch.nn.functional as F
|
|
|
|
| 47 |
from typing import Optional
|
| 48 |
|
| 49 |
|
|
@@ -181,6 +182,11 @@ class ShramConfig(PretrainedConfig):
|
|
| 181 |
cases are not expected under normal training. The bound exists as a
|
| 182 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 183 |
Default 10.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
"""
|
| 185 |
|
| 186 |
model_type = "shram"
|
|
@@ -214,8 +220,9 @@ class ShramConfig(PretrainedConfig):
|
|
| 214 |
output_hidden_states: bool = False,
|
| 215 |
tie_word_embeddings: bool = False,
|
| 216 |
mosrah_overallocation_factor: float = 2.0,
|
| 217 |
-
load_balance_p: float =
|
| 218 |
max_bid_rounds: int = 10,
|
|
|
|
| 219 |
**kwargs
|
| 220 |
):
|
| 221 |
if head_dim % 2 != 0:
|
|
@@ -261,6 +268,17 @@ class ShramConfig(PretrainedConfig):
|
|
| 261 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 262 |
)
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
self.vocab_size = vocab_size
|
| 265 |
self.embedding_width = embedding_width
|
| 266 |
self.mlp_width = mlp_width
|
|
@@ -281,6 +299,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 281 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 282 |
self.load_balance_p = load_balance_p
|
| 283 |
self.max_bid_rounds = max_bid_rounds
|
|
|
|
| 284 |
self.attention_dropout = attention_dropout
|
| 285 |
self.use_cache = use_cache
|
| 286 |
|
|
@@ -2714,9 +2733,10 @@ This separation is architecturally critical: expert_bias drives selection (and t
|
|
| 2714 |
balancing) but does not corrupt the gradient path from the output through routing_probs
|
| 2715 |
back to the routing projection weights.
|
| 2716 |
|
| 2717 |
-
The router also computes and returns the load balance loss via
|
| 2718 |
-
|
| 2719 |
-
|
|
|
|
| 2720 |
|
| 2721 |
The router additionally computes and returns MaxVio, a detached scalar summarising
|
| 2722 |
routing imbalance for the current forward pass:
|
|
@@ -2737,94 +2757,165 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
|
| 2737 |
# -----------
|
| 2738 |
# Inlined from: load_balance_loss.py
|
| 2739 |
# -----------
|
| 2740 |
-
"""
|
| 2741 |
-
|
| 2742 |
-
This module
|
| 2743 |
-
|
| 2744 |
-
|
| 2745 |
-
|
| 2746 |
-
|
| 2747 |
-
|
| 2748 |
-
|
| 2749 |
-
|
| 2750 |
-
|
| 2751 |
-
|
| 2752 |
-
|
| 2753 |
-
|
| 2754 |
-
|
| 2755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2756 |
"""
|
| 2757 |
|
| 2758 |
|
| 2759 |
|
| 2760 |
|
| 2761 |
-
class LoadBalanceLoss(torch.autograd.Function):
|
| 2762 |
-
"""Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
|
| 2763 |
|
| 2764 |
-
|
|
|
|
|
|
|
| 2765 |
|
| 2766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2767 |
|
| 2768 |
-
|
|
|
|
|
|
|
| 2769 |
|
| 2770 |
-
|
|
|
|
|
|
|
|
|
|
| 2771 |
|
| 2772 |
-
|
| 2773 |
-
|
| 2774 |
-
|
| 2775 |
-
|
| 2776 |
|
| 2777 |
-
|
|
|
|
| 2778 |
"""
|
|
|
|
|
|
|
| 2779 |
|
| 2780 |
-
@staticmethod
|
| 2781 |
-
def forward(
|
| 2782 |
-
ctx: torch.autograd.function.FunctionCtx,
|
| 2783 |
-
expert_bias: torch.Tensor,
|
| 2784 |
-
routing_freqs: torch.Tensor,
|
| 2785 |
-
) -> torch.Tensor:
|
| 2786 |
-
"""Compute the load balance loss.
|
| 2787 |
|
| 2788 |
-
|
| 2789 |
-
|
| 2790 |
-
|
| 2791 |
-
|
| 2792 |
-
|
| 2793 |
-
from the discrete TopK selection — not differentiable.
|
| 2794 |
|
| 2795 |
-
|
| 2796 |
-
|
| 2797 |
-
|
| 2798 |
-
|
| 2799 |
-
# imbalance = f_l - 1/L for each head: positive means overloaded, negative means
|
| 2800 |
-
# underloaded. Saved for backward where sign(imbalance) determines the direction
|
| 2801 |
-
# of the bias-correction update.
|
| 2802 |
-
imbalance = routing_freqs - 1.0 / L
|
| 2803 |
-
ctx.save_for_backward(imbalance)
|
| 2804 |
-
return imbalance.abs().sum()
|
| 2805 |
|
| 2806 |
-
|
| 2807 |
-
|
| 2808 |
-
|
| 2809 |
-
grad_output: torch.Tensor,
|
| 2810 |
-
) -> tuple[torch.Tensor, None]:
|
| 2811 |
-
"""Emit the DeepSeek-style bias-correction gradient.
|
| 2812 |
|
| 2813 |
-
|
| 2814 |
-
|
| 2815 |
-
|
| 2816 |
-
|
| 2817 |
-
correction magnitude is proportional to the loss weight chosen by the
|
| 2818 |
-
consumer.
|
| 2819 |
|
| 2820 |
-
|
| 2821 |
-
|
| 2822 |
-
|
| 2823 |
-
|
| 2824 |
-
|
| 2825 |
-
|
| 2826 |
-
|
| 2827 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2828 |
|
| 2829 |
|
| 2830 |
|
|
@@ -2857,6 +2948,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 2857 |
self.capacity = config.mosrah_packed_length
|
| 2858 |
|
| 2859 |
self.max_bid_rounds = config.max_bid_rounds
|
|
|
|
| 2860 |
|
| 2861 |
# W_r: routing projection, no bias (paper specifies xW_r, no additional term).
|
| 2862 |
self.routing_projection = nn.Linear(
|
|
@@ -3143,12 +3235,13 @@ class MoSRAHRouter(nn.Module):
|
|
| 3143 |
logits, self.expert_bias.expand_as(logits), dim=-1
|
| 3144 |
).mean().detach()
|
| 3145 |
|
|
|
|
| 3146 |
routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
|
| 3147 |
|
| 3148 |
# Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
|
| 3149 |
# selection. expert_bias is added to logits before softmax so that the bias
|
| 3150 |
# shifts selection probability without rescaling the unbiased distribution.
|
| 3151 |
-
biased_logits = logits + self.expert_bias
|
| 3152 |
biased_logits = self.balance_capacity(
|
| 3153 |
biased_logits,
|
| 3154 |
used_capacity,
|
|
@@ -3189,10 +3282,14 @@ class MoSRAHRouter(nn.Module):
|
|
| 3189 |
p = self.load_balance_p
|
| 3190 |
routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
|
| 3191 |
|
| 3192 |
-
#
|
| 3193 |
-
#
|
| 3194 |
-
#
|
| 3195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3196 |
|
| 3197 |
# MaxVio is a detached monitoring scalar following the paper's formula
|
| 3198 |
# L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
|
|
|
|
| 44 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 45 |
from torch.nn.attention.flex_attention import flex_attention
|
| 46 |
import torch.nn.functional as F
|
| 47 |
+
from typing import Callable
|
| 48 |
from typing import Optional
|
| 49 |
|
| 50 |
|
|
|
|
| 182 |
cases are not expected under normal training. The bound exists as a
|
| 183 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 184 |
Default 10.
|
| 185 |
+
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 186 |
+
One of ``"gshard"``, ``"ce"``, or ``"bce"``. ``"ce"`` (cross-entropy)
|
| 187 |
+
is the default; its log-probability signal scales with violation severity
|
| 188 |
+
and makes correction magnitude proportional to routing imbalance.
|
| 189 |
+
Default ``"ce"``.
|
| 190 |
"""
|
| 191 |
|
| 192 |
model_type = "shram"
|
|
|
|
| 220 |
output_hidden_states: bool = False,
|
| 221 |
tie_word_embeddings: bool = False,
|
| 222 |
mosrah_overallocation_factor: float = 2.0,
|
| 223 |
+
load_balance_p: float = 1.0,
|
| 224 |
max_bid_rounds: int = 10,
|
| 225 |
+
load_balance_loss_type: str = "ce",
|
| 226 |
**kwargs
|
| 227 |
):
|
| 228 |
if head_dim % 2 != 0:
|
|
|
|
| 268 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 269 |
)
|
| 270 |
|
| 271 |
+
_supported_loss_types = {"gshard", "ce", "bce"}
|
| 272 |
+
if load_balance_loss_type not in _supported_loss_types:
|
| 273 |
+
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 274 |
+
raise ValueError(
|
| 275 |
+
f"load_balance_loss_type must be one of {supported}, "
|
| 276 |
+
f"got {load_balance_loss_type!r}."
|
| 277 |
+
)
|
| 278 |
+
if load_balance_loss_type == "ce" and load_balance_p != 1.0:
|
| 279 |
+
raise ValueError("In cross entropy mode, aggregation of "
|
| 280 |
+
"frequencies must be with mean 1.0")
|
| 281 |
+
|
| 282 |
self.vocab_size = vocab_size
|
| 283 |
self.embedding_width = embedding_width
|
| 284 |
self.mlp_width = mlp_width
|
|
|
|
| 299 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 300 |
self.load_balance_p = load_balance_p
|
| 301 |
self.max_bid_rounds = max_bid_rounds
|
| 302 |
+
self.load_balance_loss_type = load_balance_loss_type
|
| 303 |
self.attention_dropout = attention_dropout
|
| 304 |
self.use_cache = use_cache
|
| 305 |
|
|
|
|
| 2733 |
balancing) but does not corrupt the gradient path from the output through routing_probs
|
| 2734 |
back to the routing projection weights.
|
| 2735 |
|
| 2736 |
+
The router also computes and returns the load balance loss via a log-probability auxiliary
|
| 2737 |
+
loss (see load_balance_loss.py). The loss formulation is selected by config; the default
|
| 2738 |
+
is cross-entropy. Gradients flow only to expert_bias — routing_projection.weight is
|
| 2739 |
+
isolated by detaching logits before computing assignment probabilities.
|
| 2740 |
|
| 2741 |
The router additionally computes and returns MaxVio, a detached scalar summarising
|
| 2742 |
routing imbalance for the current forward pass:
|
|
|
|
| 2757 |
# -----------
|
| 2758 |
# Inlined from: load_balance_loss.py
|
| 2759 |
# -----------
|
| 2760 |
+
"""Log-probability auxiliary loss functions for MoSRAH load balancing.
|
| 2761 |
+
|
| 2762 |
+
This module provides three load-balance loss formulations and a factory that selects
|
| 2763 |
+
among them. All formulations share the same external contract and the same gradient
|
| 2764 |
+
isolation property: assignment probabilities are computed from detached logits plus
|
| 2765 |
+
expert_bias, so only expert_bias receives gradients from the loss signal. The routing
|
| 2766 |
+
projection weights are not reachable from any returned loss.
|
| 2767 |
+
|
| 2768 |
+
The factory is the intended entry point. The caller (MoSRAHRouter) constructs the
|
| 2769 |
+
loss callable once at init and invokes it each forward pass.
|
| 2770 |
+
|
| 2771 |
+
Log-probability formulations (ce, bce) are preferred over linear ones (gshard) because
|
| 2772 |
+
their gradient magnitude scales with how far the distribution deviates from the target.
|
| 2773 |
+
A linear signal can be outrun by routing concentrations that diverge nonlinearly; a
|
| 2774 |
+
log-probability signal cannot.
|
| 2775 |
+
|
| 2776 |
+
The external contract for all returned callables is:
|
| 2777 |
+
|
| 2778 |
+
loss_fn(routing_freqs, assignment_probs) -> scalar Tensor
|
| 2779 |
+
|
| 2780 |
+
routing_freqs: (L,) realized routing frequencies f_i, detached.
|
| 2781 |
+
assignment_probs: (L,) soft assignment probabilities p_i with gradient through
|
| 2782 |
+
expert_bias. Caller must compute these via
|
| 2783 |
+
softmax(logits.detach() + expert_bias) to preserve isolation.
|
| 2784 |
"""
|
| 2785 |
|
| 2786 |
|
| 2787 |
|
| 2788 |
|
|
|
|
|
|
|
| 2789 |
|
| 2790 |
+
# ---------------------------------------------------------------------------
|
| 2791 |
+
# Loss functions
|
| 2792 |
+
# ---------------------------------------------------------------------------
|
| 2793 |
|
| 2794 |
+
def gshard_loss(
|
| 2795 |
+
routing_freqs: torch.Tensor,
|
| 2796 |
+
assignment_probs: torch.Tensor,
|
| 2797 |
+
) -> torch.Tensor:
|
| 2798 |
+
"""GShard-style linear load-balance loss.
|
| 2799 |
|
| 2800 |
+
Computes (1/L) * Σ_i f_i * p_i, where L is the number of expert heads,
|
| 2801 |
+
f_i is the realized routing frequency for head i, and p_i is the soft
|
| 2802 |
+
assignment probability for head i.
|
| 2803 |
|
| 2804 |
+
The fixed point of this loss under gradient descent is uniform routing:
|
| 2805 |
+
when p_i = 1/L for all i, the loss is minimized at 1/L (independent of f_i).
|
| 2806 |
+
The linear signal is the weakest of the three formulations — gradient magnitude
|
| 2807 |
+
does not grow with deviation from the target. Provided for comparison.
|
| 2808 |
|
| 2809 |
+
Args:
|
| 2810 |
+
routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
|
| 2811 |
+
assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
|
| 2812 |
+
flows to expert_bias through this tensor.
|
| 2813 |
|
| 2814 |
+
Returns:
|
| 2815 |
+
Scalar loss tensor.
|
| 2816 |
"""
|
| 2817 |
+
L = routing_freqs.shape[0]
|
| 2818 |
+
return (routing_freqs * assignment_probs).sum() / L
|
| 2819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2820 |
|
| 2821 |
+
def ce_loss(
|
| 2822 |
+
routing_freqs: torch.Tensor,
|
| 2823 |
+
assignment_probs: torch.Tensor,
|
| 2824 |
+
) -> torch.Tensor:
|
| 2825 |
+
"""Cross-entropy load-balance loss.
|
|
|
|
| 2826 |
|
| 2827 |
+
Computes -(1/(L-1)) * Σ_i (1 - f_i) * log(p_i), where the weight (1 - f_i)
|
| 2828 |
+
suppresses the signal for overloaded heads (high f_i → weight near zero) and
|
| 2829 |
+
amplifies it for underloaded heads (low f_i → weight near 1). This makes the
|
| 2830 |
+
loss push probability mass toward under-utilized experts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2831 |
|
| 2832 |
+
The (1/(L-1)) normalization makes the coefficient interpretable as a controller
|
| 2833 |
+
strength independent of expert count. The log-probability signal grows as p_i
|
| 2834 |
+
deviates from the target, providing correction that scales with violation severity.
|
|
|
|
|
|
|
|
|
|
| 2835 |
|
| 2836 |
+
Args:
|
| 2837 |
+
routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
|
| 2838 |
+
assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
|
| 2839 |
+
flows to expert_bias through this tensor.
|
|
|
|
|
|
|
| 2840 |
|
| 2841 |
+
Returns:
|
| 2842 |
+
Scalar loss tensor.
|
| 2843 |
+
"""
|
| 2844 |
+
L = routing_freqs.shape[0]
|
| 2845 |
+
# Numerical stability: torch.log is safe here because softmax outputs are
|
| 2846 |
+
# strictly positive. The (1 - f_i) weight goes to zero exactly when f_i = 1,
|
| 2847 |
+
# which can only occur with a single head, so the 0 * (-inf) degenerate case
|
| 2848 |
+
# does not arise in practice.
|
| 2849 |
+
return -(((1.0 - routing_freqs) * torch.log(assignment_probs)).sum()) / (L - 1)
|
| 2850 |
+
|
| 2851 |
+
|
| 2852 |
+
def bce_loss(
|
| 2853 |
+
routing_freqs: torch.Tensor,
|
| 2854 |
+
assignment_probs: torch.Tensor,
|
| 2855 |
+
) -> torch.Tensor:
|
| 2856 |
+
"""Binary cross-entropy load-balance loss.
|
| 2857 |
+
|
| 2858 |
+
Computes -(1/L) * Σ_i [(1 - f_i) * log(p_i) + f_i * log(1 - p_i)], where
|
| 2859 |
+
each head is treated as an independent binary target. Unlike CE, BCE maintains
|
| 2860 |
+
a repulsion signal from saturated experts: when f_i → 1, the weight on
|
| 2861 |
+
log(1 - p_i) drives p_i away from 1, preventing runaway concentration.
|
| 2862 |
+
|
| 2863 |
+
log(1 - p_i) is computed as log1p(-p_i) for numerical safety near p_i = 1.
|
| 2864 |
+
|
| 2865 |
+
Args:
|
| 2866 |
+
routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
|
| 2867 |
+
assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
|
| 2868 |
+
flows to expert_bias through this tensor.
|
| 2869 |
+
|
| 2870 |
+
Returns:
|
| 2871 |
+
Scalar loss tensor.
|
| 2872 |
+
"""
|
| 2873 |
+
L = routing_freqs.shape[0]
|
| 2874 |
+
positive_term = (1.0 - routing_freqs) * torch.log(assignment_probs)
|
| 2875 |
+
# log1p(-p) instead of log(1-p): avoids catastrophic cancellation when p is
|
| 2876 |
+
# close to 1, where (1 - p) loses precision and log produces large errors.
|
| 2877 |
+
negative_term = routing_freqs * torch.log1p(-assignment_probs)
|
| 2878 |
+
return -(positive_term + negative_term).sum() / L
|
| 2879 |
+
|
| 2880 |
+
|
| 2881 |
+
# ---------------------------------------------------------------------------
|
| 2882 |
+
# Factory
|
| 2883 |
+
# ---------------------------------------------------------------------------
|
| 2884 |
+
|
| 2885 |
+
_LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = {
|
| 2886 |
+
"gshard": gshard_loss,
|
| 2887 |
+
"ce": ce_loss,
|
| 2888 |
+
"bce": bce_loss,
|
| 2889 |
+
}
|
| 2890 |
+
|
| 2891 |
+
|
| 2892 |
+
def make_load_balance_loss(
|
| 2893 |
+
loss_type: str,
|
| 2894 |
+
) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 2895 |
+
"""Return a load-balance loss callable for the requested formulation.
|
| 2896 |
+
|
| 2897 |
+
All returned callables share the same external contract:
|
| 2898 |
+
|
| 2899 |
+
loss_fn(routing_freqs: Tensor, assignment_probs: Tensor) -> scalar Tensor
|
| 2900 |
+
|
| 2901 |
+
The caller is responsible for computing assignment_probs via
|
| 2902 |
+
softmax(logits.detach() + expert_bias) to ensure gradient isolation.
|
| 2903 |
+
|
| 2904 |
+
Args:
|
| 2905 |
+
loss_type: One of ``"gshard"``, ``"ce"``, or ``"bce"``.
|
| 2906 |
+
|
| 2907 |
+
Returns:
|
| 2908 |
+
Loss callable matching the shared contract.
|
| 2909 |
+
|
| 2910 |
+
Raises:
|
| 2911 |
+
ValueError: If loss_type is not one of the supported values.
|
| 2912 |
+
"""
|
| 2913 |
+
if loss_type not in _LOSS_REGISTRY:
|
| 2914 |
+
supported = ", ".join(f'"{k}"' for k in _LOSS_REGISTRY)
|
| 2915 |
+
raise ValueError(
|
| 2916 |
+
f"load_balance_loss_type must be one of {supported}, got {loss_type!r}."
|
| 2917 |
+
)
|
| 2918 |
+
return _LOSS_REGISTRY[loss_type]
|
| 2919 |
|
| 2920 |
|
| 2921 |
|
|
|
|
| 2948 |
self.capacity = config.mosrah_packed_length
|
| 2949 |
|
| 2950 |
self.max_bid_rounds = config.max_bid_rounds
|
| 2951 |
+
self._load_balance_loss = make_load_balance_loss(config.load_balance_loss_type)
|
| 2952 |
|
| 2953 |
# W_r: routing projection, no bias (paper specifies xW_r, no additional term).
|
| 2954 |
self.routing_projection = nn.Linear(
|
|
|
|
| 3235 |
logits, self.expert_bias.expand_as(logits), dim=-1
|
| 3236 |
).mean().detach()
|
| 3237 |
|
| 3238 |
+
# Routing scores. Direct.
|
| 3239 |
routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
|
| 3240 |
|
| 3241 |
# Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
|
| 3242 |
# selection. expert_bias is added to logits before softmax so that the bias
|
| 3243 |
# shifts selection probability without rescaling the unbiased distribution.
|
| 3244 |
+
biased_logits = logits.detach() + self.expert_bias
|
| 3245 |
biased_logits = self.balance_capacity(
|
| 3246 |
biased_logits,
|
| 3247 |
used_capacity,
|
|
|
|
| 3282 |
p = self.load_balance_p
|
| 3283 |
routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
|
| 3284 |
|
| 3285 |
+
# Active-token mean softmax probabilities. Detaching logits before softmax
|
| 3286 |
+
# ensures the only differentiable path into p is through expert_bias — the
|
| 3287 |
+
# load balance loss cannot reach routing_projection.weight.
|
| 3288 |
+
biased_probs = biased_routing_scores # (B, N, L)
|
| 3289 |
+
active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
|
| 3290 |
+
assignment_probs = (biased_probs * active_float).sum(dim=(0, 1)) # (L,) unnorm
|
| 3291 |
+
assignment_probs = assignment_probs / active_mask.float().sum() # (L,) norm
|
| 3292 |
+
load_balance_loss = self._load_balance_loss(routing_freqs, assignment_probs)
|
| 3293 |
|
| 3294 |
# MaxVio is a detached monitoring scalar following the paper's formula
|
| 3295 |
# L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
|
tokenizer_config.json
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
"bos_token": "<|endoftext|>",
|
| 5 |
"eos_token": "<|endoftext|>",
|
| 6 |
"errors": "replace",
|
| 7 |
-
"is_local":
|
| 8 |
"local_files_only": false,
|
| 9 |
"model_max_length": 1000000000000000019884624838656,
|
| 10 |
"pad_token": "<|padding|>",
|
|
|
|
| 4 |
"bos_token": "<|endoftext|>",
|
| 5 |
"eos_token": "<|endoftext|>",
|
| 6 |
"errors": "replace",
|
| 7 |
+
"is_local": true,
|
| 8 |
"local_files_only": false,
|
| 9 |
"model_max_length": 1000000000000000019884624838656,
|
| 10 |
"pad_token": "<|padding|>",
|