Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
File size: 13,845 Bytes
1670228 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | """Rotary Position Embeddings (RoPE).
RoPE encodes position in the *relationship* between query and key vectors. When the
attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
a score that depends only on the relative distance — not on absolute positions.
Two modes are supported:
default Standard RoPE with base frequency b. Each dimension pair d is assigned
frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
scaling A_rope = 1.
yarn YaRN frequency interpolation for long-context extrapolation (Peng et al.,
"YaRN: Efficient Context Window Extension of Large Language Models", 2023,
§A.2). Three frequency regimes:
- Low-frequency dimensions (r < α): fully interpolated by scale s.
These dimensions have long wavelengths relative to the training window
and must be compressed to avoid out-of-distribution positions.
- High-frequency dimensions (r > β): left unchanged. Short-wavelength
dimensions already encode relative position accurately at any scale.
- Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
standard RoPE.
Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
parameters — no shared instance, no config reading. See Unit 5.A design decisions.
Cache sharing: all instances with identical parameters share one cos/sin table via a
class-level registry. The first instance that needs a particular (parameters, device,
dtype) combination builds the table; all subsequent instances reference it directly.
This avoids redundant builds across the num_hidden_layers instances that share the
same parametrisation.
"""
import math
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Rotation helper
# ---------------------------------------------------------------------------
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Apply the 90° rotation used in the RoPE update formula.
Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
on each consecutive pair of dimensions, matching the block-diagonal operator
R^u_{Θ,p} in the paper.
"""
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
return torch.cat([-x2, x1], dim=-1)
# ---------------------------------------------------------------------------
# RotaryEmbedding
# ---------------------------------------------------------------------------
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings with explicit mode and parameter control.
Each caller constructs its own instance with the exact parameters it needs.
h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
config object is read inside this module.
The cos/sin table is built at construction time to cover all positions in
``[0, maximum_sequence_length)``. In forward, the table is rebuilt only if
the query tensor's dtype or device has changed since construction.
Instances with identical parameters share one cos/sin table via the class-level
``_cache`` registry, avoiding redundant computation across decoder layers.
Args:
mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
head_dim: Per-head embedding dimension ``u``. Must be even.
theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
maximum_sequence_length: Maximum number of positions the table must cover.
The cos/sin table is preallocated to this length at construction time.
For ``mode="yarn"``, the training context length C_train is derived
internally as ``round(maximum_sequence_length / dilation)``.
dilation: Scale factor ``s = C_target / C_train`` — how much the context
window is extended beyond training length. Required for ``mode="yarn"``.
When ``dilation=1.0``, YaRN reduces to standard RoPE.
alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
interpolated. Required for ``mode="yarn"``.
beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
unchanged. Required for ``mode="yarn"``.
device: Optional device for initial buffer placement.
Raises:
NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
ValueError: If ``mode="yarn"`` and any of ``dilation``, ``alpha``,
``beta`` are absent.
"""
# Maps (freq_key, device_str, dtype_str) → (cos_table, sin_table).
# Shared across all RotaryEmbedding instances in the process. Keys include device
# and dtype so that tables built on different devices or in different precisions
# are stored independently.
_cache: dict = {}
def __init__(
self,
mode: str,
head_dim: int,
theta: float,
maximum_sequence_length: int,
dilation: float | None = None,
alpha: float | None = None,
beta: float | None = None,
device: torch.device | None = None,
) -> None:
super().__init__()
self._validate_mode(mode)
self._validate_yarn_params(mode, dilation, alpha, beta)
self.mode = mode
self._maximum_sequence_length = maximum_sequence_length
# Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
# d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
# so rotation_freqs has head_dim/2 entries.
d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u}
if mode == "default":
rotation_freqs = base_freqs
self.attention_scaling: float = 1.0
else: # yarn
s = dilation
# C_train is the training context length, recovered from the inference
# context length and the dilation factor. round() guards against floating
# point error since both underlying quantities are integers.
c_train: int = round(maximum_sequence_length / dilation)
# r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
# function to classify each dimension into one of three regimes.
normalized_freqs = c_train * base_freqs / (2.0 * math.pi)
# γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
# linear blend between α and β.
blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)
# θ_d' = (1 − γ) · θ_d / s + γ · θ_d
rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs
# A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
# freq_key uniquely identifies the parameter set that produced rotation_freqs,
# including maximum_sequence_length so instances with different table sizes
# do not collide in the registry.
if mode == "default":
self._freq_key: tuple = ("default", head_dim, theta, maximum_sequence_length)
else:
self._freq_key = ("yarn", head_dim, theta, maximum_sequence_length, dilation, alpha, beta)
# rotation_freqs is a non-persistent buffer so it moves with the model across
# devices via .to() / .cuda() without appearing in saved checkpoints.
# It is stored per-instance rather than in the shared cache because it is
# small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
# it is used to build. The meaningful sharing win is on those tables.
self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)
# Cache tensors are plain instance attributes (not registered buffers) so that
# sharing across identically-parametrised instances survives .to() calls.
# Registered buffers are copied on device move; plain attributes are aliased,
# preserving the shared-tensor identity that the cache design depends on.
self._cos_cached: torch.Tensor | None = None
self._sin_cached: torch.Tensor | None = None
# Build the table at construction time. Forward rebuilds only on dtype or
# device change. If no device is specified, build on CPU as the default.
build_device = device if device is not None else torch.device("cpu")
self._build_cache(device=build_device, dtype=torch.float32)
# ---------------------------------------------------------------------------
# Validation helpers
# ---------------------------------------------------------------------------
@staticmethod
def _validate_mode(mode: str) -> None:
"""Raise NotImplementedError if mode is not a supported value."""
if mode not in {"default", "yarn"}:
raise NotImplementedError(
f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
)
@staticmethod
def _validate_yarn_params(
mode: str,
dilation: float | None,
alpha: float | None,
beta: float | None,
) -> None:
"""Raise ValueError if mode='yarn' and any required parameter is absent."""
if mode != "yarn":
return
missing = [
name for name, val in [
("dilation", dilation),
("alpha", alpha),
("beta", beta),
]
if val is None
]
if missing:
raise ValueError(f"mode='yarn' requires {missing}.")
# ---------------------------------------------------------------------------
# Cache management
# ---------------------------------------------------------------------------
def _build_cache(self, device: torch.device, dtype: torch.dtype) -> None:
"""Build the cos/sin table to cover positions [0, maximum_sequence_length).
Checks the class-level registry first. If a table already exists for this
exact (parameters, device, dtype) combination it is reused directly;
otherwise it is computed and stored. The instance attributes are pointed at
the registry entry so that all layers sharing the same parametrisation
reference the same tensor.
"""
cache_key = (self._freq_key, str(device), str(dtype))
if cache_key not in RotaryEmbedding._cache:
positions = torch.arange(
self._maximum_sequence_length, device=device, dtype=torch.float32
)
# outer product → (maximum_sequence_length, head_dim // 2);
# duplicate to (maximum_sequence_length, head_dim)
freqs = torch.outer(
positions,
self.rotation_freqs.to(device=device, dtype=torch.float32),
)
angle_embedding = torch.cat((freqs, freqs), dim=-1)
RotaryEmbedding._cache[cache_key] = (
angle_embedding.cos().to(dtype),
angle_embedding.sin().to(dtype),
)
self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, float]:
"""Apply rotary embeddings to query and key tensors.
The cos/sin table is built at construction time. It is rebuilt here only
if ``q``'s dtype or device differs from the cached table — for example,
after moving the model to a different device via ``.cuda()``.
``position_ids`` may be any integer tensor shape. Its values must be in
``[0, maximum_sequence_length)``:
- h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
- BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
When q/k have head dimensions absent from position_ids, broadcast dimensions
are inserted automatically at dim 1.
Args:
q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
position_ids: Integer positions of shape (batch, *pos_dims).
Returns:
Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
apply to attention logits before softmax.
"""
wrong_dtype = self._cos_cached.dtype != q.dtype
wrong_device = self._cos_cached.device != q.device
if wrong_dtype or wrong_device:
self._build_cache(device=q.device, dtype=q.dtype)
cos = self._cos_cached[position_ids]
sin = self._sin_cached[position_ids]
# Insert broadcast dimensions for any head axes present in q/k but absent
# from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
# BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
while cos.ndim < q.ndim:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_rotated = q * cos + _rotate_half(q) * sin
k_rotated = k * cos + _rotate_half(k) * sin
return q_rotated, k_rotated, self.attention_scaling
|