Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
| """Transformer backbone for Shram. | |
| ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed | |
| by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual | |
| representations. It has no knowledge of tokens, vocabulary, generation, or the | |
| HuggingFace causal-LM wrapper contract. | |
| Keeping the embedding out of the backbone is the correct convention and makes | |
| the backbone genuinely modality-agnostic. The token interface — embedding lookup, | |
| LM head, weight tying, and generation-facing naming conventions — belongs on the | |
| task wrapper (ShramForCausalLM), which is the only class that knows this | |
| backbone is being used for language modelling. | |
| The final RMSNorm is necessary because the decoder stack uses pre-norm throughout: | |
| each sublayer normalises its own input, leaving the residual stream itself | |
| unnormalised. After many layers of accumulated residuals, that stream arrives at | |
| the top with uncontrolled magnitude. The final norm brings it to a well-scaled | |
| state before any projection. Without it, the LM head would receive signals of | |
| arbitrary scale. | |
| Caching is caller-managed. If a ShramCache is provided, ShramModel threads the | |
| corresponding per-layer ShramLayerCache into each DecoderLayer and returns the | |
| same top-level ShramCache object in the output dict. If None is provided, no | |
| caching occurs. | |
| Returns a plain dict with keys: | |
| - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size) | |
| - "past_key_values": the ShramCache object passed in, or None | |
| - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None | |
| - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses | |
| - "max_vio": detached scalar maximum routing-imbalance across all decoder layers | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from .__cache__shram_cache import ShramCache | |
| from .configuration import ShramConfig | |
| from .decoder_layer import DecoderLayer | |
| class ShramModel(nn.Module): | |
| """Pure transformer backbone: decoder stack and final normalisation. | |
| Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size) | |
| and returns contextual representations of the same shape. No token embedding, | |
| vocabulary projection, or causal-LM lifecycle concerns. | |
| RoPE is applied inside each attention layer. Positional information is | |
| encoded in the relationship between Q and K, not added to the residual | |
| stream, so the backbone is agnostic to how positions are represented. | |
| Args: | |
| config: Model configuration. Must be a ``ShramConfig`` instance. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList( | |
| [DecoderLayer(config) for _ in range(config.num_decoder_layers)] | |
| ) | |
| self.norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps) | |
| def num_mosrah_parameters(self) -> int: | |
| """Return the total number of trainable MoSRAH parameters across all decoder layers.""" | |
| return sum(layer.num_mosrah_parameters() for layer in self.layers) | |
| def forward( | |
| self, | |
| inputs_embeds: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: ShramCache | None = None, | |
| output_hidden_states: bool = False, | |
| ) -> dict: | |
| """Run the transformer stack over a batch of pre-embedded sequences. | |
| Args: | |
| inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size). | |
| position_ids: Absolute positions of shape (batch, seq_len). Required. | |
| Must be provided explicitly by the caller — this module does not | |
| infer positions from cache state. | |
| active_mask: Current-chunk active mask of shape (batch, seq_len), | |
| where True means the token is semantically live. Forwarded | |
| unchanged to every decoder layer. | |
| cache: Optional top-level ShramCache. When provided, each DecoderLayer | |
| receives its own layer-local cache via ``cache.layers[layer_idx]``. | |
| The top-level cache object is updated in place and returned unchanged. | |
| output_hidden_states: When True, the output dict includes a tuple of | |
| per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out), | |
| collected before the final norm. | |
| Returns: | |
| Plain dict with keys: | |
| - ``"last_hidden_state"``: normed backbone output, | |
| shape (batch, seq_len, hidden_size). | |
| - ``"past_key_values"``: the cache object passed in, or None. | |
| - ``"hidden_states"``: tuple of per-layer activations (including | |
| inputs_embeds as position 0) if ``output_hidden_states`` is True, | |
| else None. Collected before the final norm so each entry reflects the | |
| unnormalised residual stream at that depth. | |
| - ``"load_balance_loss"``: scalar sum of per-layer SHRAM | |
| load-balance losses. | |
| - ``"max_vio"``: detached scalar maximum routing-imbalance across | |
| all decoder layers. Zero means perfectly balanced routing across | |
| every layer; higher values identify the worst-case head imbalance. | |
| """ | |
| hidden_states = inputs_embeds | |
| all_hidden_states = (hidden_states,) if output_hidden_states else None | |
| total_load_balance_loss = inputs_embeds.new_zeros(()) | |
| max_vio = inputs_embeds.new_zeros(()) | |
| for layer_idx, layer in enumerate(self.layers): | |
| layer_cache = None if cache is None else cache.layers[layer_idx] | |
| hidden_states, layer_load_balance_loss, layer_max_vio = layer( | |
| hidden_states, | |
| position_ids, | |
| active_mask, | |
| cache=layer_cache, | |
| ) | |
| total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss | |
| max_vio = torch.maximum(max_vio, layer_max_vio) | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| hidden_states = self.norm(hidden_states) | |
| return { | |
| "last_hidden_state": hidden_states, | |
| "past_key_values": cache, | |
| "hidden_states": all_hidden_states, | |
| "load_balance_loss": total_load_balance_loss, | |
| "max_vio": max_vio, | |
| } |