Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- README.md +1 -1
- config.json +2 -2
- configuration.py +13 -11
- huggingface.py +171 -15
README.md
CHANGED
|
@@ -82,7 +82,7 @@ contains no weights. All values are overridable via kwargs.
|
|
| 82 |
| `embedding_width` | 512 |
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
-
| `load_balance_loss_type` |
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
| 87 |
| `max_bid_rounds` | 10 |
|
| 88 |
| `maximum_expert_overclaim` | 20 |
|
|
|
|
| 82 |
| `embedding_width` | 512 |
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
+
| `load_balance_loss_type` | causal_overcapacity |
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
| 87 |
| `max_bid_rounds` | 10 |
|
| 88 |
| `maximum_expert_overclaim` | 20 |
|
config.json
CHANGED
|
@@ -9,7 +9,7 @@
|
|
| 9 |
"embedding_width": 512,
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
-
"load_balance_loss_type": "
|
| 13 |
"local_rope_theta": 10000.0,
|
| 14 |
"max_bid_rounds": 10,
|
| 15 |
"maximum_expert_overclaim": 20,
|
|
@@ -25,7 +25,7 @@
|
|
| 25 |
"rope_mode": "main_sequence",
|
| 26 |
"tie_word_embeddings": false,
|
| 27 |
"training_sequence_length": 1024,
|
| 28 |
-
"transformers_version": "5.
|
| 29 |
"use_cache": true,
|
| 30 |
"vocab_size": 50277,
|
| 31 |
"window_size": 128
|
|
|
|
| 9 |
"embedding_width": 512,
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
+
"load_balance_loss_type": "causal_overcapacity",
|
| 13 |
"local_rope_theta": 10000.0,
|
| 14 |
"max_bid_rounds": 10,
|
| 15 |
"maximum_expert_overclaim": 20,
|
|
|
|
| 25 |
"rope_mode": "main_sequence",
|
| 26 |
"tie_word_embeddings": false,
|
| 27 |
"training_sequence_length": 1024,
|
| 28 |
+
"transformers_version": "5.11.0",
|
| 29 |
"use_cache": true,
|
| 30 |
"vocab_size": 50277,
|
| 31 |
"window_size": 128
|
configuration.py
CHANGED
|
@@ -91,17 +91,19 @@ class ShramConfig(PretrainedConfig):
|
|
| 91 |
correctness guard β exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 92 |
Default 10.
|
| 93 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 94 |
-
One of ``"gshard"``, ``"ce"``, ``"bce"``,
|
| 95 |
-
``"
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
maximum_expert_overclaim: Maximum number of tokens an expert may receive above
|
| 101 |
-
its ideal allocation trajectory before
|
| 102 |
-
|
| 103 |
Larger values permit short-lived semantic specialization before correction.
|
| 104 |
-
|
| 105 |
Must be non-negative. Default 20.
|
| 106 |
"""
|
| 107 |
|
|
@@ -137,7 +139,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 137 |
tie_word_embeddings: bool = False,
|
| 138 |
mosrah_overallocation_factor: float = 2.0,
|
| 139 |
max_bid_rounds: int = 10,
|
| 140 |
-
load_balance_loss_type: str = "
|
| 141 |
maximum_expert_overclaim: int = 20,
|
| 142 |
**kwargs
|
| 143 |
):
|
|
@@ -185,7 +187,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 185 |
f"got {maximum_expert_overclaim}."
|
| 186 |
)
|
| 187 |
|
| 188 |
-
_supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
|
| 189 |
if load_balance_loss_type not in _supported_loss_types:
|
| 190 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 191 |
raise ValueError(
|
|
|
|
| 91 |
correctness guard β exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 92 |
Default 10.
|
| 93 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 94 |
+
One of ``"gshard"``, ``"ce"``, ``"bce"``, ``"temporal_overcapacity"``, or
|
| 95 |
+
``"causal_overcapacity"``. ``"causal_overcapacity"`` (the default) attributes
|
| 96 |
+
violations to the causal trajectory that produced them β each expert
|
| 97 |
+
accumulates a running mean of its selection log-probability and the loss
|
| 98 |
+
penalises the gap between overloaded and typical trajectories. Like
|
| 99 |
+
``"temporal_overcapacity"``, it fires only when a violation exists and shuts
|
| 100 |
+
off automatically, making it safe to weight strongly. Default
|
| 101 |
+
``"causal_overcapacity"``.
|
| 102 |
maximum_expert_overclaim: Maximum number of tokens an expert may receive above
|
| 103 |
+
its ideal allocation trajectory before either overcapacity loss fires.
|
| 104 |
+
A value of 0 means violations trigger immediately at any imbalance.
|
| 105 |
Larger values permit short-lived semantic specialization before correction.
|
| 106 |
+
Used by both ``"temporal_overcapacity"`` and ``"causal_overcapacity"``.
|
| 107 |
Must be non-negative. Default 20.
|
| 108 |
"""
|
| 109 |
|
|
|
|
| 139 |
tie_word_embeddings: bool = False,
|
| 140 |
mosrah_overallocation_factor: float = 2.0,
|
| 141 |
max_bid_rounds: int = 10,
|
| 142 |
+
load_balance_loss_type: str = "causal_overcapacity",
|
| 143 |
maximum_expert_overclaim: int = 20,
|
| 144 |
**kwargs
|
| 145 |
):
|
|
|
|
| 187 |
f"got {maximum_expert_overclaim}."
|
| 188 |
)
|
| 189 |
|
| 190 |
+
_supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity", "causal_overcapacity"}
|
| 191 |
if load_balance_loss_type not in _supported_loss_types:
|
| 192 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 193 |
raise ValueError(
|
huggingface.py
CHANGED
|
@@ -178,17 +178,19 @@ class ShramConfig(PretrainedConfig):
|
|
| 178 |
correctness guard β exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 179 |
Default 10.
|
| 180 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 181 |
-
One of ``"gshard"``, ``"ce"``, ``"bce"``,
|
| 182 |
-
``"
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
maximum_expert_overclaim: Maximum number of tokens an expert may receive above
|
| 188 |
-
its ideal allocation trajectory before
|
| 189 |
-
|
| 190 |
Larger values permit short-lived semantic specialization before correction.
|
| 191 |
-
|
| 192 |
Must be non-negative. Default 20.
|
| 193 |
"""
|
| 194 |
|
|
@@ -224,7 +226,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 224 |
tie_word_embeddings: bool = False,
|
| 225 |
mosrah_overallocation_factor: float = 2.0,
|
| 226 |
max_bid_rounds: int = 10,
|
| 227 |
-
load_balance_loss_type: str = "
|
| 228 |
maximum_expert_overclaim: int = 20,
|
| 229 |
**kwargs
|
| 230 |
):
|
|
@@ -272,7 +274,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 272 |
f"got {maximum_expert_overclaim}."
|
| 273 |
)
|
| 274 |
|
| 275 |
-
_supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
|
| 276 |
if load_balance_loss_type not in _supported_loss_types:
|
| 277 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 278 |
raise ValueError(
|
|
@@ -2768,7 +2770,7 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, Β§MaxVio.
|
|
| 2768 |
# -----------
|
| 2769 |
"""Log-probability auxiliary loss functions for MoSRAH load balancing.
|
| 2770 |
|
| 2771 |
-
This module provides
|
| 2772 |
helpers, and a factory that selects among the formulations. All formulations
|
| 2773 |
share the same external contract:
|
| 2774 |
|
|
@@ -3103,6 +3105,139 @@ def _temporal_overcapacity_loss(
|
|
| 3103 |
return final_loss
|
| 3104 |
|
| 3105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3106 |
# ---------------------------------------------------------------------------
|
| 3107 |
# Factory
|
| 3108 |
# ---------------------------------------------------------------------------
|
|
@@ -3139,11 +3274,32 @@ def _temporal_overcapacity_factory(
|
|
| 3139 |
return _runtime
|
| 3140 |
|
| 3141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3142 |
_LOSS_REGISTRY: dict[str, Callable[..., Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]] = {
|
| 3143 |
"gshard": _gshard_factory,
|
| 3144 |
"ce": _ce_factory,
|
| 3145 |
"bce": _bce_factory,
|
| 3146 |
"temporal_overcapacity": _temporal_overcapacity_factory,
|
|
|
|
| 3147 |
}
|
| 3148 |
|
| 3149 |
|
|
@@ -3163,11 +3319,11 @@ def make_load_balance_loss(
|
|
| 3163 |
|
| 3164 |
Keyword arguments are forwarded to the selected factory. The gshard, ce, and bce
|
| 3165 |
factories silently ignore all kwargs; this allows callers to pass loss-type-specific
|
| 3166 |
-
parameters (e.g. for
|
| 3167 |
|
| 3168 |
Args:
|
| 3169 |
-
loss_type: One of ``"gshard"``, ``"ce"``, ``"bce"``,
|
| 3170 |
-
``"temporal_overcapacity"``.
|
| 3171 |
**loss_parameters: Construction-time parameters forwarded to the factory.
|
| 3172 |
|
| 3173 |
Returns:
|
|
|
|
| 178 |
correctness guard β exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 179 |
Default 10.
|
| 180 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 181 |
+
One of ``"gshard"``, ``"ce"``, ``"bce"``, ``"temporal_overcapacity"``, or
|
| 182 |
+
``"causal_overcapacity"``. ``"causal_overcapacity"`` (the default) attributes
|
| 183 |
+
violations to the causal trajectory that produced them β each expert
|
| 184 |
+
accumulates a running mean of its selection log-probability and the loss
|
| 185 |
+
penalises the gap between overloaded and typical trajectories. Like
|
| 186 |
+
``"temporal_overcapacity"``, it fires only when a violation exists and shuts
|
| 187 |
+
off automatically, making it safe to weight strongly. Default
|
| 188 |
+
``"causal_overcapacity"``.
|
| 189 |
maximum_expert_overclaim: Maximum number of tokens an expert may receive above
|
| 190 |
+
its ideal allocation trajectory before either overcapacity loss fires.
|
| 191 |
+
A value of 0 means violations trigger immediately at any imbalance.
|
| 192 |
Larger values permit short-lived semantic specialization before correction.
|
| 193 |
+
Used by both ``"temporal_overcapacity"`` and ``"causal_overcapacity"``.
|
| 194 |
Must be non-negative. Default 20.
|
| 195 |
"""
|
| 196 |
|
|
|
|
| 226 |
tie_word_embeddings: bool = False,
|
| 227 |
mosrah_overallocation_factor: float = 2.0,
|
| 228 |
max_bid_rounds: int = 10,
|
| 229 |
+
load_balance_loss_type: str = "causal_overcapacity",
|
| 230 |
maximum_expert_overclaim: int = 20,
|
| 231 |
**kwargs
|
| 232 |
):
|
|
|
|
| 274 |
f"got {maximum_expert_overclaim}."
|
| 275 |
)
|
| 276 |
|
| 277 |
+
_supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity", "causal_overcapacity"}
|
| 278 |
if load_balance_loss_type not in _supported_loss_types:
|
| 279 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 280 |
raise ValueError(
|
|
|
|
| 2770 |
# -----------
|
| 2771 |
"""Log-probability auxiliary loss functions for MoSRAH load balancing.
|
| 2772 |
|
| 2773 |
+
This module provides five load-balance loss formulations, two token-reduction
|
| 2774 |
helpers, and a factory that selects among the formulations. All formulations
|
| 2775 |
share the same external contract:
|
| 2776 |
|
|
|
|
| 3105 |
return final_loss
|
| 3106 |
|
| 3107 |
|
| 3108 |
+
def _causal_overcapacity_loss(
|
| 3109 |
+
logits: torch.Tensor,
|
| 3110 |
+
assignment_mask: torch.Tensor,
|
| 3111 |
+
active_mask: torch.Tensor,
|
| 3112 |
+
expected_tokens_rate: float,
|
| 3113 |
+
maximum_expert_overclaim: int,
|
| 3114 |
+
) -> torch.Tensor:
|
| 3115 |
+
"""Causal overcapacity loss for MoSRAH load balancing.
|
| 3116 |
+
|
| 3117 |
+
Penalises selected expert trajectories that exceed their ideal cumulative
|
| 3118 |
+
allocation budget. A selected expert assignment is over capacity when its
|
| 3119 |
+
inclusive active assignment count exceeds cumulative_active_tokens * M + C,
|
| 3120 |
+
where M is the expected_tokens_rate (K/L) and C is the
|
| 3121 |
+
maximum_expert_overclaim slack.
|
| 3122 |
+
|
| 3123 |
+
The loss consumes discrete TopK assignment structure but only routes gradients
|
| 3124 |
+
through logits. It returns an fp32 scalar and is exactly inactive when no active
|
| 3125 |
+
selected expert exceeds its allowed trajectory.
|
| 3126 |
+
|
| 3127 |
+
Args:
|
| 3128 |
+
logits: Pre-softmax routing scores, shape (B, N, L).
|
| 3129 |
+
Gradient flows through this tensor.
|
| 3130 |
+
assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
|
| 3131 |
+
1.0 if token (b, n) is assigned to head l.
|
| 3132 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 3133 |
+
expected_tokens_rate (M): Ideal per-head allocation rate K/L. Pre-computed
|
| 3134 |
+
by the factory so the division is not repeated each
|
| 3135 |
+
forward pass.
|
| 3136 |
+
maximum_expert_overclaim (C): Slack above the ideal trajectory before
|
| 3137 |
+
overcapacity fires. Larger C tolerates more deviation.
|
| 3138 |
+
|
| 3139 |
+
Returns:
|
| 3140 |
+
Scalar fp32 loss tensor. Exactly 0.0 when no active selected expert exceeds
|
| 3141 |
+
its allowed trajectory. Can be interpreted as the difference in nats of preference
|
| 3142 |
+
between the violating and typical paths.
|
| 3143 |
+
"""
|
| 3144 |
+
# ββ Algorithm overview ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3145 |
+
#
|
| 3146 |
+
# Expert selections form causal trajectories through the sequence. Each trajectory
|
| 3147 |
+
# is scored by the mean signed nats of the selected routing events that produced
|
| 3148 |
+
# it: larger trajectory nats mean the router preferred that path more strongly.
|
| 3149 |
+
#
|
| 3150 |
+
# When a selected trajectory exceeds its cumulative budget, the loss forms a
|
| 3151 |
+
# preference contrast between the violating trajectory field and the baseline
|
| 3152 |
+
# trajectory field. Minimizing that contrast suppresses the over-preferred path
|
| 3153 |
+
# while lifting alternatives through the router softmax.
|
| 3154 |
+
#
|
| 3155 |
+
# This is not precisely equivalent to log likihood due to the selection
|
| 3156 |
+
# of multiple experts per round, but we deem this issue to be insignificant.
|
| 3157 |
+
|
| 3158 |
+
# ββ Process setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3159 |
+
#
|
| 3160 |
+
# A small amount of standardization is needed before the loss-specific trajectory
|
| 3161 |
+
# logic begins. Active selected assignments define the event structure. Routing
|
| 3162 |
+
# log-probabilities remain the only differentiable source and are computed in fp32
|
| 3163 |
+
# so the downstream trajectory accumulation does not inherit reduced precision.
|
| 3164 |
+
|
| 3165 |
+
selected_assignment_mask = assignment_mask.bool() # (B, N, L)
|
| 3166 |
+
active_assignment_mask = selected_assignment_mask & active_mask.unsqueeze(-1) # (B, N, L)
|
| 3167 |
+
routing_log_probability = F.log_softmax(logits.float(), dim=-1) # (B, N, L)
|
| 3168 |
+
|
| 3169 |
+
# ββ Mask construction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3170 |
+
#
|
| 3171 |
+
# The corrective target set is defined by active selected assignments whose
|
| 3172 |
+
# inclusive count crosses the allowed causal budget. Position and sequence masks
|
| 3173 |
+
# identify where that target set exists; they are reduction structure, not a
|
| 3174 |
+
# separate source of gradient.
|
| 3175 |
+
|
| 3176 |
+
inclusive_assignment_count = active_assignment_mask.to(torch.int32).cumsum(dim=1) # (B, N, L)
|
| 3177 |
+
inclusive_active_token_count = active_mask.to(torch.int32).cumsum(dim=1) # (B, N)
|
| 3178 |
+
|
| 3179 |
+
maximum_allowed_assignment_count = (
|
| 3180 |
+
inclusive_active_token_count.float().unsqueeze(-1) * expected_tokens_rate
|
| 3181 |
+
+ maximum_expert_overclaim
|
| 3182 |
+
) # (B, N, 1) β broadcasts to (B, N, L)
|
| 3183 |
+
|
| 3184 |
+
violating_assignment_mask = ( # (B, N, L)
|
| 3185 |
+
active_assignment_mask
|
| 3186 |
+
& (inclusive_assignment_count.float() > maximum_allowed_assignment_count)
|
| 3187 |
+
)
|
| 3188 |
+
has_violation_at_position = violating_assignment_mask.any(dim=-1) # (B, N)
|
| 3189 |
+
has_violation_in_sequence = has_violation_at_position.any(dim=-1) # (B,)
|
| 3190 |
+
|
| 3191 |
+
# ββ Trajectory construction ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3192 |
+
#
|
| 3193 |
+
# The current selection is part of the trajectory being judged, so the trajectory
|
| 3194 |
+
# score is inclusive. Empty histories intentionally receive the neutral zero score;
|
| 3195 |
+
# this keeps the later baseline compact without introducing a second eligibility
|
| 3196 |
+
# system.
|
| 3197 |
+
|
| 3198 |
+
selected_trajectory_nat_sum = ( # (B, N, L)
|
| 3199 |
+
active_assignment_mask.float() * routing_log_probability
|
| 3200 |
+
).cumsum(dim=1)
|
| 3201 |
+
mean_selected_trajectory_nats = ( # (B, N, L)
|
| 3202 |
+
selected_trajectory_nat_sum
|
| 3203 |
+
/ inclusive_assignment_count.clamp(min=1).float()
|
| 3204 |
+
)
|
| 3205 |
+
|
| 3206 |
+
# ββ Contrast construction ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3207 |
+
#
|
| 3208 |
+
# This is the correction moment. The violating trajectory field is compared to
|
| 3209 |
+
# the baseline trajectory field at the same sequence position, producing a signed
|
| 3210 |
+
# preference contrast measured in nats.
|
| 3211 |
+
|
| 3212 |
+
violating_assignment_count = violating_assignment_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N)
|
| 3213 |
+
mean_violating_trajectory_nats = ( # (B, N)
|
| 3214 |
+
(violating_assignment_mask.float() * mean_selected_trajectory_nats).sum(dim=-1)
|
| 3215 |
+
/ violating_assignment_count
|
| 3216 |
+
)
|
| 3217 |
+
mean_baseline_trajectory_nats = mean_selected_trajectory_nats.mean(dim=-1) # (B, N)
|
| 3218 |
+
contrastive_preference_nats = ( # (B, N)
|
| 3219 |
+
mean_violating_trajectory_nats
|
| 3220 |
+
- mean_baseline_trajectory_nats
|
| 3221 |
+
)
|
| 3222 |
+
|
| 3223 |
+
# ββ Violation-only reduction βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3224 |
+
#
|
| 3225 |
+
# Non-violating positions and sequences are not anchors for this loss. The scalar
|
| 3226 |
+
# is an average violation contrast, not total violation mass, and the entire loss
|
| 3227 |
+
# remains exactly inactive when no corrective target exists.
|
| 3228 |
+
|
| 3229 |
+
violation_position_count = has_violation_at_position.float().sum(dim=-1).clamp(min=1.0) # (B,)
|
| 3230 |
+
sequence_preference_nats = ( # (B,)
|
| 3231 |
+
(contrastive_preference_nats * has_violation_at_position.float()).sum(dim=-1)
|
| 3232 |
+
/ violation_position_count
|
| 3233 |
+
)
|
| 3234 |
+
violating_sequence_count = has_violation_in_sequence.float().sum().clamp(min=1.0) # scalar
|
| 3235 |
+
final_loss = ( # scalar
|
| 3236 |
+
sequence_preference_nats * has_violation_in_sequence.float()
|
| 3237 |
+
).sum() / violating_sequence_count
|
| 3238 |
+
return final_loss
|
| 3239 |
+
|
| 3240 |
+
|
| 3241 |
# ---------------------------------------------------------------------------
|
| 3242 |
# Factory
|
| 3243 |
# ---------------------------------------------------------------------------
|
|
|
|
| 3274 |
return _runtime
|
| 3275 |
|
| 3276 |
|
| 3277 |
+
def _causal_overcapacity_factory(
|
| 3278 |
+
num_selected_heads: int,
|
| 3279 |
+
num_total_heads: int,
|
| 3280 |
+
maximum_expert_overclaim: int,
|
| 3281 |
+
**kwargs: object,
|
| 3282 |
+
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3283 |
+
expected_tokens_rate = num_selected_heads / num_total_heads
|
| 3284 |
+
def _runtime(
|
| 3285 |
+
logits: torch.Tensor,
|
| 3286 |
+
assignment_mask: torch.Tensor,
|
| 3287 |
+
active_mask: torch.Tensor,
|
| 3288 |
+
) -> torch.Tensor:
|
| 3289 |
+
return _causal_overcapacity_loss(
|
| 3290 |
+
logits, assignment_mask, active_mask,
|
| 3291 |
+
expected_tokens_rate=expected_tokens_rate,
|
| 3292 |
+
maximum_expert_overclaim=maximum_expert_overclaim,
|
| 3293 |
+
)
|
| 3294 |
+
return _runtime
|
| 3295 |
+
|
| 3296 |
+
|
| 3297 |
_LOSS_REGISTRY: dict[str, Callable[..., Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]] = {
|
| 3298 |
"gshard": _gshard_factory,
|
| 3299 |
"ce": _ce_factory,
|
| 3300 |
"bce": _bce_factory,
|
| 3301 |
"temporal_overcapacity": _temporal_overcapacity_factory,
|
| 3302 |
+
"causal_overcapacity": _causal_overcapacity_factory,
|
| 3303 |
}
|
| 3304 |
|
| 3305 |
|
|
|
|
| 3319 |
|
| 3320 |
Keyword arguments are forwarded to the selected factory. The gshard, ce, and bce
|
| 3321 |
factories silently ignore all kwargs; this allows callers to pass loss-type-specific
|
| 3322 |
+
parameters (e.g. for overcapacity losses) without branching on loss_type.
|
| 3323 |
|
| 3324 |
Args:
|
| 3325 |
+
loss_type: One of ``"gshard"``, ``"ce"``, ``"bce"``,
|
| 3326 |
+
``"temporal_overcapacity"``, or ``"causal_overcapacity"``.
|
| 3327 |
**loss_parameters: Construction-time parameters forwarded to the factory.
|
| 3328 |
|
| 3329 |
Returns:
|