Instructions to use BeardedMonster/SabiYarn-125M-topic with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BeardedMonster/SabiYarn-125M-topic with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="BeardedMonster/SabiYarn-125M-topic", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("BeardedMonster/SabiYarn-125M-topic", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use BeardedMonster/SabiYarn-125M-topic with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "BeardedMonster/SabiYarn-125M-topic" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BeardedMonster/SabiYarn-125M-topic", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/BeardedMonster/SabiYarn-125M-topic
- SGLang
How to use BeardedMonster/SabiYarn-125M-topic with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "BeardedMonster/SabiYarn-125M-topic" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BeardedMonster/SabiYarn-125M-topic", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "BeardedMonster/SabiYarn-125M-topic" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "BeardedMonster/SabiYarn-125M-topic", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use BeardedMonster/SabiYarn-125M-topic with Docker Model Runner:
docker model run hf.co/BeardedMonster/SabiYarn-125M-topic
| """ | |
| SabiYarn Model Implementation - Optimized Version | |
| Memory-efficient with performance optimizations for generation. | |
| Matches original implementation exactly but with memory optimizations. | |
| """ | |
| from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoModelForCausalLM | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| # use package-relative import to avoid colliding with unrelated `model` packages | |
| from .configuration import GPTJXMoEConfig | |
| from typing import Optional, List, Tuple | |
| from torch import nn | |
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| class LayerNorm(nn.Module): | |
| """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ | |
| def __init__(self, ndim, bias): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(ndim)) | |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None | |
| def forward(self, input): | |
| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.n_embd % config.n_heads == 0 | |
| # key, query, value projections for all heads, but in a batch | |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
| # output projection | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
| # regularization | |
| self.attn_dropout = nn.Dropout(config.dropout) | |
| self.resid_dropout = nn.Dropout(config.dropout) | |
| self.n_heads = config.n_heads | |
| self.n_embd = config.n_embd | |
| self.head_dim = config.n_embd // config.n_heads | |
| self.dropout = config.dropout | |
| # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 | |
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') | |
| def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False): | |
| """ | |
| Forward pass with optional KV cache support. | |
| Args: | |
| x: (B, T, C) input embeddings | |
| attn_mask: Optional attention mask | |
| past_key_value: Optional tuple of (past_k, past_v) each (B, nh, past_len, hs) | |
| use_cache: Whether to return cache for next step | |
| Returns: | |
| If use_cache: (output, (k, v)) where output is (B, T, C) and k, v are (B, nh, total_len, hs) | |
| Else: output (B, T, C) | |
| """ | |
| B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
| k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs) | |
| q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs) | |
| v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs) | |
| # Concatenate with past KV cache if provided | |
| if past_key_value is not None: | |
| past_k, past_v = past_key_value | |
| k = torch.cat([past_k, k], dim=2) # (B, nh, past_len + T, hs) | |
| v = torch.cat([past_v, v], dim=2) | |
| # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, total_len) -> (B, nh, T, total_len) | |
| total_len = k.size(2) | |
| if self.flash: | |
| if attn_mask is not None: | |
| # efficient attention using Flash Attention CUDA kernels | |
| attn_mask = attn_mask.to(torch.bool) | |
| # Handle different mask shapes and convert to (B, nh, T, total_len) | |
| if attn_mask.dim() == 2: | |
| # (B, S) - expand to cover full sequence if needed | |
| B_mask = attn_mask.size(0) | |
| S = attn_mask.size(1) | |
| if S == total_len: | |
| # Mask already covers full sequence | |
| pass | |
| elif S == T: | |
| # Mask only covers current tokens - expand with ones for past tokens | |
| if past_key_value is not None: | |
| past_len = total_len - T | |
| past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=attn_mask.dtype) | |
| attn_mask = torch.cat([past_mask, attn_mask], dim=1) | |
| else: | |
| # No cache, mask is correct as-is | |
| pass | |
| else: | |
| raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})") | |
| # Reshape to (B, 1, T, total_len) for Flash Attention | |
| # Flash Attention expects mask shape (B, nh, T, S) where T is query length | |
| # First ensure we have the right length | |
| if attn_mask.size(1) != total_len: | |
| raise ValueError(f"Mask length mismatch: got {attn_mask.size(1)}, expected {total_len}") | |
| # Reshape: (B, total_len) -> (B, 1, 1, total_len) -> (B, 1, T, total_len) -> (B, nh, T, total_len) | |
| attn_mask = attn_mask.view(B_mask, 1, 1, total_len) | |
| # Expand to (B, 1, T, total_len) - repeat for each query position | |
| attn_mask = attn_mask.expand(B_mask, 1, T, total_len) | |
| # Expand to include head dimension: (B, nh, T, total_len) | |
| attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) | |
| # Verify final shape | |
| assert attn_mask.shape == (B_mask, self.n_heads, T, total_len), \ | |
| f"Mask shape mismatch: got {attn_mask.shape}, expected ({B_mask}, {self.n_heads}, {T}, {total_len})" | |
| elif attn_mask.dim() == 4: | |
| # Already 4D mask - ensure it's the right shape | |
| B_mask = attn_mask.size(0) | |
| if attn_mask.size(-2) != T: | |
| # Slice to match query length if needed | |
| attn_mask = attn_mask[..., -T:, :] | |
| # Ensure head dimension matches | |
| if attn_mask.size(1) == 1: | |
| attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) | |
| elif attn_mask.size(1) != self.n_heads: | |
| raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}") | |
| else: | |
| raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4") | |
| y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=False) | |
| else: | |
| # No explicit mask provided | |
| if past_key_value is None: | |
| # No cache: use is_causal for efficiency | |
| y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) | |
| else: | |
| # With cache: create causal mask manually (can't use is_causal when q and k have different lengths) | |
| causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool)) | |
| y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask.view(1, 1, T, total_len), dropout_p=self.dropout if self.training else 0, is_causal=False) | |
| else: | |
| # manual implementation of attention | |
| total_len = k.size(2) | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.to(torch.bool) | |
| # Handle different mask shapes and convert to (B, nh, T, total_len) | |
| if attn_mask.dim() == 2: | |
| # (B, S) - expand to cover full sequence if needed | |
| B_mask = attn_mask.size(0) | |
| S = attn_mask.size(1) | |
| if S == total_len: | |
| # Mask already covers full sequence | |
| pass | |
| elif S == T: | |
| # Mask only covers current tokens - expand with ones for past tokens | |
| if past_key_value is not None: | |
| past_len = total_len - T | |
| past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=torch.bool) | |
| attn_mask = torch.cat([past_mask, attn_mask], dim=1) | |
| else: | |
| # No cache, mask is correct as-is | |
| pass | |
| else: | |
| raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})") | |
| # Reshape to (B, 1, T, total_len) then expand to (B, nh, T, total_len) | |
| attn_mask = attn_mask.view(B_mask, 1, 1, total_len) | |
| attn_mask = attn_mask.expand(B_mask, 1, T, total_len) | |
| attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) | |
| elif attn_mask.dim() == 4: | |
| # Already 4D mask - ensure it's the right shape | |
| B_mask = attn_mask.size(0) | |
| if attn_mask.size(-2) != T: | |
| # Slice to match query length if needed | |
| attn_mask = attn_mask[..., -T:, :] | |
| # Ensure head dimension matches | |
| if attn_mask.size(1) == 1: | |
| attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) | |
| elif attn_mask.size(1) != self.n_heads: | |
| raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}") | |
| else: | |
| raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4") | |
| att = att.masked_fill(~attn_mask, float('-inf')) | |
| else: | |
| # Apply causal mask - created on-the-fly (memory efficient, scales to any length) | |
| # torch.tril() is fast and doesn't require storing large buffers | |
| # This approach works for 32k, 1M, or any context length | |
| causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool)) | |
| att = att.masked_fill(~causal_mask.view(1, 1, T, total_len), float('-inf')) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_dropout(att) | |
| y = att @ v # (B, nh, T, total_len) x (B, nh, total_len, hs) -> (B, nh, T, hs) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # output projection | |
| y = self.resid_dropout(self.c_proj(y)) | |
| # Return cache if requested | |
| if use_cache: | |
| return y, (k.detach(), v.detach()) | |
| return y | |
| class MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) | |
| self.gelu = nn.GELU() | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = self.gelu(x) | |
| x = self.c_proj(x) | |
| x = self.dropout(x) | |
| return x | |
| class BlockJ(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) | |
| self.j = LayerNorm(config.n_embd, config.n_embd) | |
| self.attn = CausalSelfAttention(config) | |
| self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) | |
| # Use MoE if configured, otherwise use dense MLP | |
| if getattr(config, 'use_moe', False): | |
| self.mlp = MoE( | |
| num_experts_per_tok=config.num_experts_per_tok, | |
| num_experts=config.num_experts, | |
| emb_dim=config.n_embd, | |
| moe_dim=config.moe_dim, | |
| dropout=config.dropout | |
| ) | |
| self.use_moe = True | |
| else: | |
| self.mlp = MLP(config) | |
| self.use_moe = False | |
| def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False): | |
| """ | |
| Forward pass with optional KV cache support. | |
| Args: | |
| x: (B, T, C) input embeddings | |
| attn_mask: Optional attention mask | |
| past_key_value: Optional tuple of (past_k, past_v) for attention layer | |
| use_cache: Whether to return cache for next step | |
| Returns: | |
| If use_cache: (output, (k, v)) where output is (B, T, C) | |
| Else: output (B, T, C) | |
| """ | |
| h = x | |
| x_ln = self.ln_1(x) | |
| # Attention with optional KV cache | |
| if use_cache: | |
| attn_out, new_past = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=True) | |
| x = h + attn_out + self.j(x_ln) | |
| else: | |
| attn_out = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=False) | |
| x = h + attn_out + self.j(x_ln) | |
| x = x + self.mlp(self.ln_2(x)) | |
| if use_cache: | |
| return x, new_past | |
| return x | |
| class MoE(nn.Module): | |
| """ | |
| An MoE layer with MLP block with swiglue activation function. | |
| Optimized for production workflows with proper initialization and dropout support. | |
| """ | |
| def __init__(self, num_experts_per_tok: int, num_experts: int, emb_dim: int, moe_dim: int, dropout: float = 0.0, dtype=torch.float32): | |
| super().__init__() | |
| self.k = int(num_experts_per_tok) | |
| self.E = int(num_experts) | |
| self.D = int(emb_dim) | |
| self.H = int(moe_dim) | |
| self.dropout = dropout | |
| self.gate = nn.Linear(self.D, self.E, bias=False, dtype=dtype) # use gate variable bcause couldnt load from checkpoint | |
| # Match MLP structure: c_fc -> GELU -> c_proj | |
| self.fc_bank = nn.Parameter(torch.empty(self.E, self.D, self.H, dtype=dtype)) # Equivalent to c_fc: (n_embd -> 4*n_embd) | |
| self.proj_bank = nn.Parameter(torch.empty(self.E, self.H, self.D, dtype=dtype)) # Equivalent to c_proj: (4*n_embd -> n_embd) | |
| self.gelu = nn.GELU() # Match MLP activation | |
| self.dropout_layer = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() | |
| # Initialize parameters | |
| self._init_parameters() | |
| def expert_utilization(self, logits): | |
| """ | |
| This function compute expert utilization per token and also compute load balancer loss. | |
| Details of this load balancer can be found in https://arxiv.org/abs/2101.03961 | |
| """ | |
| _, selected = logits.topk(self.k, dim=-1) | |
| selected = F.one_hot(selected, num_classes=self.E).sum(dim=2) # B, T, E | |
| load = torch.mean(selected.float(), dim=(0,1)) | |
| # average router probability per expert | |
| P = torch.softmax(logits, dim=-1).float().mean(dim=(0,1)) # [E] | |
| self._router_probs = P.detach() # per-expert avg prob | |
| self._aux_lb = self.E * torch.sum(load * P) | |
| self._expert_utilization = load | |
| def _init_parameters(self): | |
| """Initialize MoE parameters following standard practices.""" | |
| # Initialize gate with small values to start with uniform routing | |
| nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) | |
| # Initialize expert banks to match MLP initialization | |
| # fc_bank: standard normal (like c_fc in MLP) | |
| nn.init.normal_(self.fc_bank, mean=0.0, std=0.02) | |
| # proj_bank: smaller initialization for stability (like c_proj in MLP) | |
| nn.init.normal_(self.proj_bank, mean=0.0, std=0.02 / math.sqrt(2)) | |
| def forward(self, x): | |
| B, T, D = x.shape | |
| assert D == self.D, f"Expected emb_dim={self.D}, got {D}" | |
| logits = self.gate(x) # B, T, E | |
| if self.training: | |
| logits = logits + torch.randn_like(logits) * 1e-1 | |
| topk_logits, selected = logits.topk(self.k, dim=-1) | |
| topk_probs = F.softmax(topk_logits, dim=-1) | |
| # Match MLP structure exactly: c_fc -> GELU -> c_proj | |
| # Step 1: c_fc equivalent: x @ fc_bank -> (B, T, E, H) | |
| h = torch.einsum("btd,edh->bteh", x, self.fc_bank) # B, T, E, H | |
| # Step 2: GELU activation (matching MLP) | |
| h = self.gelu(h) # B, T, E, H | |
| # Step 3: c_proj equivalent: h @ proj_bank -> (B, T, E, D) | |
| y = torch.einsum("bteh,ehd->bted", h, self.proj_bank) # B, T, E, D | |
| # Step 4: Select top-k experts and combine | |
| gather_idx = selected.view(B, T, -1, 1).expand(-1, -1, -1, self.D) # B, T, K, D | |
| y = torch.gather(y, dim=2, index=gather_idx) # B, T, K, D | |
| # Step 5: Weighted sum of selected experts | |
| y = (y * topk_probs.unsqueeze(-1)).sum(dim=2) # B, T, D | |
| # Step 6: Apply dropout like MLP | |
| y = self.dropout_layer(y) | |
| self.expert_utilization(logits) | |
| return y | |
| class GPTJXMoEForCausalLM(PreTrainedModel): | |
| config_class = GPTJXMoEConfig | |
| base_model_prefix = "transformer" | |
| is_parallelizable = True | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["BlockJ"] | |
| # _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| assert config.vocab_size is not None | |
| assert config.block_size is not None | |
| self.config = config | |
| self.transformer = nn.ModuleDict(dict( | |
| wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| wpe = nn.Embedding(config.block_size, config.n_embd), | |
| drop = nn.Dropout(config.dropout), | |
| h = nn.ModuleList([BlockJ(config) for _ in range(config.n_layer)]), | |
| ln_f = LayerNorm(config.n_embd, bias=config.bias), | |
| )) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| self.transformer.wte.weight = self.lm_head.weight | |
| # No need to store causal mask buffer - masks are created on-the-fly when needed | |
| # Flash Attention handles causality internally with is_causal=True | |
| # For manual attention, torch.tril() creates masks efficiently on-the-fly | |
| # This approach scales to any context length (1M+ tokens) without memory overhead | |
| self.apply(self._init_weights) | |
| for pn, p in self.named_parameters(): | |
| if pn.endswith('c_proj.weight'): | |
| torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) | |
| print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) | |
| def get_num_params(self, non_embedding=True): | |
| """ | |
| Return the number of parameters in the model. | |
| For non-embedding count (default), the position embeddings get subtracted. | |
| The token embeddings would too, except due to the parameter sharing these | |
| params are actually used as weights in the final layer, so we include them. | |
| """ | |
| n_params = sum(p.numel() for p in self.parameters()) | |
| if non_embedding: | |
| n_params -= self.transformer.wpe.weight.numel() | |
| return n_params | |
| def get_expert_utilization(self): | |
| """ | |
| Get expert utilization statistics for MoE layers. | |
| Returns expert utilization per layer and load balancing loss. | |
| Only works when use_moe=True in config. | |
| """ | |
| if not getattr(self.config, 'use_moe', False): | |
| return None, None | |
| lb_loss, expert_utilization_per_layer = 0, [] | |
| moe_layers = 0 | |
| for block in self.transformer.h: | |
| if hasattr(block, 'use_moe') and block.use_moe and hasattr(block.mlp, '_aux_lb'): | |
| lb_loss += block.mlp._aux_lb | |
| expert_utilization_per_layer.append(block.mlp._expert_utilization.detach().cpu()) | |
| moe_layers += 1 | |
| if moe_layers > 0: | |
| lb_loss = lb_loss / moe_layers | |
| return expert_utilization_per_layer, lb_loss | |
| def get_input_embeddings(self): | |
| return self.transformer.wte | |
| def set_input_embeddings(self, new_embeddings): | |
| self.transformer.wte = new_embeddings | |
| def forward( | |
| self, | |
| input_ids, | |
| targets=None, | |
| attn_mask=None, | |
| attention_mask=None, # HF standard name | |
| past_key_values=None, | |
| position_ids=None, | |
| use_cache=None, | |
| output_hidden_states: Optional[bool] = None, | |
| **kwargs | |
| ): | |
| """ | |
| Forward pass with KV cache support for efficient generation. | |
| Args: | |
| input_ids: (B, T) Token indices | |
| targets: Optional (B, T) target token indices for training | |
| attn_mask: Optional attention mask (legacy name) | |
| attention_mask: Optional attention mask (HF standard name, takes precedence) | |
| past_key_values: Optional list of (k, v) tuples from previous steps for KV cache | |
| position_ids: Optional (B, T) position indices (if None, computed from past_key_values) | |
| use_cache: Whether to return past_key_values for next step (defaults to config.use_kv_cache) | |
| output_hidden_states: Whether to return hidden states | |
| Returns: | |
| CausalLMOutputWithPast with logits and optionally past_key_values | |
| """ | |
| device = input_ids.device | |
| b, t = input_ids.size() | |
| # Use attention_mask if provided (HF standard), otherwise fall back to attn_mask | |
| if attention_mask is not None: | |
| attn_mask = attention_mask | |
| # Determine if we're using KV cache | |
| use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False) | |
| # Compute past sequence length if using cache | |
| past_len = 0 | |
| if past_key_values is not None: | |
| past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0 | |
| # Handle position_ids | |
| if position_ids is None: | |
| # Compute position IDs: from past_len to past_len + t | |
| pos = torch.arange(past_len, past_len + t, dtype=torch.long, device=device) | |
| else: | |
| pos = position_ids | |
| # Validate sequence length | |
| total_len = past_len + t | |
| assert total_len <= self.config.block_size, f"Cannot forward sequence of length {total_len}, block size is only {self.config.block_size}" | |
| # forward the GPT model itself | |
| tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd) | |
| # Handle position embeddings: wpe expects 1D position indices | |
| if pos.dim() == 2: | |
| # If position_ids is 2D (B, T), extract first row (assuming all sequences have same positions) | |
| pos_1d = pos[0] if pos.size(0) > 0 else pos.squeeze(0) | |
| else: | |
| pos_1d = pos | |
| pos_emb = self.transformer.wpe(pos_1d) # position embeddings of shape (t, n_embd) | |
| if pos_emb.dim() == 2: | |
| pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1) # Expand to (b, t, n_embd) | |
| x = self.transformer.drop(tok_emb + pos_emb) | |
| # Expand attention_mask to cover full sequence (past + current) if needed | |
| # HF's generation API may provide mask only for current tokens | |
| if attn_mask is not None and past_key_values is not None and use_kv_cache: | |
| # Check if mask needs expansion | |
| if attn_mask.dim() == 2: | |
| mask_len = attn_mask.size(1) | |
| if mask_len == t and total_len > t: | |
| # Mask only covers current tokens, expand with ones for past tokens | |
| past_len = total_len - t | |
| past_mask = torch.ones(b, past_len, device=device, dtype=attn_mask.dtype) | |
| attn_mask = torch.cat([past_mask, attn_mask], dim=1) | |
| # Process through transformer layers with KV cache | |
| new_past_key_values = [] if use_kv_cache else None | |
| for i, block in enumerate(self.transformer.h): | |
| layer_past = past_key_values[i] if past_key_values is not None else None | |
| if use_kv_cache: | |
| x, new_past = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=True) | |
| new_past_key_values.append(new_past) | |
| else: | |
| x = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=False) | |
| x = self.transformer.ln_f(x) | |
| # Compute logits and loss | |
| if targets is not None: | |
| # Training: compute logits for all positions | |
| logits = self.lm_head(x) | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) | |
| else: | |
| # Inference: only compute logits for last position when using cache, all positions otherwise | |
| if use_kv_cache and past_key_values is not None: | |
| logits = self.lm_head(x[:, [-1], :]) # Only last token | |
| else: | |
| logits = self.lm_head(x) # All tokens | |
| loss = None | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=tuple(new_past_key_values) if use_kv_cache else None, | |
| hidden_states=x if output_hidden_states else None, | |
| attentions=None, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| attention_mask=None, | |
| past_key_values=None, | |
| position_ids=None, | |
| use_cache=None, | |
| **kwargs | |
| ): | |
| """ | |
| Prepare inputs for generation with KV cache support. | |
| This method is called by HF's generation API. | |
| """ | |
| # Determine if we should use cache | |
| use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False) | |
| # Base model inputs | |
| model_inputs = { | |
| "input_ids": input_ids, | |
| } | |
| # ---- 1. Handle KV cache (past_key_values) ---- | |
| if past_key_values is not None and use_kv_cache: | |
| # Only feed the last token when using cached keys/values | |
| model_inputs["input_ids"] = input_ids[:, -1:] | |
| model_inputs["past_key_values"] = past_key_values | |
| # ---- 2. Handle attention mask ---- | |
| if attention_mask is not None: | |
| # When using cache, attention_mask should cover the full sequence (past + current) | |
| if past_key_values is not None and use_kv_cache: | |
| # Extend attention mask to include past tokens | |
| # HF generation will handle this, but we ensure it's passed through | |
| pass | |
| model_inputs["attention_mask"] = attention_mask | |
| # ---- 3. Handle position_ids correctly ---- | |
| # HF relies on this for models like GPT-J, GPT-NeoX, Llama, etc. | |
| if position_ids is not None: | |
| if past_key_values is not None and use_kv_cache: | |
| # Only use the last position when using cache | |
| position_ids = position_ids[:, -1].unsqueeze(-1) | |
| model_inputs["position_ids"] = position_ids | |
| elif past_key_values is not None and use_kv_cache: | |
| # Compute position_ids from past_key_values length | |
| past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0 | |
| model_inputs["position_ids"] = torch.tensor([[past_len]], device=input_ids.device, dtype=torch.long) | |
| # ---- 4. Forward arbitrary extra kwargs safely ---- | |
| # For example: use_cache, output_attentions, token_type_ids, etc. | |
| if use_cache is not None: | |
| model_inputs["use_cache"] = use_cache | |
| for k, v in kwargs.items(): | |
| if v is not None: | |
| model_inputs[k] = v | |
| return model_inputs | |
| def _reorder_cache( | |
| self, | |
| past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], | |
| beam_idx: torch.Tensor, | |
| ) -> List[Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Reorder cache for beam search. | |
| Required by HF for beam search to work correctly. | |
| Selects which beam samples to keep based on beam_idx. | |
| Args: | |
| past_key_values: List of (k, v) tuples from previous steps | |
| beam_idx: (batch_size,) tensor indicating which beams to keep | |
| Returns: | |
| Reordered past_key_values | |
| """ | |
| reordered_past = [] | |
| for layer_past in past_key_values: | |
| k, v = layer_past | |
| device = k.device | |
| beam_idx_dev = beam_idx.to(device) | |
| reordered_past.append(( | |
| k.index_select(0, beam_idx_dev), | |
| v.index_select(0, beam_idx_dev) | |
| )) | |
| return reordered_past | |
| def crop_block_size(self, block_size): | |
| assert block_size <= self.config.block_size | |
| self.config.block_size = block_size | |
| self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) | |
| for block in self.transformer.h: | |
| if hasattr(block.attn, 'bias'): | |
| block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] | |
| def load_dense_weights_into_moe(self, dense_state_dict, strict=False): | |
| """ | |
| Migrate Dense MLP weights to MoE experts. | |
| Ensures exact mathematical equivalence by cloning weights/biases to ALL experts. | |
| """ | |
| if not getattr(self.config, 'use_moe', False): | |
| return self.load_state_dict(dense_state_dict, strict=strict) | |
| print("Converting Dense Checkpoint -> MoE Checkpoint...") | |
| moe_state_dict = {} | |
| # Get config details | |
| num_experts = self.config.num_experts | |
| moe_dim = self.config.moe_dim | |
| for key, value in dense_state_dict.items(): | |
| # Identify MLP weights | |
| if 'mlp.c_fc' in key or 'mlp.c_proj' in key: | |
| # Extract layer index and type (weight/bias) | |
| # key format: transformer.h.{i}.mlp.c_fc.{weight/bias} | |
| parts = key.split('.') | |
| layer_idx = parts[2] | |
| layer_key_prefix = f"transformer.h.{layer_idx}.mlp" | |
| is_bias = 'bias' in key | |
| is_fc = 'c_fc' in key | |
| # --- Handle c_fc (Input -> Hidden) --- | |
| if is_fc: | |
| if not is_bias: | |
| # Weight: Dense is (H, D) -> MoE needs (E, D, H) | |
| # 1. Transpose to (D, H) | |
| w_T = value.t() | |
| # 2. Slice to moe_dim if necessary | |
| w_T = w_T[:, :moe_dim] | |
| # 3. Expand to (E, D, H) | |
| new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone() | |
| moe_state_dict[f"{layer_key_prefix}.fc_bank"] = new_val | |
| else: | |
| # Bias: Dense is (H) -> MoE needs (E, H) | |
| b = value[:moe_dim] | |
| new_val = b.unsqueeze(0).expand(num_experts, -1).clone() | |
| moe_state_dict[f"{layer_key_prefix}.fc_bias"] = new_val | |
| # --- Handle c_proj (Hidden -> Output) --- | |
| else: | |
| if not is_bias: | |
| # Weight: Dense is (D, H) -> MoE needs (E, H, D) | |
| # 1. Transpose to (H, D) | |
| w_T = value.t() | |
| # 2. Slice source dimension (H) if necessary | |
| w_T = w_T[:moe_dim, :] | |
| # 3. Expand to (E, H, D) | |
| new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone() | |
| moe_state_dict[f"{layer_key_prefix}.proj_bank"] = new_val | |
| else: | |
| # Bias: Dense is (D) -> MoE needs (E, D) | |
| # Bias is on the output, so dimension is D, usually doesn't need slicing | |
| new_val = value.unsqueeze(0).expand(num_experts, -1).clone() | |
| moe_state_dict[f"{layer_key_prefix}.proj_bias"] = new_val | |
| # --- Initialize Gate (if not yet initialized) --- | |
| # We initialize gate to zero to ensure uniform routing probability initially, | |
| # which guarantees average of identical experts == single expert. | |
| gate_key = f"{layer_key_prefix}.gate.weight" | |
| if gate_key not in moe_state_dict: | |
| # Zeros = equal probability for all experts | |
| moe_state_dict[gate_key] = torch.zeros(num_experts, self.config.n_embd) | |
| else: | |
| # Copy non-MLP keys directly (Attn, LayerNorm, Embeddings) | |
| moe_state_dict[key] = value | |
| print("Loading constructed state dict...") | |
| return self.load_state_dict(moe_state_dict, strict=strict) | |
| AutoConfig.register("sabiyarn", GPTJXMoEConfig) | |
| AutoModel.register(GPTJXMoEConfig,GPTJXMoEForCausalLM) | |
| AutoModelForCausalLM.register(GPTJXMoEConfig, GPTJXMoEForCausalLM) | |