Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
| """Token-choice router for the MoSRAH sparse attention path. | |
| This module implements the routing mechanism described in Appendix A.Routing of the | |
| paper. Given an input hidden state x, the router produces two outputs used downstream: | |
| - selected_heads (I): which K of the L available expert heads each token routes to, | |
| determined by TopK over biased routing scores. | |
| - routing_probs (P): the weights used for the weighted output reduction, gathered from | |
| *unbiased* routing scores at the selected indices and renormalized. The learned expert | |
| bias b must not influence P. | |
| This separation is architecturally critical: expert_bias drives selection (and thus load | |
| balancing) but does not corrupt the gradient path from the output through routing_probs | |
| back to the routing projection weights. | |
| The router also computes and returns the load balance loss via the LoadBalanceLoss custom | |
| autograd operator (see load_balance_loss.py). This loss is a scalar that the training | |
| loop can weight and add to the language modeling loss. | |
| The router additionally computes and returns MaxVio, a detached scalar summarising | |
| routing imbalance for the current forward pass: | |
| MaxVio = L · max_l(f_l − 1/L) | |
| where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced | |
| target. MaxVio is a monitoring quantity only; it never contributes gradients. | |
| Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .configuration import ShramConfig | |
| from .__attention__load_balance_loss import LoadBalanceLoss | |
| class MoSRAHRouter(nn.Module): | |
| """Token-choice router for MoSRAH sparse attention. | |
| Each input token independently selects K of the L available expert heads. Selection | |
| is driven by biased routing scores to enable load balancing, but the routing | |
| probabilities used for output reduction are computed from unbiased scores so that | |
| the expert bias does not interfere with the gradient path to the router weights. | |
| The routing projection W_r has no bias term — the paper specifies xW_r with no | |
| additional projection bias. The only bias-like parameter is expert_bias (b), which | |
| has an entirely separate role and update mechanism. | |
| Args: | |
| config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads`` | |
| (L), and ``num_selected_heads`` (K). | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.num_mosrah_heads = config.num_mosrah_heads | |
| self.num_selected_heads = config.num_selected_heads | |
| self.load_balance_p = config.load_balance_p | |
| # W_r: routing projection, no bias (paper specifies xW_r, no additional term). | |
| self.routing_projection = nn.Linear( | |
| config.embedding_width, config.num_mosrah_heads, bias=False | |
| ) | |
| # b: learned per-head bias for load balancing. Initialized to zero so that all | |
| # heads start with equal selection probability. Updated by the main optimizer | |
| # via the LoadBalanceLoss custom backward. | |
| self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads)) | |
| def forward( | |
| self, x: torch.Tensor, active_mask: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Route input tokens to K expert heads each and compute routing probabilities. | |
| Args: | |
| x: Input hidden states of shape (batch, seq_len, hidden_size). | |
| active_mask: Current-chunk active mask of shape (batch, seq_len), where | |
| True means the token is semantically live. Dead tokens do not | |
| contribute to routing frequencies, load_balance_loss, or max_vio. | |
| Returns: | |
| selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads). | |
| Each token's K selected head indices, determined by TopK on biased scores. | |
| routing_probs: Routing probabilities P of shape (batch, seq_len, | |
| num_selected_heads). Gathered from unbiased scores at selected_heads | |
| indices and renormalized to sum to 1 per token. | |
| load_balance_loss: Scalar load balance imbalance loss for this forward pass. | |
| Training loop scales this by a weight and adds it to the main loss. | |
| max_vio: Detached scalar routing-imbalance summary for this forward pass. | |
| Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss; | |
| never contributes gradients. | |
| """ | |
| B, N, _ = x.shape | |
| L = self.num_mosrah_heads | |
| K = self.num_selected_heads | |
| # Unbiased routing scores R = Softmax(xW_r). These are the scores used to | |
| # compute routing_probs — expert_bias must not influence them. | |
| logits = self.routing_projection(x) # (B, N, L) | |
| routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L) | |
| # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head | |
| # selection. expert_bias is added to logits before softmax so that the bias | |
| # shifts selection probability without rescaling the unbiased distribution. | |
| biased_routing_scores = F.softmax( # R̂, (B, N, L) | |
| logits + self.expert_bias, dim=-1 | |
| ) | |
| # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K). | |
| selected_heads = biased_routing_scores.topk(K, dim=-1).indices | |
| # Routing probabilities P: gathered from unbiased R at selected_heads indices, | |
| # then renormalized so they sum to 1 per token. Gathering from routing_scores | |
| # (not biased_routing_scores) is the invariant that keeps the gradient path from | |
| # the output back to the router weights free of expert_bias influence. | |
| gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K) | |
| routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K) | |
| # Per-item routing frequencies f_{b,l}: for each batch item b and head l, what | |
| # fraction of that item's active K assignments over all tokens go to head l. | |
| # Dead tokens are excluded before reduction. Normalization is per batch item so | |
| # each item's frequencies sum to 1 independently of other items in the batch. | |
| assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype) | |
| assignment_mask.scatter_(-1, selected_heads, 1.0) | |
| active_assignments = assignment_mask * active_mask.unsqueeze(-1) | |
| per_item_counts = active_assignments.sum(dim=1) # (B, L) | |
| per_item_total = active_mask.sum(dim=1, keepdim=True) * K # (B, 1) | |
| per_item_freqs = per_item_counts / per_item_total # (B, L) | |
| # p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,). | |
| # p-mean weights aggregation toward the worst-case batch item relative to | |
| # arithmetic mean, making the load balance signal sensitive to per-item spikes | |
| # that cause packing overflow. | |
| p = self.load_balance_p | |
| routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,) | |
| # Load balance loss via custom autograd. expert_bias is an input so PyTorch | |
| # registers it as a graph node; the custom backward writes the DeepSeek-style | |
| # correction gradient to expert_bias.grad for the optimizer to consume. | |
| load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs) | |
| # MaxVio is a detached monitoring scalar following the paper's formula | |
| # L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients. | |
| max_vio = self._compute_max_vio(routing_freqs, L) | |
| return selected_heads, routing_probs, load_balance_loss, max_vio | |
| def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor: | |
| """Compute the MaxVio routing-imbalance scalar. | |
| MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of | |
| head l and 1/L is the perfectly balanced target. Follows the paper's definition | |
| (Wang et al.) applied to routing_freqs. A value of zero indicates perfect | |
| balance; a value of 0.5 means the most overloaded head received 50% more routed | |
| tokens than ideal. | |
| The result is detached from the autograd graph — MaxVio is a monitoring scalar | |
| and must never contribute gradients to any parameter. | |
| Args: | |
| routing_freqs: Per-head routing frequencies of shape (L,). | |
| num_heads: Total number of MoSRAH heads L. | |
| Returns: | |
| Detached scalar MaxVio tensor. | |
| """ | |
| return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach() | |