Instructions to use JetLM/SDAR-1.7B-Chat with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use JetLM/SDAR-1.7B-Chat with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="JetLM/SDAR-1.7B-Chat", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("JetLM/SDAR-1.7B-Chat", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use JetLM/SDAR-1.7B-Chat with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "JetLM/SDAR-1.7B-Chat" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "JetLM/SDAR-1.7B-Chat", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/JetLM/SDAR-1.7B-Chat
- SGLang
How to use JetLM/SDAR-1.7B-Chat with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "JetLM/SDAR-1.7B-Chat" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "JetLM/SDAR-1.7B-Chat", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "JetLM/SDAR-1.7B-Chat" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "JetLM/SDAR-1.7B-Chat", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use JetLM/SDAR-1.7B-Chat with Docker Model Runner:
docker model run hf.co/JetLM/SDAR-1.7B-Chat
remove LossKwargs
#1
by kashif HF Staff - opened
- modeling_sdar.py +65 -23
modeling_sdar.py
CHANGED
|
@@ -43,7 +43,7 @@ from transformers.modeling_outputs import (
|
|
| 43 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 44 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 45 |
from transformers.processing_utils import Unpack
|
| 46 |
-
from transformers.utils import
|
| 47 |
from .configuration_sdar import SDARConfig
|
| 48 |
|
| 49 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
|
@@ -261,22 +261,41 @@ class SDARAttention(nn.Module):
|
|
| 261 |
query_states, key_states = apply_rotary_pos_emb(
|
| 262 |
query_states, key_states, cos, sin)
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
| 266 |
key_states, value_states = past_key_value.update(
|
| 267 |
key_states, value_states, self.layer_idx)
|
| 268 |
-
elif past_key_value is not None and not kwargs.get("store_kv", False) and len(past_key_value) > self.layer_idx:
|
| 269 |
-
# only retrive, do not store kv
|
| 270 |
-
past_key_states, past_value_states = past_key_value[self.layer_idx]
|
| 271 |
-
key_states = torch.cat(
|
| 272 |
-
[past_key_states, key_states], dim=-2
|
| 273 |
-
)
|
| 274 |
-
value_states = torch.cat(
|
| 275 |
-
[past_value_states, value_states], dim=-2
|
| 276 |
-
)
|
| 277 |
|
| 278 |
attention_mask = attention_mask.bool() if attention_mask is not None else None
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
query_states = query_states.transpose(1, 2)
|
| 281 |
key_states = key_states.transpose(1, 2)
|
| 282 |
value_states = value_states.transpose(1, 2)
|
|
@@ -329,7 +348,6 @@ class SDARDecoderLayer(GradientCheckpointingLayer):
|
|
| 329 |
past_key_value: Optional[Cache] = None,
|
| 330 |
output_attentions: Optional[bool] = False,
|
| 331 |
use_cache: Optional[bool] = False,
|
| 332 |
-
store_kv: Optional[bool] = False,
|
| 333 |
cache_position: Optional[torch.LongTensor] = None,
|
| 334 |
# necessary, but kept here for BC
|
| 335 |
position_embeddings: Optional[Tuple[torch.Tensor,
|
|
@@ -347,7 +365,6 @@ class SDARDecoderLayer(GradientCheckpointingLayer):
|
|
| 347 |
past_key_value=past_key_value,
|
| 348 |
output_attentions=output_attentions,
|
| 349 |
use_cache=use_cache,
|
| 350 |
-
store_kv=store_kv,
|
| 351 |
cache_position=cache_position,
|
| 352 |
position_embeddings=position_embeddings,
|
| 353 |
**kwargs,
|
|
@@ -394,9 +411,27 @@ class SDARPreTrainedModel(PreTrainedModel):
|
|
| 394 |
module.weight.data[module.padding_idx].zero_()
|
| 395 |
elif isinstance(module, SDARRMSNorm):
|
| 396 |
module.weight.data.fill_(1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
|
| 399 |
class SDARRotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
def __init__(self, config: SDARConfig, device=None):
|
| 401 |
super().__init__()
|
| 402 |
# BC: "rope_type" was originally "type"
|
|
@@ -409,12 +444,18 @@ class SDARRotaryEmbedding(nn.Module):
|
|
| 409 |
self.original_max_seq_len = config.max_position_embeddings
|
| 410 |
|
| 411 |
self.config = config
|
| 412 |
-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 413 |
|
| 414 |
-
|
| 415 |
-
self.config, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 417 |
-
self.original_inv_freq
|
| 418 |
|
| 419 |
@torch.no_grad()
|
| 420 |
# power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
@@ -440,7 +481,10 @@ class SDARRotaryEmbedding(nn.Module):
|
|
| 440 |
class SDARModel(SDARPreTrainedModel):
|
| 441 |
def __init__(self, config: SDARConfig):
|
| 442 |
super().__init__(config)
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
| 444 |
self.vocab_size = config.vocab_size
|
| 445 |
|
| 446 |
self.embed_tokens = nn.Embedding(
|
|
@@ -472,7 +516,6 @@ class SDARModel(SDARPreTrainedModel):
|
|
| 472 |
past_key_values: Optional[Cache] = None,
|
| 473 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 474 |
use_cache: Optional[bool] = None,
|
| 475 |
-
store_kv: Optional[bool] = None,
|
| 476 |
output_attentions: Optional[bool] = None,
|
| 477 |
output_hidden_states: Optional[bool] = None,
|
| 478 |
cache_position: Optional[torch.LongTensor] = None,
|
|
@@ -539,7 +582,6 @@ class SDARModel(SDARPreTrainedModel):
|
|
| 539 |
past_key_value=past_key_values,
|
| 540 |
output_attentions=output_attentions,
|
| 541 |
use_cache=use_cache,
|
| 542 |
-
store_kv=store_kv,
|
| 543 |
cache_position=cache_position,
|
| 544 |
position_embeddings=position_embeddings,
|
| 545 |
**flash_attn_kwargs,
|
|
@@ -734,7 +776,7 @@ class SDARModel(SDARPreTrainedModel):
|
|
| 734 |
return causal_mask
|
| 735 |
|
| 736 |
|
| 737 |
-
class KwargsForCausalLM(FlashAttentionKwargs
|
| 738 |
...
|
| 739 |
|
| 740 |
|
|
|
|
| 43 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 44 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 45 |
from transformers.processing_utils import Unpack
|
| 46 |
+
from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
| 47 |
from .configuration_sdar import SDARConfig
|
| 48 |
|
| 49 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
|
|
|
| 261 |
query_states, key_states = apply_rotary_pos_emb(
|
| 262 |
query_states, key_states, cos, sin)
|
| 263 |
|
| 264 |
+
# Standard transformers v5 cache convention: when a cache is provided, always `.update()` it.
|
| 265 |
+
# Callers that want a read-only forward should pass `past_key_values=None`, or use
|
| 266 |
+
# `DynamicCache.crop(prev_seq_len)` to roll back the append after reading the logits.
|
| 267 |
+
if past_key_value is not None:
|
| 268 |
key_states, value_states = past_key_value.update(
|
| 269 |
key_states, value_states, self.layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
attention_mask = attention_mask.bool() if attention_mask is not None else None
|
| 272 |
+
|
| 273 |
+
# I-DLM / strict-causal mode: rely on PyTorch's built-in `is_causal=True` path so GQA
|
| 274 |
+
# broadcasting works cleanly with a KV cache (query q_len ≠ key k_len). We compute a
|
| 275 |
+
# per-query offset such that `is_causal=True` masks against key position `q + offset`,
|
| 276 |
+
# matching the Dream-shifted causal-LM convention.
|
| 277 |
+
use_regular_causal = bool(getattr(self.config, "use_regular_causal", False))
|
| 278 |
+
if use_regular_causal:
|
| 279 |
+
q_len = query_states.shape[-2]
|
| 280 |
+
k_len = key_states.shape[-2]
|
| 281 |
+
if q_len == k_len:
|
| 282 |
+
attn_output = F.scaled_dot_product_attention(
|
| 283 |
+
query=query_states, key=key_states, value=value_states,
|
| 284 |
+
is_causal=True, scale=self.scaling, enable_gqa=True,
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
# Non-square causal: build a (q_len, k_len) mask where row `i` attends to key
|
| 288 |
+
# positions `0..k_len - q_len + i`. Works for any cache state.
|
| 289 |
+
offset = k_len - q_len
|
| 290 |
+
rows = torch.arange(q_len, device=query_states.device).unsqueeze(1)
|
| 291 |
+
cols = torch.arange(k_len, device=query_states.device).unsqueeze(0)
|
| 292 |
+
causal_mask = cols <= rows + offset # [q_len, k_len]
|
| 293 |
+
attn_output = F.scaled_dot_product_attention(
|
| 294 |
+
query=query_states, key=key_states, value=value_states,
|
| 295 |
+
attn_mask=causal_mask, is_causal=False, scale=self.scaling, enable_gqa=True,
|
| 296 |
+
)
|
| 297 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 298 |
+
elif attention_mask is not None and torch.all(attention_mask): # decoding
|
| 299 |
query_states = query_states.transpose(1, 2)
|
| 300 |
key_states = key_states.transpose(1, 2)
|
| 301 |
value_states = value_states.transpose(1, 2)
|
|
|
|
| 348 |
past_key_value: Optional[Cache] = None,
|
| 349 |
output_attentions: Optional[bool] = False,
|
| 350 |
use_cache: Optional[bool] = False,
|
|
|
|
| 351 |
cache_position: Optional[torch.LongTensor] = None,
|
| 352 |
# necessary, but kept here for BC
|
| 353 |
position_embeddings: Optional[Tuple[torch.Tensor,
|
|
|
|
| 365 |
past_key_value=past_key_value,
|
| 366 |
output_attentions=output_attentions,
|
| 367 |
use_cache=use_cache,
|
|
|
|
| 368 |
cache_position=cache_position,
|
| 369 |
position_embeddings=position_embeddings,
|
| 370 |
**kwargs,
|
|
|
|
| 411 |
module.weight.data[module.padding_idx].zero_()
|
| 412 |
elif isinstance(module, SDARRMSNorm):
|
| 413 |
module.weight.data.fill_(1.0)
|
| 414 |
+
# Delegate rotary-embedding buffer re-init to the base PreTrainedModel, which handles
|
| 415 |
+
# transformers v5's meta-device load by recomputing inv_freq via compute_default_rope_parameters.
|
| 416 |
+
else:
|
| 417 |
+
super()._init_weights(module)
|
| 418 |
|
| 419 |
|
| 420 |
class SDARRotaryEmbedding(nn.Module):
|
| 421 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 422 |
+
|
| 423 |
+
@staticmethod
|
| 424 |
+
def compute_default_rope_parameters(config, device=None, seq_len=None):
|
| 425 |
+
# transformers v5 removed "default" from ROPE_INIT_FUNCTIONS; match the Qwen3 implementation.
|
| 426 |
+
base = getattr(config, "rope_theta", None)
|
| 427 |
+
if base is None:
|
| 428 |
+
base = config.rope_parameters["rope_theta"]
|
| 429 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 430 |
+
inv_freq = 1.0 / (
|
| 431 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 432 |
+
)
|
| 433 |
+
return inv_freq, 1.0
|
| 434 |
+
|
| 435 |
def __init__(self, config: SDARConfig, device=None):
|
| 436 |
super().__init__()
|
| 437 |
# BC: "rope_type" was originally "type"
|
|
|
|
| 444 |
self.original_max_seq_len = config.max_position_embeddings
|
| 445 |
|
| 446 |
self.config = config
|
|
|
|
| 447 |
|
| 448 |
+
if self.rope_type == "default":
|
| 449 |
+
inv_freq, self.attention_scaling = self.compute_default_rope_parameters(config, device)
|
| 450 |
+
else:
|
| 451 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 452 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 453 |
+
|
| 454 |
+
# Register both as buffers — transformers v5's `_move_missing_keys_from_meta_to_device`
|
| 455 |
+
# replaces non-persistent buffers with `torch.empty_like` (uninitialized / zeros); the base
|
| 456 |
+
# `_init_weights` then re-copies into them IF they're buffers with `original_inv_freq` present.
|
| 457 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 458 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 459 |
|
| 460 |
@torch.no_grad()
|
| 461 |
# power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
|
|
| 481 |
class SDARModel(SDARPreTrainedModel):
|
| 482 |
def __init__(self, config: SDARConfig):
|
| 483 |
super().__init__(config)
|
| 484 |
+
# transformers v5 configs may not have pad_token_id; fall back to eos_token_id.
|
| 485 |
+
self.padding_idx = getattr(config, "pad_token_id", None)
|
| 486 |
+
if self.padding_idx is None:
|
| 487 |
+
self.padding_idx = getattr(config, "eos_token_id", None)
|
| 488 |
self.vocab_size = config.vocab_size
|
| 489 |
|
| 490 |
self.embed_tokens = nn.Embedding(
|
|
|
|
| 516 |
past_key_values: Optional[Cache] = None,
|
| 517 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 518 |
use_cache: Optional[bool] = None,
|
|
|
|
| 519 |
output_attentions: Optional[bool] = None,
|
| 520 |
output_hidden_states: Optional[bool] = None,
|
| 521 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
| 582 |
past_key_value=past_key_values,
|
| 583 |
output_attentions=output_attentions,
|
| 584 |
use_cache=use_cache,
|
|
|
|
| 585 |
cache_position=cache_position,
|
| 586 |
position_embeddings=position_embeddings,
|
| 587 |
**flash_attn_kwargs,
|
|
|
|
| 776 |
return causal_mask
|
| 777 |
|
| 778 |
|
| 779 |
+
class KwargsForCausalLM(FlashAttentionKwargs):
|
| 780 |
...
|
| 781 |
|
| 782 |
|