Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- config.json +1 -1
- huggingface.py +108 -32
config.json
CHANGED
|
@@ -24,7 +24,7 @@
|
|
| 24 |
"rope_mode": "main_sequence",
|
| 25 |
"tie_word_embeddings": false,
|
| 26 |
"training_sequence_length": 1024,
|
| 27 |
-
"transformers_version": "5.10.
|
| 28 |
"use_cache": true,
|
| 29 |
"vocab_size": 50277,
|
| 30 |
"window_size": 128
|
|
|
|
| 24 |
"rope_mode": "main_sequence",
|
| 25 |
"tie_word_embeddings": false,
|
| 26 |
"training_sequence_length": 1024,
|
| 27 |
+
"transformers_version": "5.10.2",
|
| 28 |
"use_cache": true,
|
| 29 |
"vocab_size": 50277,
|
| 30 |
"window_size": 128
|
huggingface.py
CHANGED
|
@@ -1458,6 +1458,10 @@ Returns a plain dict with keys:
|
|
| 1458 |
- "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
|
| 1459 |
- "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
|
| 1460 |
- "max_vio": detached scalar maximum routing-imbalance across all decoder layers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1461 |
"""
|
| 1462 |
|
| 1463 |
|
|
@@ -1474,7 +1478,7 @@ Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
|
|
| 1474 |
gated residual connections around both sublayers:
|
| 1475 |
|
| 1476 |
normed_attn = RMSNorm(x)
|
| 1477 |
-
attn_out,
|
| 1478 |
h = x + residual_gate * attn_out
|
| 1479 |
|
| 1480 |
normed_mlp = RMSNorm(h)
|
|
@@ -3094,7 +3098,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 3094 |
x: torch.Tensor,
|
| 3095 |
active_mask: torch.Tensor,
|
| 3096 |
used_capacity: torch.Tensor | None
|
| 3097 |
-
) -> tuple[torch.Tensor, torch.Tensor,
|
| 3098 |
"""Route input tokens to K expert heads each and compute routing probabilities.
|
| 3099 |
|
| 3100 |
Args:
|
|
@@ -3103,17 +3107,23 @@ class MoSRAHRouter(nn.Module):
|
|
| 3103 |
True means the token is semantically live. Dead tokens do not
|
| 3104 |
contribute to routing frequencies, load_balance_loss, or max_vio.
|
| 3105 |
used_capacity: Used for capacity management during inference, missing during training.
|
|
|
|
| 3106 |
Returns:
|
| 3107 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 3108 |
Each token's K selected head indices, determined by TopK on biased scores.
|
| 3109 |
routing_probs: Routing probabilities P of shape (batch, seq_len,
|
| 3110 |
num_selected_heads). Gathered from unbiased scores at selected_heads
|
| 3111 |
indices and renormalized to sum to 1 per token.
|
| 3112 |
-
|
| 3113 |
-
|
| 3114 |
-
|
| 3115 |
-
|
| 3116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3117 |
"""
|
| 3118 |
B, N, _ = x.shape
|
| 3119 |
L = self.num_mosrah_heads
|
|
@@ -3122,6 +3132,17 @@ class MoSRAHRouter(nn.Module):
|
|
| 3122 |
# Unbiased routing scores R = Softmax(xW_r). These are the scores used to
|
| 3123 |
# compute routing_probs — expert_bias must not influence them.
|
| 3124 |
logits = self.routing_projection(x) # (B, N, L)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3125 |
routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
|
| 3126 |
|
| 3127 |
# Biased routing scores RÌ‚ = Softmax(xW_r + b). Used only for TopK head
|
|
@@ -3177,7 +3198,15 @@ class MoSRAHRouter(nn.Module):
|
|
| 3177 |
# L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
|
| 3178 |
max_vio = self._compute_max_vio(routing_freqs, L)
|
| 3179 |
|
| 3180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3181 |
|
| 3182 |
@staticmethod
|
| 3183 |
def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
|
|
@@ -3322,8 +3351,9 @@ class MoSRAHLayer(nn.Module):
|
|
| 3322 |
|
| 3323 |
The MoSRAH path consumes model-space hidden states together with
|
| 3324 |
authoritative per-token positions and returns the model-space sparse-path
|
| 3325 |
-
contribution
|
| 3326 |
-
routing-imbalance scalar
|
|
|
|
| 3327 |
"""
|
| 3328 |
|
| 3329 |
def __init__(self, config: ShramConfig) -> None:
|
|
@@ -3348,7 +3378,7 @@ class MoSRAHLayer(nn.Module):
|
|
| 3348 |
position_ids: torch.Tensor,
|
| 3349 |
active_mask: torch.Tensor,
|
| 3350 |
cache: MoSRAHCache | None,
|
| 3351 |
-
) -> tuple[torch.Tensor,
|
| 3352 |
"""Run the full MoSRAH sparse path.
|
| 3353 |
|
| 3354 |
Args:
|
|
@@ -3364,9 +3394,10 @@ class MoSRAHLayer(nn.Module):
|
|
| 3364 |
|
| 3365 |
Returns:
|
| 3366 |
sparse_output: Model-space sparse-path output of shape (B, N, d).
|
| 3367 |
-
|
| 3368 |
-
|
| 3369 |
-
|
|
|
|
| 3370 |
"""
|
| 3371 |
|
| 3372 |
# -------------------------------------------------------------------
|
|
@@ -3381,7 +3412,7 @@ class MoSRAHLayer(nn.Module):
|
|
| 3381 |
# active_mask is rebound to the packed form after this point.
|
| 3382 |
# -------------------------------------------------------------------
|
| 3383 |
used_capacity = cache.get_heads_lengths() if cache is not None else None
|
| 3384 |
-
selected_heads, routing_probs,
|
| 3385 |
hidden_states, active_mask, used_capacity
|
| 3386 |
)
|
| 3387 |
|
|
@@ -3434,7 +3465,7 @@ class MoSRAHLayer(nn.Module):
|
|
| 3434 |
token_choice_outputs * routing_probs.unsqueeze(-1)
|
| 3435 |
).sum(dim=2)
|
| 3436 |
|
| 3437 |
-
return final_output,
|
| 3438 |
|
| 3439 |
|
| 3440 |
|
|
@@ -3463,7 +3494,7 @@ class SHRAMHybridLayer(nn.Module):
|
|
| 3463 |
position_ids: torch.Tensor,
|
| 3464 |
active_mask: torch.Tensor,
|
| 3465 |
cache: ShramLayerCache | None,
|
| 3466 |
-
) -> tuple[torch.Tensor,
|
| 3467 |
"""Apply the SHRAM hybrid attention layer.
|
| 3468 |
|
| 3469 |
Args:
|
|
@@ -3478,8 +3509,7 @@ class SHRAMHybridLayer(nn.Module):
|
|
| 3478 |
|
| 3479 |
Returns:
|
| 3480 |
hybrid_output: Model-space hybrid attention output of shape (B, N, d).
|
| 3481 |
-
|
| 3482 |
-
max_vio: Detached scalar routing-imbalance summary. Passed through
|
| 3483 |
unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
|
| 3484 |
"""
|
| 3485 |
# -------------------------------------------------------------------
|
|
@@ -3507,7 +3537,7 @@ class SHRAMHybridLayer(nn.Module):
|
|
| 3507 |
active_mask=active_mask,
|
| 3508 |
cache=sliding_window_cache,
|
| 3509 |
)
|
| 3510 |
-
sparse_output,
|
| 3511 |
hidden_states=hidden_states,
|
| 3512 |
position_ids=position_ids,
|
| 3513 |
active_mask=active_mask,
|
|
@@ -3522,7 +3552,7 @@ class SHRAMHybridLayer(nn.Module):
|
|
| 3522 |
# -------------------------------------------------------------------
|
| 3523 |
hybrid_output = local_output + sparse_output
|
| 3524 |
|
| 3525 |
-
return hybrid_output,
|
| 3526 |
|
| 3527 |
|
| 3528 |
# -----------
|
|
@@ -3612,7 +3642,7 @@ class DecoderLayer(nn.Module):
|
|
| 3612 |
position_ids: torch.Tensor,
|
| 3613 |
active_mask: torch.Tensor,
|
| 3614 |
cache: ShramLayerCache | None = None,
|
| 3615 |
-
) -> tuple[torch.Tensor,
|
| 3616 |
"""Apply one decoder block to the input.
|
| 3617 |
|
| 3618 |
Args:
|
|
@@ -3626,12 +3656,10 @@ class DecoderLayer(nn.Module):
|
|
| 3626 |
|
| 3627 |
Returns:
|
| 3628 |
output: Tensor of shape (batch, seq_len, hidden_size).
|
| 3629 |
-
|
| 3630 |
-
from SHRAMHybridLayer.
|
| 3631 |
-
max_vio: Detached scalar routing-imbalance summary. Passed through
|
| 3632 |
unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
|
| 3633 |
"""
|
| 3634 |
-
attn_out,
|
| 3635 |
hidden_states=self.attn_norm(x),
|
| 3636 |
position_ids=position_ids,
|
| 3637 |
active_mask=active_mask,
|
|
@@ -3639,7 +3667,7 @@ class DecoderLayer(nn.Module):
|
|
| 3639 |
)
|
| 3640 |
hidden_states = x + self.residual_gate*attn_out
|
| 3641 |
output = hidden_states + self.residual_gate*self.mlp(self.mlp_norm(hidden_states))
|
| 3642 |
-
return output,
|
| 3643 |
|
| 3644 |
|
| 3645 |
class ShramModel(nn.Module):
|
|
@@ -3708,27 +3736,51 @@ class ShramModel(nn.Module):
|
|
| 3708 |
- ``"max_vio"``: detached scalar maximum routing-imbalance across
|
| 3709 |
all decoder layers. Zero means perfectly balanced routing across
|
| 3710 |
every layer; higher values identify the worst-case head imbalance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3711 |
"""
|
| 3712 |
hidden_states = inputs_embeds
|
| 3713 |
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 3714 |
total_load_balance_loss = inputs_embeds.new_zeros(())
|
| 3715 |
max_vio = inputs_embeds.new_zeros(())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3716 |
|
| 3717 |
for layer_idx, layer in enumerate(self.layers):
|
| 3718 |
layer_cache = None if cache is None else cache.layers[layer_idx]
|
| 3719 |
-
hidden_states,
|
| 3720 |
hidden_states,
|
| 3721 |
position_ids,
|
| 3722 |
active_mask,
|
| 3723 |
cache=layer_cache,
|
| 3724 |
)
|
| 3725 |
-
total_load_balance_loss = total_load_balance_loss +
|
| 3726 |
-
max_vio = torch.maximum(max_vio,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3727 |
|
| 3728 |
if output_hidden_states:
|
| 3729 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 3730 |
|
| 3731 |
hidden_states = self.norm(hidden_states)
|
|
|
|
| 3732 |
|
| 3733 |
return {
|
| 3734 |
"last_hidden_state": hidden_states,
|
|
@@ -3736,6 +3788,10 @@ class ShramModel(nn.Module):
|
|
| 3736 |
"hidden_states": all_hidden_states,
|
| 3737 |
"load_balance_loss": total_load_balance_loss,
|
| 3738 |
"max_vio": max_vio,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3739 |
}
|
| 3740 |
|
| 3741 |
|
|
@@ -3749,10 +3805,20 @@ class ShramCausalLMOutput(CausalLMOutputWithPast):
|
|
| 3749 |
only the SHRAM-specific wrapper outputs.
|
| 3750 |
"""
|
| 3751 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3752 |
ce_loss: torch.FloatTensor | None = None
|
| 3753 |
load_balance_loss: torch.FloatTensor | None = None
|
| 3754 |
max_vio: torch.FloatTensor | None = None
|
| 3755 |
-
|
|
|
|
|
|
|
|
|
|
| 3756 |
|
| 3757 |
class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
| 3758 |
"""HuggingFace-facing causal language model wrapper for SHRAM.
|
|
@@ -4181,6 +4247,9 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4181 |
output_hidden_states: Whether to return backbone hidden states.
|
| 4182 |
Defaults to ``config.output_hidden_states``.
|
| 4183 |
labels: Optional target token IDs of shape ``(batch, seq_len)``.
|
|
|
|
|
|
|
|
|
|
| 4184 |
return_dict: Must be ``True`` or ``None``.
|
| 4185 |
ce_weight: Weight applied to the cross-entropy loss when combining with
|
| 4186 |
the load-balance loss. Default 1.0.
|
|
@@ -4197,7 +4266,10 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4197 |
- ``past_key_values`` as the active ``ShramCache`` or ``None``,
|
| 4198 |
- ``hidden_states`` when requested,
|
| 4199 |
- ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
|
| 4200 |
-
-
|
|
|
|
|
|
|
|
|
|
| 4201 |
"""
|
| 4202 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 4203 |
output_hidden_states = (
|
|
@@ -4304,4 +4376,8 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4304 |
hidden_states=backbone_outputs["hidden_states"],
|
| 4305 |
load_balance_loss=backbone_outputs["load_balance_loss"],
|
| 4306 |
max_vio=backbone_outputs["max_vio"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4307 |
)
|
|
|
|
| 1458 |
- "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
|
| 1459 |
- "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
|
| 1460 |
- "max_vio": detached scalar maximum routing-imbalance across all decoder layers
|
| 1461 |
+
- "bias_std": detached scalar mean per-layer std of the expert bias vector
|
| 1462 |
+
- "raw_logit_std": detached scalar mean per-layer per-token routing logit spread
|
| 1463 |
+
- "logit_std": detached scalar mean per-layer per-token combined (logit + bias) spread
|
| 1464 |
+
- "bias_alignment": detached scalar mean per-layer cosine similarity of bias vs logits
|
| 1465 |
"""
|
| 1466 |
|
| 1467 |
|
|
|
|
| 1478 |
gated residual connections around both sublayers:
|
| 1479 |
|
| 1480 |
normed_attn = RMSNorm(x)
|
| 1481 |
+
attn_out, router_diagnostics = SHRAMHybridLayer(normed_attn, ...)
|
| 1482 |
h = x + residual_gate * attn_out
|
| 1483 |
|
| 1484 |
normed_mlp = RMSNorm(h)
|
|
|
|
| 3098 |
x: torch.Tensor,
|
| 3099 |
active_mask: torch.Tensor,
|
| 3100 |
used_capacity: torch.Tensor | None
|
| 3101 |
+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
|
| 3102 |
"""Route input tokens to K expert heads each and compute routing probabilities.
|
| 3103 |
|
| 3104 |
Args:
|
|
|
|
| 3107 |
True means the token is semantically live. Dead tokens do not
|
| 3108 |
contribute to routing frequencies, load_balance_loss, or max_vio.
|
| 3109 |
used_capacity: Used for capacity management during inference, missing during training.
|
| 3110 |
+
|
| 3111 |
Returns:
|
| 3112 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 3113 |
Each token's K selected head indices, determined by TopK on biased scores.
|
| 3114 |
routing_probs: Routing probabilities P of shape (batch, seq_len,
|
| 3115 |
num_selected_heads). Gathered from unbiased scores at selected_heads
|
| 3116 |
indices and renormalized to sum to 1 per token.
|
| 3117 |
+
router_diagnostics: Dict of routing feedback scalars. Keys:
|
| 3118 |
+
- ``load_balance_loss``: scalar load-balance loss with gradient.
|
| 3119 |
+
- ``max_vio``: detached scalar routing-imbalance summary.
|
| 3120 |
+
- ``bias_std``: std of expert_bias; near-zero means corrections have not built up.
|
| 3121 |
+
- ``raw_logit_std``: mean per-token std of unbiased logits; the natural routing scale.
|
| 3122 |
+
- ``logit_std``: mean per-token std of (logits + expert_bias); lower than
|
| 3123 |
+
raw_logit_std means bias is flattening preferences (healthy correction).
|
| 3124 |
+
- ``bias_alignment``: mean cosine similarity of expert_bias against per-token
|
| 3125 |
+
logits. Negative means bias opposes routing direction (healthy correction);
|
| 3126 |
+
positive means runaway reinforcement.
|
| 3127 |
"""
|
| 3128 |
B, N, _ = x.shape
|
| 3129 |
L = self.num_mosrah_heads
|
|
|
|
| 3132 |
# Unbiased routing scores R = Softmax(xW_r). These are the scores used to
|
| 3133 |
# compute routing_probs — expert_bias must not influence them.
|
| 3134 |
logits = self.routing_projection(x) # (B, N, L)
|
| 3135 |
+
|
| 3136 |
+
# Diagnostic scalars characterising the load-balance mechanism. Must be
|
| 3137 |
+
# computed here — before balance_capacity injects -1e8 sentinels that
|
| 3138 |
+
# would corrupt std and cosine similarity.
|
| 3139 |
+
bias_std = self.expert_bias.std().detach()
|
| 3140 |
+
raw_logit_std = logits.std(dim=-1).mean().detach()
|
| 3141 |
+
logit_std = (logits + self.expert_bias).std(dim=-1).mean().detach()
|
| 3142 |
+
bias_alignment = F.cosine_similarity(
|
| 3143 |
+
logits, self.expert_bias.expand_as(logits), dim=-1
|
| 3144 |
+
).mean().detach()
|
| 3145 |
+
|
| 3146 |
routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
|
| 3147 |
|
| 3148 |
# Biased routing scores RÌ‚ = Softmax(xW_r + b). Used only for TopK head
|
|
|
|
| 3198 |
# L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
|
| 3199 |
max_vio = self._compute_max_vio(routing_freqs, L)
|
| 3200 |
|
| 3201 |
+
router_diagnostics = {
|
| 3202 |
+
"load_balance_loss": load_balance_loss,
|
| 3203 |
+
"max_vio": max_vio,
|
| 3204 |
+
"bias_std": bias_std,
|
| 3205 |
+
"raw_logit_std": raw_logit_std,
|
| 3206 |
+
"logit_std": logit_std,
|
| 3207 |
+
"bias_alignment": bias_alignment,
|
| 3208 |
+
}
|
| 3209 |
+
return selected_heads, routing_probs, router_diagnostics
|
| 3210 |
|
| 3211 |
@staticmethod
|
| 3212 |
def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
|
|
|
|
| 3351 |
|
| 3352 |
The MoSRAH path consumes model-space hidden states together with
|
| 3353 |
authoritative per-token positions and returns the model-space sparse-path
|
| 3354 |
+
contribution and a diagnostics dict from the router containing
|
| 3355 |
+
load-balance loss, routing-imbalance scalar, and load-balance health
|
| 3356 |
+
scalars.
|
| 3357 |
"""
|
| 3358 |
|
| 3359 |
def __init__(self, config: ShramConfig) -> None:
|
|
|
|
| 3378 |
position_ids: torch.Tensor,
|
| 3379 |
active_mask: torch.Tensor,
|
| 3380 |
cache: MoSRAHCache | None,
|
| 3381 |
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
| 3382 |
"""Run the full MoSRAH sparse path.
|
| 3383 |
|
| 3384 |
Args:
|
|
|
|
| 3394 |
|
| 3395 |
Returns:
|
| 3396 |
sparse_output: Model-space sparse-path output of shape (B, N, d).
|
| 3397 |
+
router_diagnostics: Dict of router feedback scalars. Keys:
|
| 3398 |
+
``load_balance_loss`` (has grad), ``max_vio``, ``bias_std``,
|
| 3399 |
+
``raw_logit_std``, ``logit_std``, ``bias_alignment`` (all
|
| 3400 |
+
detached). See MoSRAHRouter for semantics.
|
| 3401 |
"""
|
| 3402 |
|
| 3403 |
# -------------------------------------------------------------------
|
|
|
|
| 3412 |
# active_mask is rebound to the packed form after this point.
|
| 3413 |
# -------------------------------------------------------------------
|
| 3414 |
used_capacity = cache.get_heads_lengths() if cache is not None else None
|
| 3415 |
+
selected_heads, routing_probs, router_diagnostics = self.router(
|
| 3416 |
hidden_states, active_mask, used_capacity
|
| 3417 |
)
|
| 3418 |
|
|
|
|
| 3465 |
token_choice_outputs * routing_probs.unsqueeze(-1)
|
| 3466 |
).sum(dim=2)
|
| 3467 |
|
| 3468 |
+
return final_output, router_diagnostics
|
| 3469 |
|
| 3470 |
|
| 3471 |
|
|
|
|
| 3494 |
position_ids: torch.Tensor,
|
| 3495 |
active_mask: torch.Tensor,
|
| 3496 |
cache: ShramLayerCache | None,
|
| 3497 |
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
| 3498 |
"""Apply the SHRAM hybrid attention layer.
|
| 3499 |
|
| 3500 |
Args:
|
|
|
|
| 3509 |
|
| 3510 |
Returns:
|
| 3511 |
hybrid_output: Model-space hybrid attention output of shape (B, N, d).
|
| 3512 |
+
router_diagnostics: Dict of router feedback scalars passed through
|
|
|
|
| 3513 |
unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
|
| 3514 |
"""
|
| 3515 |
# -------------------------------------------------------------------
|
|
|
|
| 3537 |
active_mask=active_mask,
|
| 3538 |
cache=sliding_window_cache,
|
| 3539 |
)
|
| 3540 |
+
sparse_output, router_diagnostics = self.sparse_attention(
|
| 3541 |
hidden_states=hidden_states,
|
| 3542 |
position_ids=position_ids,
|
| 3543 |
active_mask=active_mask,
|
|
|
|
| 3552 |
# -------------------------------------------------------------------
|
| 3553 |
hybrid_output = local_output + sparse_output
|
| 3554 |
|
| 3555 |
+
return hybrid_output, router_diagnostics
|
| 3556 |
|
| 3557 |
|
| 3558 |
# -----------
|
|
|
|
| 3642 |
position_ids: torch.Tensor,
|
| 3643 |
active_mask: torch.Tensor,
|
| 3644 |
cache: ShramLayerCache | None = None,
|
| 3645 |
+
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
| 3646 |
"""Apply one decoder block to the input.
|
| 3647 |
|
| 3648 |
Args:
|
|
|
|
| 3656 |
|
| 3657 |
Returns:
|
| 3658 |
output: Tensor of shape (batch, seq_len, hidden_size).
|
| 3659 |
+
router_diagnostics: Dict of router feedback scalars passed through
|
|
|
|
|
|
|
| 3660 |
unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
|
| 3661 |
"""
|
| 3662 |
+
attn_out, router_diagnostics = self.attention(
|
| 3663 |
hidden_states=self.attn_norm(x),
|
| 3664 |
position_ids=position_ids,
|
| 3665 |
active_mask=active_mask,
|
|
|
|
| 3667 |
)
|
| 3668 |
hidden_states = x + self.residual_gate*attn_out
|
| 3669 |
output = hidden_states + self.residual_gate*self.mlp(self.mlp_norm(hidden_states))
|
| 3670 |
+
return output, router_diagnostics
|
| 3671 |
|
| 3672 |
|
| 3673 |
class ShramModel(nn.Module):
|
|
|
|
| 3736 |
- ``"max_vio"``: detached scalar maximum routing-imbalance across
|
| 3737 |
all decoder layers. Zero means perfectly balanced routing across
|
| 3738 |
every layer; higher values identify the worst-case head imbalance.
|
| 3739 |
+
- ``"bias_std"``: detached scalar — mean across layers of the std
|
| 3740 |
+
of each layer's expert bias vector. Near-zero means corrections
|
| 3741 |
+
have not built up; large relative to ``raw_logit_std`` means the
|
| 3742 |
+
bias dominates routing.
|
| 3743 |
+
- ``"raw_logit_std"``: detached scalar — mean across layers of the
|
| 3744 |
+
per-token routing logit spread before bias addition. Baseline
|
| 3745 |
+
natural routing preference scale.
|
| 3746 |
+
- ``"logit_std"``: detached scalar — mean across layers of the
|
| 3747 |
+
per-token combined (logit + bias) spread. Lower than
|
| 3748 |
+
``raw_logit_std`` indicates healthy flattening; higher indicates
|
| 3749 |
+
amplification.
|
| 3750 |
+
- ``"bias_alignment"``: detached scalar — mean across layers of the
|
| 3751 |
+
per-token cosine similarity between the expert bias vector and the
|
| 3752 |
+
routing logits. Negative is healthy correction; positive is
|
| 3753 |
+
runaway feedback.
|
| 3754 |
"""
|
| 3755 |
hidden_states = inputs_embeds
|
| 3756 |
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 3757 |
total_load_balance_loss = inputs_embeds.new_zeros(())
|
| 3758 |
max_vio = inputs_embeds.new_zeros(())
|
| 3759 |
+
total_bias_std = inputs_embeds.new_zeros(())
|
| 3760 |
+
total_raw_logit_std = inputs_embeds.new_zeros(())
|
| 3761 |
+
total_logit_std = inputs_embeds.new_zeros(())
|
| 3762 |
+
total_bias_alignment = inputs_embeds.new_zeros(())
|
| 3763 |
|
| 3764 |
for layer_idx, layer in enumerate(self.layers):
|
| 3765 |
layer_cache = None if cache is None else cache.layers[layer_idx]
|
| 3766 |
+
hidden_states, layer_diagnostics = layer(
|
| 3767 |
hidden_states,
|
| 3768 |
position_ids,
|
| 3769 |
active_mask,
|
| 3770 |
cache=layer_cache,
|
| 3771 |
)
|
| 3772 |
+
total_load_balance_loss = total_load_balance_loss + layer_diagnostics["load_balance_loss"]
|
| 3773 |
+
max_vio = torch.maximum(max_vio, layer_diagnostics["max_vio"])
|
| 3774 |
+
total_bias_std = total_bias_std + layer_diagnostics["bias_std"]
|
| 3775 |
+
total_raw_logit_std = total_raw_logit_std + layer_diagnostics["raw_logit_std"]
|
| 3776 |
+
total_logit_std = total_logit_std + layer_diagnostics["logit_std"]
|
| 3777 |
+
total_bias_alignment = total_bias_alignment + layer_diagnostics["bias_alignment"]
|
| 3778 |
|
| 3779 |
if output_hidden_states:
|
| 3780 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 3781 |
|
| 3782 |
hidden_states = self.norm(hidden_states)
|
| 3783 |
+
num_layers = len(self.layers)
|
| 3784 |
|
| 3785 |
return {
|
| 3786 |
"last_hidden_state": hidden_states,
|
|
|
|
| 3788 |
"hidden_states": all_hidden_states,
|
| 3789 |
"load_balance_loss": total_load_balance_loss,
|
| 3790 |
"max_vio": max_vio,
|
| 3791 |
+
"bias_std": total_bias_std / num_layers,
|
| 3792 |
+
"raw_logit_std": total_raw_logit_std / num_layers,
|
| 3793 |
+
"logit_std": total_logit_std / num_layers,
|
| 3794 |
+
"bias_alignment": total_bias_alignment / num_layers,
|
| 3795 |
}
|
| 3796 |
|
| 3797 |
|
|
|
|
| 3805 |
only the SHRAM-specific wrapper outputs.
|
| 3806 |
"""
|
| 3807 |
|
| 3808 |
+
## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all
|
| 3809 |
+
## fields to None, which forces every subclass field to also carry a default.
|
| 3810 |
+
## The = None below is a language constraint, not a semantic statement. In
|
| 3811 |
+
## practice, load_balance_loss, max_vio, bias_std, raw_logit_std, logit_std,
|
| 3812 |
+
## and bias_alignment are always populated by ShramForCausalLM.forward().
|
| 3813 |
+
## ce_loss is genuinely optional — present only when labels are supplied.
|
| 3814 |
+
|
| 3815 |
ce_loss: torch.FloatTensor | None = None
|
| 3816 |
load_balance_loss: torch.FloatTensor | None = None
|
| 3817 |
max_vio: torch.FloatTensor | None = None
|
| 3818 |
+
bias_std: torch.Tensor | None = None
|
| 3819 |
+
raw_logit_std: torch.Tensor | None = None
|
| 3820 |
+
logit_std: torch.Tensor | None = None
|
| 3821 |
+
bias_alignment: torch.Tensor | None = None
|
| 3822 |
|
| 3823 |
class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
| 3824 |
"""HuggingFace-facing causal language model wrapper for SHRAM.
|
|
|
|
| 4247 |
output_hidden_states: Whether to return backbone hidden states.
|
| 4248 |
Defaults to ``config.output_hidden_states``.
|
| 4249 |
labels: Optional target token IDs of shape ``(batch, seq_len)``.
|
| 4250 |
+
Pass unshifted labels (same alignment as ``input_ids``). This
|
| 4251 |
+
wrapper shifts internally: ``logits[:, :-1]`` is compared
|
| 4252 |
+
against ``labels[:, 1:]``. Do not pre-shift the caller side.
|
| 4253 |
return_dict: Must be ``True`` or ``None``.
|
| 4254 |
ce_weight: Weight applied to the cross-entropy loss when combining with
|
| 4255 |
the load-balance loss. Default 1.0.
|
|
|
|
| 4266 |
- ``past_key_values`` as the active ``ShramCache`` or ``None``,
|
| 4267 |
- ``hidden_states`` when requested,
|
| 4268 |
- ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
|
| 4269 |
+
- ``max_vio`` — detached worst-case routing imbalance across layers,
|
| 4270 |
+
- ``bias_std``, ``raw_logit_std``, ``logit_std``, ``bias_alignment`` —
|
| 4271 |
+
detached load-balance health scalars averaged across decoder layers;
|
| 4272 |
+
see ``ShramModel`` for interpretation.
|
| 4273 |
"""
|
| 4274 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 4275 |
output_hidden_states = (
|
|
|
|
| 4376 |
hidden_states=backbone_outputs["hidden_states"],
|
| 4377 |
load_balance_loss=backbone_outputs["load_balance_loss"],
|
| 4378 |
max_vio=backbone_outputs["max_vio"],
|
| 4379 |
+
bias_std=backbone_outputs["bias_std"],
|
| 4380 |
+
raw_logit_std=backbone_outputs["raw_logit_std"],
|
| 4381 |
+
logit_std=backbone_outputs["logit_std"],
|
| 4382 |
+
bias_alignment=backbone_outputs["bias_alignment"],
|
| 4383 |
)
|