| | """ |
| | SabiYarn Model Implementation - Optimized Version |
| | Memory-efficient with performance optimizations for generation. |
| | Matches original implementation exactly but with memory optimizations. |
| | """ |
| |
|
| | from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoModelForCausalLM |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | |
| | from .configuration import GPTJXMoEConfig |
| | from typing import Optional, List, Tuple |
| | from torch import nn |
| | import torch |
| | import torch.nn.functional as F |
| | import math |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ |
| |
|
| | def __init__(self, ndim, bias): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(ndim)) |
| | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| |
|
| | def forward(self, input): |
| | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
| |
|
| | class CausalSelfAttention(nn.Module): |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | assert config.n_embd % config.n_heads == 0 |
| | |
| | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| | |
| | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| | |
| | self.attn_dropout = nn.Dropout(config.dropout) |
| | self.resid_dropout = nn.Dropout(config.dropout) |
| | self.n_heads = config.n_heads |
| | self.n_embd = config.n_embd |
| | self.head_dim = config.n_embd // config.n_heads |
| | self.dropout = config.dropout |
| | |
| | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
| |
|
| | def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False): |
| | """ |
| | Forward pass with optional KV cache support. |
| | |
| | Args: |
| | x: (B, T, C) input embeddings |
| | attn_mask: Optional attention mask |
| | past_key_value: Optional tuple of (past_k, past_v) each (B, nh, past_len, hs) |
| | use_cache: Whether to return cache for next step |
| | |
| | Returns: |
| | If use_cache: (output, (k, v)) where output is (B, T, C) and k, v are (B, nh, total_len, hs) |
| | Else: output (B, T, C) |
| | """ |
| | B, T, C = x.size() |
| |
|
| | |
| | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| | k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| | q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| | v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| |
|
| | |
| | if past_key_value is not None: |
| | past_k, past_v = past_key_value |
| | k = torch.cat([past_k, k], dim=2) |
| | v = torch.cat([past_v, v], dim=2) |
| |
|
| | |
| | total_len = k.size(2) |
| | |
| | if self.flash: |
| | if attn_mask is not None: |
| | |
| | attn_mask = attn_mask.to(torch.bool) |
| | |
| | |
| | if attn_mask.dim() == 2: |
| | |
| | B_mask = attn_mask.size(0) |
| | S = attn_mask.size(1) |
| | |
| | if S == total_len: |
| | |
| | pass |
| | elif S == T: |
| | |
| | if past_key_value is not None: |
| | past_len = total_len - T |
| | past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=attn_mask.dtype) |
| | attn_mask = torch.cat([past_mask, attn_mask], dim=1) |
| | else: |
| | |
| | pass |
| | else: |
| | raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})") |
| | |
| | |
| | |
| | |
| | if attn_mask.size(1) != total_len: |
| | raise ValueError(f"Mask length mismatch: got {attn_mask.size(1)}, expected {total_len}") |
| | |
| | |
| | attn_mask = attn_mask.view(B_mask, 1, 1, total_len) |
| | |
| | attn_mask = attn_mask.expand(B_mask, 1, T, total_len) |
| | |
| | attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) |
| | |
| | |
| | assert attn_mask.shape == (B_mask, self.n_heads, T, total_len), \ |
| | f"Mask shape mismatch: got {attn_mask.shape}, expected ({B_mask}, {self.n_heads}, {T}, {total_len})" |
| | elif attn_mask.dim() == 4: |
| | |
| | B_mask = attn_mask.size(0) |
| | if attn_mask.size(-2) != T: |
| | |
| | attn_mask = attn_mask[..., -T:, :] |
| | |
| | if attn_mask.size(1) == 1: |
| | attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) |
| | elif attn_mask.size(1) != self.n_heads: |
| | raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}") |
| | else: |
| | raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4") |
| | |
| | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=False) |
| | else: |
| | |
| | if past_key_value is None: |
| | |
| | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) |
| | else: |
| | |
| | causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool)) |
| | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask.view(1, 1, T, total_len), dropout_p=self.dropout if self.training else 0, is_causal=False) |
| | else: |
| | |
| | total_len = k.size(2) |
| | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
| | |
| | if attn_mask is not None: |
| | attn_mask = attn_mask.to(torch.bool) |
| | |
| | |
| | if attn_mask.dim() == 2: |
| | |
| | B_mask = attn_mask.size(0) |
| | S = attn_mask.size(1) |
| | |
| | if S == total_len: |
| | |
| | pass |
| | elif S == T: |
| | |
| | if past_key_value is not None: |
| | past_len = total_len - T |
| | past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=torch.bool) |
| | attn_mask = torch.cat([past_mask, attn_mask], dim=1) |
| | else: |
| | |
| | pass |
| | else: |
| | raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})") |
| | |
| | |
| | attn_mask = attn_mask.view(B_mask, 1, 1, total_len) |
| | attn_mask = attn_mask.expand(B_mask, 1, T, total_len) |
| | attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) |
| | elif attn_mask.dim() == 4: |
| | |
| | B_mask = attn_mask.size(0) |
| | if attn_mask.size(-2) != T: |
| | |
| | attn_mask = attn_mask[..., -T:, :] |
| | |
| | if attn_mask.size(1) == 1: |
| | attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) |
| | elif attn_mask.size(1) != self.n_heads: |
| | raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}") |
| | else: |
| | raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4") |
| | |
| | att = att.masked_fill(~attn_mask, float('-inf')) |
| | else: |
| | |
| | |
| | |
| | causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool)) |
| | att = att.masked_fill(~causal_mask.view(1, 1, T, total_len), float('-inf')) |
| | |
| | att = F.softmax(att, dim=-1) |
| | att = self.attn_dropout(att) |
| | y = att @ v |
| | |
| | y = y.transpose(1, 2).contiguous().view(B, T, C) |
| |
|
| | |
| | y = self.resid_dropout(self.c_proj(y)) |
| | |
| | |
| | if use_cache: |
| | return y, (k.detach(), v.detach()) |
| | return y |
| |
|
| | class MLP(nn.Module): |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| | self.gelu = nn.GELU() |
| | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| | self.dropout = nn.Dropout(config.dropout) |
| |
|
| | def forward(self, x): |
| | x = self.c_fc(x) |
| | x = self.gelu(x) |
| | x = self.c_proj(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| | class BlockJ(nn.Module): |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
| | self.j = LayerNorm(config.n_embd, config.n_embd) |
| | self.attn = CausalSelfAttention(config) |
| | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
| | |
| | |
| | if getattr(config, 'use_moe', False): |
| | self.mlp = MoE( |
| | num_experts_per_tok=config.num_experts_per_tok, |
| | num_experts=config.num_experts, |
| | emb_dim=config.n_embd, |
| | moe_dim=config.moe_dim, |
| | dropout=config.dropout |
| | ) |
| | self.use_moe = True |
| | else: |
| | self.mlp = MLP(config) |
| | self.use_moe = False |
| |
|
| | def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False): |
| | """ |
| | Forward pass with optional KV cache support. |
| | |
| | Args: |
| | x: (B, T, C) input embeddings |
| | attn_mask: Optional attention mask |
| | past_key_value: Optional tuple of (past_k, past_v) for attention layer |
| | use_cache: Whether to return cache for next step |
| | |
| | Returns: |
| | If use_cache: (output, (k, v)) where output is (B, T, C) |
| | Else: output (B, T, C) |
| | """ |
| | h = x |
| | x_ln = self.ln_1(x) |
| | |
| | |
| | if use_cache: |
| | attn_out, new_past = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=True) |
| | x = h + attn_out + self.j(x_ln) |
| | else: |
| | attn_out = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=False) |
| | x = h + attn_out + self.j(x_ln) |
| | |
| | x = x + self.mlp(self.ln_2(x)) |
| | |
| | if use_cache: |
| | return x, new_past |
| | return x |
| | |
| |
|
| | class MoE(nn.Module): |
| | """ |
| | An MoE layer with MLP block with swiglue activation function. |
| | Optimized for production workflows with proper initialization and dropout support. |
| | """ |
| |
|
| | def __init__(self, num_experts_per_tok: int, num_experts: int, emb_dim: int, moe_dim: int, dropout: float = 0.0, dtype=torch.float32): |
| | super().__init__() |
| | self.k = int(num_experts_per_tok) |
| | self.E = int(num_experts) |
| | self.D = int(emb_dim) |
| | self.H = int(moe_dim) |
| | self.dropout = dropout |
| |
|
| | self.gate = nn.Linear(self.D, self.E, bias=False, dtype=dtype) |
| | |
| | self.fc_bank = nn.Parameter(torch.empty(self.E, self.D, self.H, dtype=dtype)) |
| | self.proj_bank = nn.Parameter(torch.empty(self.E, self.H, self.D, dtype=dtype)) |
| | self.gelu = nn.GELU() |
| | self.dropout_layer = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() |
| | |
| | |
| | self._init_parameters() |
| |
|
| |
|
| | def expert_utilization(self, logits): |
| | """ |
| | This function compute expert utilization per token and also compute load balancer loss. |
| | Details of this load balancer can be found in https://arxiv.org/abs/2101.03961 |
| | """ |
| | |
| | _, selected = logits.topk(self.k, dim=-1) |
| | selected = F.one_hot(selected, num_classes=self.E).sum(dim=2) |
| |
|
| | load = torch.mean(selected.float(), dim=(0,1)) |
| | |
| | |
| | P = torch.softmax(logits, dim=-1).float().mean(dim=(0,1)) |
| | self._router_probs = P.detach() |
| | self._aux_lb = self.E * torch.sum(load * P) |
| |
|
| | |
| | self._expert_utilization = load |
| |
|
| | def _init_parameters(self): |
| | """Initialize MoE parameters following standard practices.""" |
| | |
| | nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) |
| | |
| | |
| | |
| | nn.init.normal_(self.fc_bank, mean=0.0, std=0.02) |
| | |
| | |
| | nn.init.normal_(self.proj_bank, mean=0.0, std=0.02 / math.sqrt(2)) |
| |
|
| | def forward(self, x): |
| | B, T, D = x.shape |
| | assert D == self.D, f"Expected emb_dim={self.D}, got {D}" |
| |
|
| | logits = self.gate(x) |
| |
|
| | if self.training: |
| | logits = logits + torch.randn_like(logits) * 1e-1 |
| |
|
| | |
| | topk_logits, selected = logits.topk(self.k, dim=-1) |
| | topk_probs = F.softmax(topk_logits, dim=-1) |
| |
|
| | |
| | |
| | h = torch.einsum("btd,edh->bteh", x, self.fc_bank) |
| | |
| | |
| | h = self.gelu(h) |
| | |
| | |
| | y = torch.einsum("bteh,ehd->bted", h, self.proj_bank) |
| | |
| | |
| | gather_idx = selected.view(B, T, -1, 1).expand(-1, -1, -1, self.D) |
| | y = torch.gather(y, dim=2, index=gather_idx) |
| | |
| | |
| | y = (y * topk_probs.unsqueeze(-1)).sum(dim=2) |
| | |
| | |
| | y = self.dropout_layer(y) |
| |
|
| | self.expert_utilization(logits) |
| | return y |
| | |
| | |
| | class GPTJXMoEForCausalLM(PreTrainedModel): |
| | config_class = GPTJXMoEConfig |
| | base_model_prefix = "transformer" |
| | is_parallelizable = True |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["BlockJ"] |
| | |
| | _supports_flash_attn_2 = True |
| | _tied_weights_keys = ["lm_head.weight"] |
| | |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | assert config.vocab_size is not None |
| | assert config.block_size is not None |
| | self.config = config |
| |
|
| | self.transformer = nn.ModuleDict(dict( |
| | wte = nn.Embedding(config.vocab_size, config.n_embd), |
| | wpe = nn.Embedding(config.block_size, config.n_embd), |
| | drop = nn.Dropout(config.dropout), |
| | h = nn.ModuleList([BlockJ(config) for _ in range(config.n_layer)]), |
| | ln_f = LayerNorm(config.n_embd, bias=config.bias), |
| | )) |
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | self.transformer.wte.weight = self.lm_head.weight |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | self.apply(self._init_weights) |
| | |
| | for pn, p in self.named_parameters(): |
| | if pn.endswith('c_proj.weight'): |
| | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) |
| |
|
| | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) |
| |
|
| | def get_num_params(self, non_embedding=True): |
| | """ |
| | Return the number of parameters in the model. |
| | For non-embedding count (default), the position embeddings get subtracted. |
| | The token embeddings would too, except due to the parameter sharing these |
| | params are actually used as weights in the final layer, so we include them. |
| | """ |
| | n_params = sum(p.numel() for p in self.parameters()) |
| | if non_embedding: |
| | n_params -= self.transformer.wpe.weight.numel() |
| | return n_params |
| |
|
| | def get_expert_utilization(self): |
| | """ |
| | Get expert utilization statistics for MoE layers. |
| | Returns expert utilization per layer and load balancing loss. |
| | Only works when use_moe=True in config. |
| | """ |
| | if not getattr(self.config, 'use_moe', False): |
| | return None, None |
| | |
| | lb_loss, expert_utilization_per_layer = 0, [] |
| | moe_layers = 0 |
| | for block in self.transformer.h: |
| | if hasattr(block, 'use_moe') and block.use_moe and hasattr(block.mlp, '_aux_lb'): |
| | lb_loss += block.mlp._aux_lb |
| | expert_utilization_per_layer.append(block.mlp._expert_utilization.detach().cpu()) |
| | moe_layers += 1 |
| | |
| | if moe_layers > 0: |
| | lb_loss = lb_loss / moe_layers |
| | return expert_utilization_per_layer, lb_loss |
| |
|
| | def get_input_embeddings(self): |
| | return self.transformer.wte |
| |
|
| | def set_input_embeddings(self, new_embeddings): |
| | self.transformer.wte = new_embeddings |
| |
|
| | def forward( |
| | self, |
| | input_ids, |
| | targets=None, |
| | attn_mask=None, |
| | attention_mask=None, |
| | past_key_values=None, |
| | position_ids=None, |
| | use_cache=None, |
| | output_hidden_states: Optional[bool] = None, |
| | **kwargs |
| | ): |
| | """ |
| | Forward pass with KV cache support for efficient generation. |
| | |
| | Args: |
| | input_ids: (B, T) Token indices |
| | targets: Optional (B, T) target token indices for training |
| | attn_mask: Optional attention mask (legacy name) |
| | attention_mask: Optional attention mask (HF standard name, takes precedence) |
| | past_key_values: Optional list of (k, v) tuples from previous steps for KV cache |
| | position_ids: Optional (B, T) position indices (if None, computed from past_key_values) |
| | use_cache: Whether to return past_key_values for next step (defaults to config.use_kv_cache) |
| | output_hidden_states: Whether to return hidden states |
| | |
| | Returns: |
| | CausalLMOutputWithPast with logits and optionally past_key_values |
| | """ |
| | device = input_ids.device |
| | b, t = input_ids.size() |
| | |
| | |
| | if attention_mask is not None: |
| | attn_mask = attention_mask |
| | |
| | |
| | use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False) |
| | |
| | |
| | past_len = 0 |
| | if past_key_values is not None: |
| | past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0 |
| | |
| | |
| | if position_ids is None: |
| | |
| | pos = torch.arange(past_len, past_len + t, dtype=torch.long, device=device) |
| | else: |
| | pos = position_ids |
| | |
| | |
| | total_len = past_len + t |
| | assert total_len <= self.config.block_size, f"Cannot forward sequence of length {total_len}, block size is only {self.config.block_size}" |
| |
|
| | |
| | tok_emb = self.transformer.wte(input_ids) |
| | |
| | |
| | if pos.dim() == 2: |
| | |
| | pos_1d = pos[0] if pos.size(0) > 0 else pos.squeeze(0) |
| | else: |
| | pos_1d = pos |
| | |
| | pos_emb = self.transformer.wpe(pos_1d) |
| | if pos_emb.dim() == 2: |
| | pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1) |
| | x = self.transformer.drop(tok_emb + pos_emb) |
| | |
| | |
| | |
| | if attn_mask is not None and past_key_values is not None and use_kv_cache: |
| | |
| | if attn_mask.dim() == 2: |
| | mask_len = attn_mask.size(1) |
| | if mask_len == t and total_len > t: |
| | |
| | past_len = total_len - t |
| | past_mask = torch.ones(b, past_len, device=device, dtype=attn_mask.dtype) |
| | attn_mask = torch.cat([past_mask, attn_mask], dim=1) |
| | |
| | |
| | new_past_key_values = [] if use_kv_cache else None |
| | |
| | for i, block in enumerate(self.transformer.h): |
| | layer_past = past_key_values[i] if past_key_values is not None else None |
| | |
| | if use_kv_cache: |
| | x, new_past = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=True) |
| | new_past_key_values.append(new_past) |
| | else: |
| | x = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=False) |
| | |
| | x = self.transformer.ln_f(x) |
| | |
| | |
| | if targets is not None: |
| | |
| | logits = self.lm_head(x) |
| | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) |
| | else: |
| | |
| | if use_kv_cache and past_key_values is not None: |
| | logits = self.lm_head(x[:, [-1], :]) |
| | else: |
| | logits = self.lm_head(x) |
| | loss = None |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=tuple(new_past_key_values) if use_kv_cache else None, |
| | hidden_states=x if output_hidden_states else None, |
| | attentions=None, |
| | ) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | past_key_values=None, |
| | position_ids=None, |
| | use_cache=None, |
| | **kwargs |
| | ): |
| | """ |
| | Prepare inputs for generation with KV cache support. |
| | This method is called by HF's generation API. |
| | """ |
| | |
| | use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False) |
| | |
| | |
| | model_inputs = { |
| | "input_ids": input_ids, |
| | } |
| |
|
| | |
| | if past_key_values is not None and use_kv_cache: |
| | |
| | model_inputs["input_ids"] = input_ids[:, -1:] |
| | model_inputs["past_key_values"] = past_key_values |
| |
|
| | |
| | if attention_mask is not None: |
| | |
| | if past_key_values is not None and use_kv_cache: |
| | |
| | |
| | pass |
| | model_inputs["attention_mask"] = attention_mask |
| |
|
| | |
| | |
| | if position_ids is not None: |
| | if past_key_values is not None and use_kv_cache: |
| | |
| | position_ids = position_ids[:, -1].unsqueeze(-1) |
| | model_inputs["position_ids"] = position_ids |
| | elif past_key_values is not None and use_kv_cache: |
| | |
| | past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0 |
| | model_inputs["position_ids"] = torch.tensor([[past_len]], device=input_ids.device, dtype=torch.long) |
| |
|
| | |
| | |
| | if use_cache is not None: |
| | model_inputs["use_cache"] = use_cache |
| | |
| | for k, v in kwargs.items(): |
| | if v is not None: |
| | model_inputs[k] = v |
| |
|
| | return model_inputs |
| |
|
| | def _reorder_cache( |
| | self, |
| | past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| | beam_idx: torch.Tensor, |
| | ) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
| | """ |
| | Reorder cache for beam search. |
| | |
| | Required by HF for beam search to work correctly. |
| | Selects which beam samples to keep based on beam_idx. |
| | |
| | Args: |
| | past_key_values: List of (k, v) tuples from previous steps |
| | beam_idx: (batch_size,) tensor indicating which beams to keep |
| | |
| | Returns: |
| | Reordered past_key_values |
| | """ |
| | reordered_past = [] |
| | for layer_past in past_key_values: |
| | k, v = layer_past |
| | device = k.device |
| | beam_idx_dev = beam_idx.to(device) |
| | reordered_past.append(( |
| | k.index_select(0, beam_idx_dev), |
| | v.index_select(0, beam_idx_dev) |
| | )) |
| | return reordered_past |
| |
|
| |
|
| | def crop_block_size(self, block_size): |
| | assert block_size <= self.config.block_size |
| | self.config.block_size = block_size |
| | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) |
| | for block in self.transformer.h: |
| | if hasattr(block.attn, 'bias'): |
| | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] |
| | |
| | def load_dense_weights_into_moe(self, dense_state_dict, strict=False): |
| | """ |
| | Migrate Dense MLP weights to MoE experts. |
| | Ensures exact mathematical equivalence by cloning weights/biases to ALL experts. |
| | """ |
| | if not getattr(self.config, 'use_moe', False): |
| | return self.load_state_dict(dense_state_dict, strict=strict) |
| | |
| | print("Converting Dense Checkpoint -> MoE Checkpoint...") |
| | moe_state_dict = {} |
| | |
| | |
| | num_experts = self.config.num_experts |
| | moe_dim = self.config.moe_dim |
| | |
| | for key, value in dense_state_dict.items(): |
| | |
| | if 'mlp.c_fc' in key or 'mlp.c_proj' in key: |
| | |
| | |
| | |
| | parts = key.split('.') |
| | layer_idx = parts[2] |
| | layer_key_prefix = f"transformer.h.{layer_idx}.mlp" |
| | |
| | is_bias = 'bias' in key |
| | is_fc = 'c_fc' in key |
| | |
| | |
| | if is_fc: |
| | if not is_bias: |
| | |
| | |
| | w_T = value.t() |
| | |
| | w_T = w_T[:, :moe_dim] |
| | |
| | new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone() |
| | moe_state_dict[f"{layer_key_prefix}.fc_bank"] = new_val |
| | else: |
| | |
| | b = value[:moe_dim] |
| | new_val = b.unsqueeze(0).expand(num_experts, -1).clone() |
| | moe_state_dict[f"{layer_key_prefix}.fc_bias"] = new_val |
| |
|
| | |
| | else: |
| | if not is_bias: |
| | |
| | |
| | w_T = value.t() |
| | |
| | w_T = w_T[:moe_dim, :] |
| | |
| | new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone() |
| | moe_state_dict[f"{layer_key_prefix}.proj_bank"] = new_val |
| | else: |
| | |
| | |
| | new_val = value.unsqueeze(0).expand(num_experts, -1).clone() |
| | moe_state_dict[f"{layer_key_prefix}.proj_bias"] = new_val |
| |
|
| | |
| | |
| | |
| | gate_key = f"{layer_key_prefix}.gate.weight" |
| | if gate_key not in moe_state_dict: |
| | |
| | moe_state_dict[gate_key] = torch.zeros(num_experts, self.config.n_embd) |
| |
|
| | else: |
| | |
| | moe_state_dict[key] = value |
| |
|
| | print("Loading constructed state dict...") |
| | return self.load_state_dict(moe_state_dict, strict=strict) |
| | |
| |
|
| | AutoConfig.register("sabiyarn", GPTJXMoEConfig) |
| | AutoModel.register(GPTJXMoEConfig,GPTJXMoEForCausalLM) |
| | AutoModelForCausalLM.register(GPTJXMoEConfig, GPTJXMoEForCausalLM) |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|