Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
File size: 17,602 Bytes
1670228 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 | """MoSRAH sparse KV cache β single-layer implementation.
MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
by head rather than by sequence position. The routing is dynamic and produces a ragged
distribution of token counts across (batch, head) slots β different batch items may
route different numbers of tokens to the same head, and different heads accumulate at
different rates. DynamicCache cannot represent this correctly: it concatenates along
the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
therefore uses a custom buffer design.
Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
buffer_capacity property and is derived directly from self.keys rather than tracked
as a separate variable.
The primary interface is update(key_states, value_states, active_mask), which accepts
expert-choice layout, stores only active entries in causal order, and returns the full
accumulated (keys, values, active_mask) for immediate use by BEA. The returned
active_mask identifies valid cached positions; everything beyond each slot's count is
junk data that downstream attention must exclude.
BEA applies RoPE and calls update() with post-RoPE keys (KΜ). The occupancy counts
exposed by get_heads_lengths() must be read before update() if the caller needs the
pre-update occupancy for position computation (Unit 10.A). update() increments counts
in-place and the pre-update values are not recoverable afterward.
All buffers are allocated at construction time. MoSRAHCache is constructed by
ShramLayerCache, which has access to batch size, device, and all model config parameters
needed to fully specify the storage layout upfront.
"""
import torch
from transformers.cache_utils import CacheLayerMixin
class MoSRAHCache(CacheLayerMixin):
"""KV cache for the MoSRAH sparse attention path β single decoder layer.
Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
Stores keys and values in the mixin-standard self.keys and self.values attributes
using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
All storage is allocated at construction time and is_initialized is True
immediately. The caller (ShramLayerCache) provides batch size, device, and model
config parameters so no lazy allocation is needed.
Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
(B, L, T) boolean active_mask. Only positions where active_mask is True are written.
This matches the packed representation produced by expert packing in the MoSRAH
forward pass, where BEA has already applied RoPE before calling update().
Args:
num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
second dimension of all storage tensors.
head_dim: Bottlenecked head embedding width (u). Determines the fourth
dimension of all storage tensors.
batch_size: Number of sequences in the batch. Determines the first dimension
of all storage tensors.
device: Device on which to allocate all tensors. Should match the model device.
mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to
config.mosrah_cache_length. The buffer never grows; if any slot would exceed
this capacity, update() raises in both eager and compiled modes. Increase
mosrah_overallocation_factor in ShramConfig to resolve an overflow.
"""
is_compileable = True
is_sliding = False
def __init__(
self,
num_mosrah_heads: int,
head_dim: int,
batch_size: int,
device: torch.device,
mosrah_cache_length: int,
) -> None:
super().__init__()
self.num_mosrah_heads = num_mosrah_heads
self.head_dim = head_dim
self.batch_size = batch_size
self.device = device
self.mosrah_cache_length = mosrah_cache_length
# Allocate primary storage into the mixin-standard self.keys / self.values so
# that inherited methods (offload, prefetch) operate on real tensors. _counts
# tracks valid occupancy per (batch, head) slot.
self.keys: torch.Tensor = torch.zeros(
batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
)
self.values: torch.Tensor = torch.zeros(
batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
)
self._counts: torch.Tensor = torch.zeros(
batch_size, num_mosrah_heads, dtype=torch.long, device=device
)
# Storage is fully allocated at construction β the cache is initialized.
self.is_initialized = True
# ---------------------------------------------------------------------------
# Properties
# ---------------------------------------------------------------------------
@property
def buffer_capacity(self) -> int:
"""Current number of slots allocated per (batch, head) pair.
Equal to mosrah_cache_length as supplied at construction. Derived from
self.keys so it remains consistent with the actual buffer shape.
"""
return self.keys.shape[2]
# ---------------------------------------------------------------------------
# Primary API
# ---------------------------------------------------------------------------
def update( # type: ignore[override]
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
active_mask: torch.Tensor,
cache_kwargs: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Scatter active key/value states into the buffer and return the full cache state.
Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
active_mask is (B, L, T) bool with True marking real tokens. Only active
positions are written; inactive positions are ignored.
Uses a fixed-shape destination mask constructed from per-slot write intervals
to transfer active tokens into the buffer without any data-dependent shape
operations. Active tokens are left-justified within each packed slot by the
packing machinery, so the destination positions are a contiguous range
starting at the current slot count β no cumsum or torch.where needed.
Returns the full accumulated (keys, values, active_mask) across the cached
sparse sequence. The returned active_mask is True exactly for slots t <
counts[b, l]; everything beyond is junk data that BEA must exclude.
Note: get_heads_lengths() must be called before update() if the caller needs
the pre-update occupancy for position computation (Unit 10.A). update()
increments counts in-place and the pre-update values are not recoverable.
Args:
key_states: Shape (B, L, T, u) β post-RoPE key vectors in expert-choice layout.
value_states: Shape (B, L, T, u) β value vectors in expert-choice layout.
active_mask: Shape (B, L, T) bool β True for real tokens, False for padding.
cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
Returns:
Tuple of (keys, values, active_mask):
keys: (B, L, mosrah_cache_length, u) float β full key buffer including junk slots.
values: (B, L, mosrah_cache_length, u) float β full value buffer including junk slots.
active_mask: (B, L, mosrah_cache_length) bool β True iff slot t has been written.
"""
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
post_counts = self._counts + incoming_delta
self._check_no_overflow(post_counts.max(), self.mosrah_cache_length)
# Build a fixed-shape destination mask in cache space. Active tokens within
# each (b, l) slot are left-justified by the packing machinery, so they occupy
# positions 0..s-1 in their packed slot. The corresponding cache positions are
# write_start[b,l]..write_start[b,l]+write_count[b,l]-1. Broadcasting a
# time arange against these per-slot intervals selects exactly the target
# positions without any data-dependent shape query.
write_start = self._counts.unsqueeze(-1) # cache position where new tokens begin
write_count = incoming_delta.unsqueeze(-1) # number of new tokens arriving per slot
time_arange = torch.arange(
self.mosrah_cache_length, device=active_mask.device
)
dest_mask = (time_arange >= write_start) & (time_arange < write_start + write_count)
# dest_mask: (B, L, mosrah_cache_length)
# Transfer key and value vectors. Left-justification guarantees that
# dest_mask and active_mask have equal True counts per (b, l) slot, so the
# boolean-mask transfer is correct without any explicit count verification.
self.keys[dest_mask] = key_states[active_mask]
self.values[dest_mask] = value_states[active_mask]
self._counts = post_counts
return self.keys, self.values, self._make_active_mask()
def get_heads_lengths(self) -> torch.Tensor:
"""Return the per-(batch, head) token count for this layer.
This is the authoritative occupancy tensor consumed by BEA for attention
masking and by position computation (Unit 10.A) for semantic-sequence
position computation.
Note: in the MoSRAH forward pass, this must be called before update() if the
caller needs the pre-update occupancy. update() increments these counts in-place.
Returns:
Integer tensor of shape (B, L) where entry [b, h] is the number of valid
tokens stored in the (b, h) slot. Zero for slots with no writes yet.
"""
return self._counts
# ---------------------------------------------------------------------------
# CacheLayerMixin β overridden coordination methods
# ---------------------------------------------------------------------------
def reset(self) -> None:
"""Clear all cached key and value tensors.
Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
and is_initialized remains True β only the contents are cleared.
"""
self.keys.zero_()
self.values.zero_()
self._counts.zero_()
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
"""Reorder the batch dimension of all cached tensors for beam search.
Applied atomically across self.keys, self.values, and _counts. Beam search
must reorder all three together or the occupancy counts and buffer contents
will correspond to different beam hypotheses.
Overrides the parent because the parent's implementation calls get_seq_length(),
which is not supported for this cache.
Args:
beam_idx: Permutation indices of shape (batch,) produced by the beam
search algorithm.
"""
self.keys = self.keys[beam_idx]
self.values = self.values[beam_idx]
self._counts = self._counts[beam_idx]
def batch_repeat_interleave(self, repeats: int) -> None:
"""Expand the batch dimension by repeating each entry repeats times.
Used at beam search initialisation to expand the cache from batch size B to
B * repeats, matching the expanded beam candidate batch. Applied atomically
across keys, values, and _counts; batch_size is updated to reflect the new size.
Args:
repeats: Number of times to repeat each batch entry.
"""
self.keys = self.keys.repeat_interleave(repeats, dim=0)
self.values = self.values.repeat_interleave(repeats, dim=0)
self._counts = self._counts.repeat_interleave(repeats, dim=0)
self.batch_size = self.batch_size * repeats
def batch_select_indices(self, indices: torch.Tensor) -> None:
"""Select a subset of batch entries by index.
Used in contrastive search to retain only the selected candidate entries.
Applied atomically across keys, values, and _counts; batch_size is updated
to reflect the number of retained entries.
Args:
indices: 1-D integer tensor of batch indices to retain.
"""
self.keys = self.keys[indices]
self.values = self.values[indices]
self._counts = self._counts[indices]
self.batch_size = indices.shape[0]
def offload(self) -> None:
"""Offload all cached tensors to CPU.
Extends the parent to also offload _counts, which the parent does not know
about. All three tensors are moved atomically so device state remains consistent.
"""
super().offload()
self._counts = self._counts.to("cpu", non_blocking=True)
def prefetch(self) -> None:
"""Move all cached tensors back to the model device ahead of time.
Extends the parent to also prefetch _counts, which the parent does not know
about. _counts is synced to self.keys.device after the parent moves keys and
values, so all three remain consistent.
"""
super().prefetch()
if self._counts.device != self.keys.device:
self._counts = self._counts.to(self.keys.device, non_blocking=True)
def lazy_initialization( # type: ignore[override]
self, key_states: torch.Tensor, value_states: torch.Tensor
) -> None:
"""No-op β storage is fully allocated at construction time."""
pass
# ---------------------------------------------------------------------------
# CacheLayerMixin β unsupported abstract methods
# ---------------------------------------------------------------------------
def get_seq_length(self) -> int: # type: ignore[override]
"""Not supported β no single sequence length represents this cache's state.
MoSRAH heads accumulate independently; (batch, head) slots have different
lengths depending on routing history. There is no meaningful scalar summary.
Use get_heads_lengths() for per-head occupancy.
"""
raise NotImplementedError(
"MoSRAHCache has no single sequence length. "
"Use get_heads_lengths() for per-head occupancy."
)
def get_max_cache_shape(self) -> int: # type: ignore[override]
"""Return the static per-(batch, head) slot capacity of this cache.
Equal to mosrah_cache_length as supplied at construction, which is derived
from config.mosrah_cache_length. Required by the HuggingFace static cache
contract; generation machinery uses this to size attention masks.
"""
return self.mosrah_cache_length
def get_mask_sizes( # type: ignore[override]
self,
cache_position: torch.Tensor,
) -> tuple[int, int]:
"""Not supported β MoSRAHCache does not participate in HF mask construction."""
raise NotImplementedError(
"MoSRAHCache does not support get_mask_sizes()."
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _make_active_mask(self) -> torch.Tensor:
"""Construct the (B, L, T) active mask from current counts.
Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
has been written. Positions at or beyond the count are junk and must be
excluded by downstream attention.
"""
cap = self.buffer_capacity
return (
torch.arange(cap, device=self.keys.device)
.expand(self.batch_size, self.num_mosrah_heads, cap)
< self._counts.unsqueeze(-1)
)
@staticmethod
def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
"""Raise if any (batch, head) slot would exceed the static buffer capacity.
Uses the 19.F.1 pattern: branches on whether the graph is being compiled.
In compiled mode, `.item()` folds into the graph when capture_scalar_outputs=True
and `torch._check` issues a compile-time assertion. In eager mode, a plain
RuntimeError is raised with a descriptive message.
Args:
max_count: Scalar tensor β the maximum post-update count across all slots.
capacity: The static buffer capacity (mosrah_cache_length).
"""
if torch.compiler.is_compiling():
torch._check(max_count.item() <= capacity)
else:
if max_count.item() > capacity:
raise RuntimeError(
f"MoSRAHCache overflow: a (batch, head) slot would reach "
f"{max_count.item()} tokens but the static buffer capacity is "
f"{capacity}. Increase mosrah_overallocation_factor in ShramConfig."
)
|