Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- config.json +1 -1
- huggingface.py +93 -12
config.json
CHANGED
|
@@ -24,7 +24,7 @@
|
|
| 24 |
"rope_mode": "main_sequence",
|
| 25 |
"tie_word_embeddings": false,
|
| 26 |
"training_sequence_length": 1024,
|
| 27 |
-
"transformers_version": "5.
|
| 28 |
"use_cache": true,
|
| 29 |
"vocab_size": 50277,
|
| 30 |
"window_size": 128
|
|
|
|
| 24 |
"rope_mode": "main_sequence",
|
| 25 |
"tie_word_embeddings": false,
|
| 26 |
"training_sequence_length": 1024,
|
| 27 |
+
"transformers_version": "5.10.1",
|
| 28 |
"use_cache": true,
|
| 29 |
"vocab_size": 50277,
|
| 30 |
"window_size": 128
|
huggingface.py
CHANGED
|
@@ -1284,6 +1284,13 @@ class ShramCache(Cache):
|
|
| 1284 |
layer have materially different update semantics; callers must update sub-caches directly
|
| 1285 |
via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
|
| 1286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1287 |
Args:
|
| 1288 |
config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache
|
| 1289 |
dimensions are derived from config so that a single source of truth governs
|
|
@@ -1310,11 +1317,19 @@ class ShramCache(Cache):
|
|
| 1310 |
]
|
| 1311 |
super().__init__(layers=layers)
|
| 1312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1313 |
# ---------------------------------------------------------------------------
|
| 1314 |
# Cache — composite-meaningful methods
|
| 1315 |
# ---------------------------------------------------------------------------
|
| 1316 |
#
|
| 1317 |
-
# reset():
|
|
|
|
| 1318 |
#
|
| 1319 |
# reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
|
| 1320 |
#
|
|
@@ -1322,6 +1337,40 @@ class ShramCache(Cache):
|
|
| 1322 |
# Since ShramLayerCache.is_initialized is True from construction, this is True
|
| 1323 |
# immediately after ShramCache.__init__ returns.
|
| 1324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1325 |
def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
|
| 1326 |
"""Return the cumulative sequence length for the specified layer.
|
| 1327 |
|
|
@@ -2191,6 +2240,7 @@ class BottleneckedEnsembleAttention(nn.Module):
|
|
| 2191 |
key_length=key_states.shape[2],
|
| 2192 |
device=packed_embeddings.device,
|
| 2193 |
)
|
|
|
|
| 2194 |
attended_states = flex_attention(
|
| 2195 |
rotated_query_states,
|
| 2196 |
key_states,
|
|
@@ -2836,7 +2886,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 2836 |
outputs.
|
| 2837 |
capacity_scalar: Static upper bound on n; used to derive topk k as
|
| 2838 |
min(tensor.shape[dim], capacity_scalar). Must be a Python int
|
| 2839 |
-
|
| 2840 |
|
| 2841 |
Returns:
|
| 2842 |
Boolean mask of the same shape as tensor.
|
|
@@ -4055,19 +4105,49 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4055 |
return attention_mask.to(dtype=torch.bool)
|
| 4056 |
|
| 4057 |
def _resolve_current_position_ids(
|
| 4058 |
-
|
| 4059 |
-
|
| 4060 |
-
|
| 4061 |
-
|
|
|
|
| 4062 |
) -> torch.LongTensor:
|
| 4063 |
-
"""Resolve concrete current-step position IDs for the backbone.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4064 |
if position_ids is not None:
|
| 4065 |
return position_ids.to(dtype=torch.long)
|
| 4066 |
|
| 4067 |
-
full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
|
| 4068 |
-
full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
|
| 4069 |
current_length = input_ids.shape[1]
|
| 4070 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4071 |
|
| 4072 |
def forward(
|
| 4073 |
self,
|
|
@@ -4172,12 +4252,13 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4172 |
)
|
| 4173 |
current_length: int = input_ids.shape[1]
|
| 4174 |
current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
|
|
|
|
| 4175 |
current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
|
| 4176 |
input_ids=input_ids,
|
| 4177 |
position_ids=position_ids,
|
| 4178 |
-
|
|
|
|
| 4179 |
)
|
| 4180 |
-
shram_cache: ShramCache | None = past_key_values if use_cache else None
|
| 4181 |
|
| 4182 |
if shram_cache is None:
|
| 4183 |
positions_start_sane = torch.all(current_position_ids[:, 0] == 0)
|
|
|
|
| 1284 |
layer have materially different update semantics; callers must update sub-caches directly
|
| 1285 |
via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
|
| 1286 |
|
| 1287 |
+
ShramCache also tracks per-batch cumulative active token counts via
|
| 1288 |
+
``_active_token_counts``. ``total_active_tokens(active_mask)`` returns the accumulated
|
| 1289 |
+
count before the current step and updates the buffer in-place; the caller uses this as a
|
| 1290 |
+
per-batch position bias for contiguous arange-based position ID resolution. All counter
|
| 1291 |
+
updates are in-place to satisfy CUDAGraph fixed-memory requirements. ``reset()``
|
| 1292 |
+
zeroes the buffer along with all layer caches.
|
| 1293 |
+
|
| 1294 |
Args:
|
| 1295 |
config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache
|
| 1296 |
dimensions are derived from config so that a single source of truth governs
|
|
|
|
| 1317 |
]
|
| 1318 |
super().__init__(layers=layers)
|
| 1319 |
|
| 1320 |
+
# Active token counter for position ID resolution (Unit 23.B). Pre-allocated
|
| 1321 |
+
# at construction so all updates remain in-place across forward passes,
|
| 1322 |
+
# satisfying CUDAGraph fixed-memory requirements.
|
| 1323 |
+
self._active_token_counts: torch.Tensor = torch.zeros(
|
| 1324 |
+
batch_size, dtype=torch.long, device=device
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
# ---------------------------------------------------------------------------
|
| 1328 |
# Cache — composite-meaningful methods
|
| 1329 |
# ---------------------------------------------------------------------------
|
| 1330 |
#
|
| 1331 |
+
# reset(): Overridden. Zeroes _active_token_counts in-place, then delegates to
|
| 1332 |
+
# the inherited implementation to reset all layer caches.
|
| 1333 |
#
|
| 1334 |
# reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
|
| 1335 |
#
|
|
|
|
| 1337 |
# Since ShramLayerCache.is_initialized is True from construction, this is True
|
| 1338 |
# immediately after ShramCache.__init__ returns.
|
| 1339 |
|
| 1340 |
+
def total_active_tokens(self, active_mask: torch.BoolTensor) -> torch.Tensor:
|
| 1341 |
+
"""Return the per-batch accumulated active token count before this step, then update.
|
| 1342 |
+
|
| 1343 |
+
Reads the current per-batch accumulated count as a position bias for the caller,
|
| 1344 |
+
then increments the internal counter in-place by the number of active tokens in
|
| 1345 |
+
``active_mask`` for each batch item. The pre-update count is returned so the
|
| 1346 |
+
caller can offset an arange-based position tensor to the correct starting position
|
| 1347 |
+
for this forward pass.
|
| 1348 |
+
|
| 1349 |
+
All updates are in-place to satisfy CUDAGraph fixed-memory requirements. The
|
| 1350 |
+
counter persists across forward passes until ``reset()`` is called.
|
| 1351 |
+
|
| 1352 |
+
Args:
|
| 1353 |
+
active_mask: Boolean mask of shape ``(B, N)`` for the current forward step,
|
| 1354 |
+
where True marks an active (non-padding) token position.
|
| 1355 |
+
|
| 1356 |
+
Returns:
|
| 1357 |
+
Integer tensor of shape ``(B,)`` — the accumulated count before this update.
|
| 1358 |
+
"""
|
| 1359 |
+
prior_counts = self._active_token_counts.clone()
|
| 1360 |
+
self._active_token_counts.add_(active_mask.sum(dim=-1))
|
| 1361 |
+
return prior_counts
|
| 1362 |
+
|
| 1363 |
+
def reset(self) -> None:
|
| 1364 |
+
"""Clear all layer caches and reset the active token counter.
|
| 1365 |
+
|
| 1366 |
+
Zeroes ``_active_token_counts`` in-place, then delegates to the inherited
|
| 1367 |
+
implementation to reset all ShramLayerCache instances. In-place mutation of
|
| 1368 |
+
the counter is required for CUDAGraph compatibility — the buffer must remain
|
| 1369 |
+
at the same memory address across steps.
|
| 1370 |
+
"""
|
| 1371 |
+
self._active_token_counts.zero_()
|
| 1372 |
+
super().reset()
|
| 1373 |
+
|
| 1374 |
def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
|
| 1375 |
"""Return the cumulative sequence length for the specified layer.
|
| 1376 |
|
|
|
|
| 2240 |
key_length=key_states.shape[2],
|
| 2241 |
device=packed_embeddings.device,
|
| 2242 |
)
|
| 2243 |
+
|
| 2244 |
attended_states = flex_attention(
|
| 2245 |
rotated_query_states,
|
| 2246 |
key_states,
|
|
|
|
| 2886 |
outputs.
|
| 2887 |
capacity_scalar: Static upper bound on n; used to derive topk k as
|
| 2888 |
min(tensor.shape[dim], capacity_scalar). Must be a Python int
|
| 2889 |
+
for compile compatibility.
|
| 2890 |
|
| 2891 |
Returns:
|
| 2892 |
Boolean mask of the same shape as tensor.
|
|
|
|
| 4105 |
return attention_mask.to(dtype=torch.bool)
|
| 4106 |
|
| 4107 |
def _resolve_current_position_ids(
|
| 4108 |
+
self,
|
| 4109 |
+
input_ids: torch.Tensor,
|
| 4110 |
+
position_ids: torch.Tensor | None,
|
| 4111 |
+
current_active_mask: torch.BoolTensor,
|
| 4112 |
+
cache: ShramCache | None,
|
| 4113 |
) -> torch.LongTensor:
|
| 4114 |
+
"""Resolve concrete current-step position IDs for the backbone.
|
| 4115 |
+
|
| 4116 |
+
Builds a fresh contiguous allocation via arange + per-batch bias. No cumsum
|
| 4117 |
+
or stride-based views are produced; the returned tensor is always a new
|
| 4118 |
+
allocation safe for Inductor tracing at the FlexAttention boundary.
|
| 4119 |
+
|
| 4120 |
+
When a cache is present, ``total_active_tokens()`` provides the per-batch
|
| 4121 |
+
accumulated active token count as a position bias. Uncached calls use a zero
|
| 4122 |
+
bias. In both cases positions are ``bias + arange(current_length)``, with
|
| 4123 |
+
inactive positions masked to 0.
|
| 4124 |
+
|
| 4125 |
+
Args:
|
| 4126 |
+
input_ids: Current token IDs of shape ``(B, N)``.
|
| 4127 |
+
position_ids: Explicit positions if supplied by the caller; returned
|
| 4128 |
+
unchanged (cast to long). Bias computation is skipped entirely.
|
| 4129 |
+
current_active_mask: Boolean mask of shape ``(B, N)`` for the current step.
|
| 4130 |
+
cache: Active ``ShramCache``, or ``None`` for uncached forward passes.
|
| 4131 |
+
|
| 4132 |
+
Returns:
|
| 4133 |
+
Long tensor of shape ``(B, N)`` — position index per token, 0 for inactive.
|
| 4134 |
+
"""
|
| 4135 |
if position_ids is not None:
|
| 4136 |
return position_ids.to(dtype=torch.long)
|
| 4137 |
|
|
|
|
|
|
|
| 4138 |
current_length = input_ids.shape[1]
|
| 4139 |
+
|
| 4140 |
+
if cache is not None:
|
| 4141 |
+
position_bias = cache.total_active_tokens(current_active_mask)
|
| 4142 |
+
else:
|
| 4143 |
+
position_bias = torch.zeros(
|
| 4144 |
+
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
| 4145 |
+
)
|
| 4146 |
+
|
| 4147 |
+
positions = position_bias.unsqueeze(1) + torch.arange(
|
| 4148 |
+
current_length, device=input_ids.device, dtype=torch.long
|
| 4149 |
+
)
|
| 4150 |
+
return positions.masked_fill(~current_active_mask, 0)
|
| 4151 |
|
| 4152 |
def forward(
|
| 4153 |
self,
|
|
|
|
| 4252 |
)
|
| 4253 |
current_length: int = input_ids.shape[1]
|
| 4254 |
current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
|
| 4255 |
+
shram_cache: ShramCache | None = past_key_values if use_cache else None
|
| 4256 |
current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
|
| 4257 |
input_ids=input_ids,
|
| 4258 |
position_ids=position_ids,
|
| 4259 |
+
current_active_mask=current_active_mask,
|
| 4260 |
+
cache=shram_cache,
|
| 4261 |
)
|
|
|
|
| 4262 |
|
| 4263 |
if shram_cache is None:
|
| 4264 |
positions_start_sane = torch.all(current_position_ids[:, 0] == 0)
|