Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
| """Rotary Position Embeddings (RoPE). | |
| RoPE encodes position in the *relationship* between query and key vectors. When the | |
| attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce | |
| a score that depends only on the relative distance — not on absolute positions. | |
| Two modes are supported: | |
| default Standard RoPE with base frequency b. Each dimension pair d is assigned | |
| frequency θ_d = b^{-2d/u} where u is the head dimension. The attention | |
| scaling A_rope = 1. | |
| yarn YaRN frequency interpolation for long-context extrapolation (Peng et al., | |
| "YaRN: Efficient Context Window Extension of Large Language Models", 2023, | |
| §A.2). Three frequency regimes: | |
| - Low-frequency dimensions (r < α): fully interpolated by scale s. | |
| These dimensions have long wavelengths relative to the training window | |
| and must be compressed to avoid out-of-distribution positions. | |
| - High-frequency dimensions (r > β): left unchanged. Short-wavelength | |
| dimensions already encode relative position accurately at any scale. | |
| - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r). | |
| Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to | |
| standard RoPE. | |
| Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit | |
| parameters — no shared instance, no config reading. See Unit 5.A design decisions. | |
| Cache sharing: all instances with identical parameters share one cos/sin table via a | |
| class-level registry. The first instance that needs a particular (parameters, device, | |
| dtype) combination builds the table; all subsequent instances reference it directly. | |
| This avoids redundant builds across the num_hidden_layers instances that share the | |
| same parametrisation. | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| # --------------------------------------------------------------------------- | |
| # Rotation helper | |
| # --------------------------------------------------------------------------- | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| """Apply the 90° rotation used in the RoPE update formula. | |
| Splits the last dimension into two halves [x1, x2] and returns [-x2, x1]. | |
| Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation | |
| on each consecutive pair of dimensions, matching the block-diagonal operator | |
| R^u_{Θ,p} in the paper. | |
| """ | |
| d = x.shape[-1] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| return torch.cat([-x2, x1], dim=-1) | |
| # --------------------------------------------------------------------------- | |
| # RotaryEmbedding | |
| # --------------------------------------------------------------------------- | |
| class RotaryEmbedding(nn.Module): | |
| """Rotary Position Embeddings with explicit mode and parameter control. | |
| Each caller constructs its own instance with the exact parameters it needs. | |
| h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No | |
| config object is read inside this module. | |
| The cos/sin table is built at construction time to cover all positions in | |
| ``[0, maximum_sequence_length)``. In forward, the table is rebuilt only if | |
| the query tensor's dtype or device has changed since construction. | |
| Instances with identical parameters share one cos/sin table via the class-level | |
| ``_cache`` registry, avoiding redundant computation across decoder layers. | |
| Args: | |
| mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation. | |
| head_dim: Per-head embedding dimension ``u``. Must be even. | |
| theta: Base frequency ``b`` in θ_d = b^{-2d/u}. | |
| maximum_sequence_length: Maximum number of positions the table must cover. | |
| The cos/sin table is preallocated to this length at construction time. | |
| For ``mode="yarn"``, the training context length C_train is derived | |
| internally as ``round(maximum_sequence_length / dilation)``. | |
| dilation: Scale factor ``s = C_target / C_train`` — how much the context | |
| window is extended beyond training length. Required for ``mode="yarn"``. | |
| When ``dilation=1.0``, YaRN reduces to standard RoPE. | |
| alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully | |
| interpolated. Required for ``mode="yarn"``. | |
| beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left | |
| unchanged. Required for ``mode="yarn"``. | |
| device: Optional device for initial buffer placement. | |
| Raises: | |
| NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``. | |
| ValueError: If ``mode="yarn"`` and any of ``dilation``, ``alpha``, | |
| ``beta`` are absent. | |
| """ | |
| # Maps (freq_key, device_str, dtype_str) → (cos_table, sin_table). | |
| # Shared across all RotaryEmbedding instances in the process. Keys include device | |
| # and dtype so that tables built on different devices or in different precisions | |
| # are stored independently. | |
| _cache: dict = {} | |
| def __init__( | |
| self, | |
| mode: str, | |
| head_dim: int, | |
| theta: float, | |
| maximum_sequence_length: int, | |
| dilation: float | None = None, | |
| alpha: float | None = None, | |
| beta: float | None = None, | |
| device: torch.device | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self._validate_mode(mode) | |
| self._validate_yarn_params(mode, dilation, alpha, beta) | |
| self.mode = mode | |
| self._maximum_sequence_length = maximum_sequence_length | |
| # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn). | |
| # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair, | |
| # so rotation_freqs has head_dim/2 entries. | |
| d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) | |
| base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u} | |
| if mode == "default": | |
| rotation_freqs = base_freqs | |
| self.attention_scaling: float = 1.0 | |
| else: # yarn | |
| s = dilation | |
| # C_train is the training context length, recovered from the inference | |
| # context length and the dilation factor. round() guards against floating | |
| # point error since both underlying quantities are integers. | |
| c_train: int = round(maximum_sequence_length / dilation) | |
| # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp | |
| # function to classify each dimension into one of three regimes. | |
| normalized_freqs = c_train * base_freqs / (2.0 * math.pi) | |
| # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged), | |
| # linear blend between α and β. | |
| blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0) | |
| # θ_d' = (1 − γ) · θ_d / s + γ · θ_d | |
| rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs | |
| # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller. | |
| self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2 | |
| # freq_key uniquely identifies the parameter set that produced rotation_freqs, | |
| # including maximum_sequence_length so instances with different table sizes | |
| # do not collide in the registry. | |
| if mode == "default": | |
| self._freq_key: tuple = ("default", head_dim, theta, maximum_sequence_length) | |
| else: | |
| self._freq_key = ("yarn", head_dim, theta, maximum_sequence_length, dilation, alpha, beta) | |
| # rotation_freqs is a non-persistent buffer so it moves with the model across | |
| # devices via .to() / .cuda() without appearing in saved checkpoints. | |
| # It is stored per-instance rather than in the shared cache because it is | |
| # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables | |
| # it is used to build. The meaningful sharing win is on those tables. | |
| self.register_buffer("rotation_freqs", rotation_freqs, persistent=False) | |
| # Cache tensors are plain instance attributes (not registered buffers) so that | |
| # sharing across identically-parametrised instances survives .to() calls. | |
| # Registered buffers are copied on device move; plain attributes are aliased, | |
| # preserving the shared-tensor identity that the cache design depends on. | |
| self._cos_cached: torch.Tensor | None = None | |
| self._sin_cached: torch.Tensor | None = None | |
| # Build the table at construction time. Forward rebuilds only on dtype or | |
| # device change. If no device is specified, build on CPU as the default. | |
| build_device = device if device is not None else torch.device("cpu") | |
| self._build_cache(device=build_device, dtype=torch.float32) | |
| # --------------------------------------------------------------------------- | |
| # Validation helpers | |
| # --------------------------------------------------------------------------- | |
| def _validate_mode(mode: str) -> None: | |
| """Raise NotImplementedError if mode is not a supported value.""" | |
| if mode not in {"default", "yarn"}: | |
| raise NotImplementedError( | |
| f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'." | |
| ) | |
| def _validate_yarn_params( | |
| mode: str, | |
| dilation: float | None, | |
| alpha: float | None, | |
| beta: float | None, | |
| ) -> None: | |
| """Raise ValueError if mode='yarn' and any required parameter is absent.""" | |
| if mode != "yarn": | |
| return | |
| missing = [ | |
| name for name, val in [ | |
| ("dilation", dilation), | |
| ("alpha", alpha), | |
| ("beta", beta), | |
| ] | |
| if val is None | |
| ] | |
| if missing: | |
| raise ValueError(f"mode='yarn' requires {missing}.") | |
| # --------------------------------------------------------------------------- | |
| # Cache management | |
| # --------------------------------------------------------------------------- | |
| def _build_cache(self, device: torch.device, dtype: torch.dtype) -> None: | |
| """Build the cos/sin table to cover positions [0, maximum_sequence_length). | |
| Checks the class-level registry first. If a table already exists for this | |
| exact (parameters, device, dtype) combination it is reused directly; | |
| otherwise it is computed and stored. The instance attributes are pointed at | |
| the registry entry so that all layers sharing the same parametrisation | |
| reference the same tensor. | |
| """ | |
| cache_key = (self._freq_key, str(device), str(dtype)) | |
| if cache_key not in RotaryEmbedding._cache: | |
| positions = torch.arange( | |
| self._maximum_sequence_length, device=device, dtype=torch.float32 | |
| ) | |
| # outer product → (maximum_sequence_length, head_dim // 2); | |
| # duplicate to (maximum_sequence_length, head_dim) | |
| freqs = torch.outer( | |
| positions, | |
| self.rotation_freqs.to(device=device, dtype=torch.float32), | |
| ) | |
| angle_embedding = torch.cat((freqs, freqs), dim=-1) | |
| RotaryEmbedding._cache[cache_key] = ( | |
| angle_embedding.cos().to(dtype), | |
| angle_embedding.sin().to(dtype), | |
| ) | |
| self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key] | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, float]: | |
| """Apply rotary embeddings to query and key tensors. | |
| The cos/sin table is built at construction time. It is rebuilt here only | |
| if ``q``'s dtype or device differs from the cached table — for example, | |
| after moving the model to a different device via ``.cuda()``. | |
| ``position_ids`` may be any integer tensor shape. Its values must be in | |
| ``[0, maximum_sequence_length)``: | |
| - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim). | |
| - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim). | |
| When q/k have head dimensions absent from position_ids, broadcast dimensions | |
| are inserted automatically at dim 1. | |
| Args: | |
| q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim). | |
| k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim). | |
| position_ids: Integer positions of shape (batch, *pos_dims). | |
| Returns: | |
| Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is | |
| 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must | |
| apply to attention logits before softmax. | |
| """ | |
| wrong_dtype = self._cos_cached.dtype != q.dtype | |
| wrong_device = self._cos_cached.device != q.device | |
| if wrong_dtype or wrong_device: | |
| self._build_cache(device=q.device, dtype=q.dtype) | |
| cos = self._cos_cached[position_ids] | |
| sin = self._sin_cached[position_ids] | |
| # Insert broadcast dimensions for any head axes present in q/k but absent | |
| # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once. | |
| # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed. | |
| while cos.ndim < q.ndim: | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| q_rotated = q * cos + _rotate_half(q) * sin | |
| k_rotated = k * cos + _rotate_half(k) * sin | |
| return q_rotated, k_rotated, self.attention_scaling | |