Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- huggingface.py +102 -16
huggingface.py
CHANGED
|
@@ -44,6 +44,7 @@ from torch import nn
|
|
| 44 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 45 |
from torch.nn.attention.flex_attention import flex_attention
|
| 46 |
import torch.nn.functional as F
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
|
|
@@ -547,8 +548,7 @@ class MoSRAHCache(CacheLayerMixin):
|
|
| 547 |
# boolean-mask transfer is correct without any explicit count verification.
|
| 548 |
self.keys[dest_mask] = key_states[active_mask]
|
| 549 |
self.values[dest_mask] = value_states[active_mask]
|
| 550 |
-
|
| 551 |
-
self._counts = post_counts
|
| 552 |
|
| 553 |
return self.keys, self.values, self._make_active_mask()
|
| 554 |
|
|
@@ -1405,15 +1405,20 @@ Returns a plain dict with keys:
|
|
| 1405 |
"""Decoder layer — a single transformer block.
|
| 1406 |
|
| 1407 |
Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
|
| 1408 |
-
residual connections around both sublayers:
|
| 1409 |
|
| 1410 |
normed_attn = RMSNorm(x)
|
| 1411 |
attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
|
| 1412 |
-
h = x + attn_out
|
| 1413 |
|
| 1414 |
normed_mlp = RMSNorm(h)
|
| 1415 |
mlp_out = SwiGLUMLP(normed_mlp)
|
| 1416 |
-
out = h + mlp_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1417 |
|
| 1418 |
Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
|
| 1419 |
through unnormalised residuals at depth, and each sublayer receives a stable,
|
|
@@ -2344,7 +2349,7 @@ def setup_packing(
|
|
| 2344 |
batch_size,
|
| 2345 |
sequence_length * num_selected_heads,
|
| 2346 |
)
|
| 2347 |
-
|
| 2348 |
permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
|
| 2349 |
inverse_permutation = torch.argsort(permutation, dim=-1)
|
| 2350 |
|
|
@@ -2352,6 +2357,7 @@ def setup_packing(
|
|
| 2352 |
"flattened_selected_heads": flattened_selected_heads,
|
| 2353 |
"permutation": permutation,
|
| 2354 |
"inverse_permutation": inverse_permutation,
|
|
|
|
| 2355 |
}
|
| 2356 |
|
| 2357 |
|
|
@@ -2493,6 +2499,7 @@ def pack_experts(
|
|
| 2493 |
(batch_size, num_experts, packed_length, *extra_shape),
|
| 2494 |
fill_value=padding_value,
|
| 2495 |
)
|
|
|
|
| 2496 |
packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
|
| 2497 |
packed_entries[key] = packed_tensor
|
| 2498 |
|
|
@@ -2537,7 +2544,17 @@ def unpack_experts(
|
|
| 2537 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 2538 |
hidden_dim = expert_outputs.shape[-1]
|
| 2539 |
|
| 2540 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2541 |
sorted_token_choice_outputs = active_outputs.reshape(
|
| 2542 |
batch_size,
|
| 2543 |
sequence_length * num_selected_heads,
|
|
@@ -2547,7 +2564,6 @@ def unpack_experts(
|
|
| 2547 |
dim=1,
|
| 2548 |
index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
|
| 2549 |
)
|
| 2550 |
-
|
| 2551 |
return restored_outputs.reshape(
|
| 2552 |
batch_size,
|
| 2553 |
sequence_length,
|
|
@@ -2753,6 +2769,7 @@ class LoadBalanceLoss(torch.autograd.Function):
|
|
| 2753 |
|
| 2754 |
|
| 2755 |
|
|
|
|
| 2756 |
class MoSRAHRouter(nn.Module):
|
| 2757 |
"""Token-choice router for MoSRAH sparse attention.
|
| 2758 |
|
|
@@ -2775,6 +2792,10 @@ class MoSRAHRouter(nn.Module):
|
|
| 2775 |
self.num_mosrah_heads = config.num_mosrah_heads
|
| 2776 |
self.num_selected_heads = config.num_selected_heads
|
| 2777 |
self.load_balance_p = config.load_balance_p
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2778 |
|
| 2779 |
# W_r: routing projection, no bias (paper specifies xW_r, no additional term).
|
| 2780 |
self.routing_projection = nn.Linear(
|
|
@@ -2786,8 +2807,69 @@ class MoSRAHRouter(nn.Module):
|
|
| 2786 |
# via the LoadBalanceLoss custom backward.
|
| 2787 |
self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
|
| 2788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2789 |
def forward(
|
| 2790 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 2791 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 2792 |
"""Route input tokens to K expert heads each and compute routing probabilities.
|
| 2793 |
|
|
@@ -2796,7 +2878,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 2796 |
active_mask: Current-chunk active mask of shape (batch, seq_len), where
|
| 2797 |
True means the token is semantically live. Dead tokens do not
|
| 2798 |
contribute to routing frequencies, load_balance_loss, or max_vio.
|
| 2799 |
-
|
| 2800 |
Returns:
|
| 2801 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 2802 |
Each token's K selected head indices, determined by TopK on biased scores.
|
|
@@ -2821,18 +2903,21 @@ class MoSRAHRouter(nn.Module):
|
|
| 2821 |
# Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
|
| 2822 |
# selection. expert_bias is added to logits before softmax so that the bias
|
| 2823 |
# shifts selection probability without rescaling the unbiased distribution.
|
|
|
|
|
|
|
| 2824 |
biased_routing_scores = F.softmax( # R̂, (B, N, L)
|
| 2825 |
-
|
| 2826 |
)
|
| 2827 |
|
| 2828 |
# selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
|
|
|
|
| 2829 |
selected_heads = biased_routing_scores.topk(K, dim=-1).indices
|
|
|
|
| 2830 |
|
| 2831 |
# Routing probabilities P: gathered from unbiased R at selected_heads indices,
|
| 2832 |
# then renormalized so they sum to 1 per token. Gathering from routing_scores
|
| 2833 |
# (not biased_routing_scores) is the invariant that keeps the gradient path from
|
| 2834 |
# the output back to the router weights free of expert_bias influence.
|
| 2835 |
-
gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
|
| 2836 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 2837 |
|
| 2838 |
# Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
|
|
@@ -3062,8 +3147,9 @@ class MoSRAHLayer(nn.Module):
|
|
| 3062 |
# B*N*K True entries) and the packed active mask (live slots only);
|
| 3063 |
# active_mask is rebound to the packed form after this point.
|
| 3064 |
# -------------------------------------------------------------------
|
|
|
|
| 3065 |
selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
|
| 3066 |
-
hidden_states, active_mask
|
| 3067 |
)
|
| 3068 |
|
| 3069 |
setup = setup_packing(selected_heads)
|
|
@@ -3282,7 +3368,7 @@ class DecoderLayer(nn.Module):
|
|
| 3282 |
self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
|
| 3283 |
self.attention = SHRAMHybridLayer(config)
|
| 3284 |
self.mlp = SwiGLUMLP(config)
|
| 3285 |
-
|
| 3286 |
def num_mosrah_parameters(self) -> int:
|
| 3287 |
"""Return the total number of trainable MoSRAH parameters in this decoder layer."""
|
| 3288 |
return self.attention.num_mosrah_parameters()
|
|
@@ -3318,8 +3404,8 @@ class DecoderLayer(nn.Module):
|
|
| 3318 |
active_mask=active_mask,
|
| 3319 |
cache=cache,
|
| 3320 |
)
|
| 3321 |
-
hidden_states = x + attn_out
|
| 3322 |
-
output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
|
| 3323 |
return output, load_balance_loss, max_vio
|
| 3324 |
|
| 3325 |
|
|
|
|
| 44 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 45 |
from torch.nn.attention.flex_attention import flex_attention
|
| 46 |
import torch.nn.functional as F
|
| 47 |
+
from typing import Optional
|
| 48 |
|
| 49 |
|
| 50 |
|
|
|
|
| 548 |
# boolean-mask transfer is correct without any explicit count verification.
|
| 549 |
self.keys[dest_mask] = key_states[active_mask]
|
| 550 |
self.values[dest_mask] = value_states[active_mask]
|
| 551 |
+
self._counts[:] = post_counts[:]
|
|
|
|
| 552 |
|
| 553 |
return self.keys, self.values, self._make_active_mask()
|
| 554 |
|
|
|
|
| 1405 |
"""Decoder layer — a single transformer block.
|
| 1406 |
|
| 1407 |
Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
|
| 1408 |
+
gated residual connections around both sublayers:
|
| 1409 |
|
| 1410 |
normed_attn = RMSNorm(x)
|
| 1411 |
attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
|
| 1412 |
+
h = x + residual_gate * attn_out
|
| 1413 |
|
| 1414 |
normed_mlp = RMSNorm(h)
|
| 1415 |
mlp_out = SwiGLUMLP(normed_mlp)
|
| 1416 |
+
out = h + residual_gate * mlp_out
|
| 1417 |
+
|
| 1418 |
+
A single shared residual_gate vector (shape: embedding_width, init: zeros) gates
|
| 1419 |
+
both sublayer contributions. At initialisation the layer is a pure identity, which
|
| 1420 |
+
prevents variance explosion through depth regardless of how HuggingFace initialises
|
| 1421 |
+
the projection weights. The gate is a trainable parameter and opens during training.
|
| 1422 |
|
| 1423 |
Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
|
| 1424 |
through unnormalised residuals at depth, and each sublayer receives a stable,
|
|
|
|
| 2349 |
batch_size,
|
| 2350 |
sequence_length * num_selected_heads,
|
| 2351 |
)
|
| 2352 |
+
num_elements = batch_size*sequence_length*num_selected_heads
|
| 2353 |
permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
|
| 2354 |
inverse_permutation = torch.argsort(permutation, dim=-1)
|
| 2355 |
|
|
|
|
| 2357 |
"flattened_selected_heads": flattened_selected_heads,
|
| 2358 |
"permutation": permutation,
|
| 2359 |
"inverse_permutation": inverse_permutation,
|
| 2360 |
+
"num_elements" : num_elements,
|
| 2361 |
}
|
| 2362 |
|
| 2363 |
|
|
|
|
| 2499 |
(batch_size, num_experts, packed_length, *extra_shape),
|
| 2500 |
fill_value=padding_value,
|
| 2501 |
)
|
| 2502 |
+
|
| 2503 |
packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
|
| 2504 |
packed_entries[key] = packed_tensor
|
| 2505 |
|
|
|
|
| 2544 |
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 2545 |
hidden_dim = expert_outputs.shape[-1]
|
| 2546 |
|
| 2547 |
+
coords = torch.nonzero_static(
|
| 2548 |
+
unpacking_mask,
|
| 2549 |
+
size=setup["num_elements"],
|
| 2550 |
+
) # shape: (B*N*K, 3)
|
| 2551 |
+
|
| 2552 |
+
active_outputs = expert_outputs[
|
| 2553 |
+
coords[:, 0],
|
| 2554 |
+
coords[:, 1],
|
| 2555 |
+
coords[:, 2],
|
| 2556 |
+
] # shape: (B*N*K, d)
|
| 2557 |
+
|
| 2558 |
sorted_token_choice_outputs = active_outputs.reshape(
|
| 2559 |
batch_size,
|
| 2560 |
sequence_length * num_selected_heads,
|
|
|
|
| 2564 |
dim=1,
|
| 2565 |
index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
|
| 2566 |
)
|
|
|
|
| 2567 |
return restored_outputs.reshape(
|
| 2568 |
batch_size,
|
| 2569 |
sequence_length,
|
|
|
|
| 2769 |
|
| 2770 |
|
| 2771 |
|
| 2772 |
+
|
| 2773 |
class MoSRAHRouter(nn.Module):
|
| 2774 |
"""Token-choice router for MoSRAH sparse attention.
|
| 2775 |
|
|
|
|
| 2792 |
self.num_mosrah_heads = config.num_mosrah_heads
|
| 2793 |
self.num_selected_heads = config.num_selected_heads
|
| 2794 |
self.load_balance_p = config.load_balance_p
|
| 2795 |
+
if config.use_cache:
|
| 2796 |
+
self.capacity = config.mosrah_cache_length
|
| 2797 |
+
else:
|
| 2798 |
+
self.capacity = config.mosrah_packed_length
|
| 2799 |
|
| 2800 |
# W_r: routing projection, no bias (paper specifies xW_r, no additional term).
|
| 2801 |
self.routing_projection = nn.Linear(
|
|
|
|
| 2807 |
# via the LoadBalanceLoss custom backward.
|
| 2808 |
self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
|
| 2809 |
|
| 2810 |
+
@staticmethod
|
| 2811 |
+
def balance_capacity(logits: torch.Tensor,
|
| 2812 |
+
used_capacity: torch.Tensor | None,
|
| 2813 |
+
capacity: int,
|
| 2814 |
+
)->torch.Tensor:
|
| 2815 |
+
"""
|
| 2816 |
+
Balances capacity limits so that if choosing an
|
| 2817 |
+
expert would go over capacity, the expert is simply
|
| 2818 |
+
not chosen instead
|
| 2819 |
+
:param logits: The logits to balance. (B, N, L)
|
| 2820 |
+
:param used_capacity: The used capacity, if it exists. (B, L)
|
| 2821 |
+
:param capacity: The maximum available capacity. Int.
|
| 2822 |
+
:return: Modified logits.
|
| 2823 |
+
"""
|
| 2824 |
+
|
| 2825 |
+
if used_capacity is None:
|
| 2826 |
+
# Presume we are in training mode.
|
| 2827 |
+
|
| 2828 |
+
# Looking up capacity limits only
|
| 2829 |
+
# matters if it is, in fact, possible
|
| 2830 |
+
# to exceed capacity limits.
|
| 2831 |
+
if logits.shape[-2] < capacity:
|
| 2832 |
+
return logits
|
| 2833 |
+
|
| 2834 |
+
# Look up the kthvalue and use that as
|
| 2835 |
+
# the threshold to mask when below.
|
| 2836 |
+
# Note we negate then negate again to sort
|
| 2837 |
+
# in ascending order.
|
| 2838 |
+
response = torch.kthvalue(-logits, capacity, dim=-2)
|
| 2839 |
+
threshold = -response.values
|
| 2840 |
+
threshold = threshold.unsqueeze(-2) #(B, 1, L)
|
| 2841 |
+
else:
|
| 2842 |
+
# We are operating in inference mode.
|
| 2843 |
+
# We have to use padding to accomodate the
|
| 2844 |
+
# response physically not being long enough
|
| 2845 |
+
# to reach capacity
|
| 2846 |
+
|
| 2847 |
+
# Note that padding at zero and shifting
|
| 2848 |
+
# the indexes prevents dereferencing a symint,
|
| 2849 |
+
# as a version that just patted at 0, 1 and set to
|
| 2850 |
+
# length + 1 would do. This prevents a graph break.
|
| 2851 |
+
remaining_capacity = capacity - used_capacity # 0 means all used, can be at most capacity
|
| 2852 |
+
response_length = logits.shape[-2]
|
| 2853 |
+
index = torch.clamp(remaining_capacity, 0, response_length+1)
|
| 2854 |
+
|
| 2855 |
+
# Sort, and add padding. Anything asking for a sequence position
|
| 2856 |
+
# outside the current sequence will get a threshold of -1e8; always include
|
| 2857 |
+
# If we are asking for a value at zero, get 1e8, or full and we include
|
| 2858 |
+
# nothing.
|
| 2859 |
+
ordered_logits = torch.sort(logits, dim=-2, descending=True).values
|
| 2860 |
+
ordered_logits = F.pad(ordered_logits, (0,0, 1, 0), value=1e8)
|
| 2861 |
+
ordered_logits = F.pad(ordered_logits, (0, 0, 0, 1), value=-1e8)
|
| 2862 |
+
|
| 2863 |
+
threshold = ordered_logits.gather(-2, index.unsqueeze(-2)) #(B, 1, L)
|
| 2864 |
+
|
| 2865 |
+
mask = threshold > logits
|
| 2866 |
+
logits = logits.masked_fill(mask, -1e8)
|
| 2867 |
+
return logits
|
| 2868 |
def forward(
|
| 2869 |
+
self,
|
| 2870 |
+
x: torch.Tensor,
|
| 2871 |
+
active_mask: torch.Tensor,
|
| 2872 |
+
used_capacity: torch.Tensor | None
|
| 2873 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 2874 |
"""Route input tokens to K expert heads each and compute routing probabilities.
|
| 2875 |
|
|
|
|
| 2878 |
active_mask: Current-chunk active mask of shape (batch, seq_len), where
|
| 2879 |
True means the token is semantically live. Dead tokens do not
|
| 2880 |
contribute to routing frequencies, load_balance_loss, or max_vio.
|
| 2881 |
+
used_capacity: Used for capacity management during inference, missing during training.
|
| 2882 |
Returns:
|
| 2883 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 2884 |
Each token's K selected head indices, determined by TopK on biased scores.
|
|
|
|
| 2903 |
# Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
|
| 2904 |
# selection. expert_bias is added to logits before softmax so that the bias
|
| 2905 |
# shifts selection probability without rescaling the unbiased distribution.
|
| 2906 |
+
biased_logits = logits + self.expert_bias
|
| 2907 |
+
biased_logits = self.balance_capacity(biased_logits, used_capacity, self.capacity)
|
| 2908 |
biased_routing_scores = F.softmax( # R̂, (B, N, L)
|
| 2909 |
+
biased_logits, dim=-1
|
| 2910 |
)
|
| 2911 |
|
| 2912 |
# selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
|
| 2913 |
+
# and routing logits directly
|
| 2914 |
selected_heads = biased_routing_scores.topk(K, dim=-1).indices
|
| 2915 |
+
gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
|
| 2916 |
|
| 2917 |
# Routing probabilities P: gathered from unbiased R at selected_heads indices,
|
| 2918 |
# then renormalized so they sum to 1 per token. Gathering from routing_scores
|
| 2919 |
# (not biased_routing_scores) is the invariant that keeps the gradient path from
|
| 2920 |
# the output back to the router weights free of expert_bias influence.
|
|
|
|
| 2921 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 2922 |
|
| 2923 |
# Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
|
|
|
|
| 3147 |
# B*N*K True entries) and the packed active mask (live slots only);
|
| 3148 |
# active_mask is rebound to the packed form after this point.
|
| 3149 |
# -------------------------------------------------------------------
|
| 3150 |
+
used_capacity = cache.get_heads_lengths() if cache is not None else None
|
| 3151 |
selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
|
| 3152 |
+
hidden_states, active_mask, used_capacity
|
| 3153 |
)
|
| 3154 |
|
| 3155 |
setup = setup_packing(selected_heads)
|
|
|
|
| 3368 |
self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
|
| 3369 |
self.attention = SHRAMHybridLayer(config)
|
| 3370 |
self.mlp = SwiGLUMLP(config)
|
| 3371 |
+
self.residual_gate = nn.Parameter(torch.zeros([config.embedding_width]))
|
| 3372 |
def num_mosrah_parameters(self) -> int:
|
| 3373 |
"""Return the total number of trainable MoSRAH parameters in this decoder layer."""
|
| 3374 |
return self.attention.num_mosrah_parameters()
|
|
|
|
| 3404 |
active_mask=active_mask,
|
| 3405 |
cache=cache,
|
| 3406 |
)
|
| 3407 |
+
hidden_states = x + self.residual_gate*attn_out
|
| 3408 |
+
output = hidden_states + self.residual_gate*self.mlp(self.mlp_norm(hidden_states))
|
| 3409 |
return output, load_balance_loss, max_vio
|
| 3410 |
|
| 3411 |
|