Upload GPTJXMoEForCausalLM
Browse files- config.json +3 -3
- configuration.py +56 -0
- modeling.py +755 -0
config.json
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "BeardedMonster/
|
| 3 |
"architectures": [
|
| 4 |
"GPTJXMoEForCausalLM"
|
| 5 |
],
|
| 6 |
"auto_map": {
|
| 7 |
-
"AutoConfig": "
|
| 8 |
-
"AutoModelForCausalLM": "
|
| 9 |
},
|
| 10 |
"bias": false,
|
| 11 |
"block_size": 32768,
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "BeardedMonster/MOE",
|
| 3 |
"architectures": [
|
| 4 |
"GPTJXMoEForCausalLM"
|
| 5 |
],
|
| 6 |
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration.GPTJXMoEConfig",
|
| 8 |
+
"AutoModelForCausalLM": "modeling.GPTJXMoEForCausalLM"
|
| 9 |
},
|
| 10 |
"bias": false,
|
| 11 |
"block_size": 32768,
|
configuration.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
|
| 3 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
repo_name = "BeardedMonster/SabiYarn-125M"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GPTJXMoEConfig(PretrainedConfig):
|
| 14 |
+
"""Configuration class for SabiYarn model."""
|
| 15 |
+
|
| 16 |
+
model_type = "sabiyarn"
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
block_size: int = 32768,
|
| 21 |
+
vocab_size: int = 52050, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
| 22 |
+
n_layer: int = 12,
|
| 23 |
+
n_heads: int = 12,
|
| 24 |
+
n_embd: int = 768,
|
| 25 |
+
dropout: float = 0.0,
|
| 26 |
+
max_batch_size: int = 1,
|
| 27 |
+
use_kv_cache: bool = True,
|
| 28 |
+
bias: bool = False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
| 29 |
+
kv_cache_dtype: str = "float32", # "float32" or "float16" for memory savings
|
| 30 |
+
# MoE hyperparameters
|
| 31 |
+
use_moe: bool = False, # Whether to use MoE instead of dense MLP
|
| 32 |
+
num_experts: int = 4, # Number of experts in MoE layer
|
| 33 |
+
num_experts_per_tok: int = 2, # Number of experts to route each token to (top-k)
|
| 34 |
+
moe_dim: int = None, # MoE hidden dimension (defaults to 4 * n_embd like MLP)
|
| 35 |
+
**kwargs
|
| 36 |
+
):
|
| 37 |
+
self.block_size = block_size
|
| 38 |
+
self.vocab_size = vocab_size
|
| 39 |
+
self.n_layer = n_layer
|
| 40 |
+
self.n_heads = n_heads
|
| 41 |
+
self.n_embd = n_embd
|
| 42 |
+
self.dropout = dropout
|
| 43 |
+
self.bias = bias
|
| 44 |
+
self.use_kv_cache = use_kv_cache
|
| 45 |
+
self.max_batch_size = max_batch_size
|
| 46 |
+
self.kv_cache_dtype = kv_cache_dtype # Memory optimization: use float16 for cache
|
| 47 |
+
|
| 48 |
+
# MoE configuration
|
| 49 |
+
self.use_moe = use_moe
|
| 50 |
+
self.num_experts = num_experts
|
| 51 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 52 |
+
# Default moe_dim to match MLP expansion (4x)
|
| 53 |
+
self.moe_dim = moe_dim if moe_dim is not None else (4 * n_embd)
|
| 54 |
+
|
| 55 |
+
super().__init__(**kwargs)
|
| 56 |
+
|
modeling.py
ADDED
|
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SabiYarn Model Implementation - Optimized Version
|
| 3 |
+
Memory-efficient with performance optimizations for generation.
|
| 4 |
+
Matches original implementation exactly but with memory optimizations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoModelForCausalLM
|
| 8 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 9 |
+
# use package-relative import to avoid colliding with unrelated `model` packages
|
| 10 |
+
from .configuration import GPTJXMoEConfig
|
| 11 |
+
from typing import Optional, List, Tuple
|
| 12 |
+
from torch import nn
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LayerNorm(nn.Module):
|
| 19 |
+
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
| 20 |
+
|
| 21 |
+
def __init__(self, ndim, bias):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
| 24 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
| 25 |
+
|
| 26 |
+
def forward(self, input):
|
| 27 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
| 28 |
+
|
| 29 |
+
class CausalSelfAttention(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, config):
|
| 32 |
+
super().__init__()
|
| 33 |
+
assert config.n_embd % config.n_heads == 0
|
| 34 |
+
# key, query, value projections for all heads, but in a batch
|
| 35 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 36 |
+
# output projection
|
| 37 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 38 |
+
# regularization
|
| 39 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 40 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 41 |
+
self.n_heads = config.n_heads
|
| 42 |
+
self.n_embd = config.n_embd
|
| 43 |
+
self.head_dim = config.n_embd // config.n_heads
|
| 44 |
+
self.dropout = config.dropout
|
| 45 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
| 46 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| 47 |
+
|
| 48 |
+
def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False):
|
| 49 |
+
"""
|
| 50 |
+
Forward pass with optional KV cache support.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
x: (B, T, C) input embeddings
|
| 54 |
+
attn_mask: Optional attention mask
|
| 55 |
+
past_key_value: Optional tuple of (past_k, past_v) each (B, nh, past_len, hs)
|
| 56 |
+
use_cache: Whether to return cache for next step
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
If use_cache: (output, (k, v)) where output is (B, T, C) and k, v are (B, nh, total_len, hs)
|
| 60 |
+
Else: output (B, T, C)
|
| 61 |
+
"""
|
| 62 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 63 |
+
|
| 64 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 65 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 66 |
+
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 67 |
+
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 68 |
+
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
|
| 69 |
+
|
| 70 |
+
# Concatenate with past KV cache if provided
|
| 71 |
+
if past_key_value is not None:
|
| 72 |
+
past_k, past_v = past_key_value
|
| 73 |
+
k = torch.cat([past_k, k], dim=2) # (B, nh, past_len + T, hs)
|
| 74 |
+
v = torch.cat([past_v, v], dim=2)
|
| 75 |
+
|
| 76 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, total_len) -> (B, nh, T, total_len)
|
| 77 |
+
total_len = k.size(2)
|
| 78 |
+
|
| 79 |
+
if self.flash:
|
| 80 |
+
if attn_mask is not None:
|
| 81 |
+
# efficient attention using Flash Attention CUDA kernels
|
| 82 |
+
attn_mask = attn_mask.to(torch.bool)
|
| 83 |
+
|
| 84 |
+
# Handle different mask shapes and convert to (B, nh, T, total_len)
|
| 85 |
+
if attn_mask.dim() == 2:
|
| 86 |
+
# (B, S) - expand to cover full sequence if needed
|
| 87 |
+
B_mask = attn_mask.size(0)
|
| 88 |
+
S = attn_mask.size(1)
|
| 89 |
+
|
| 90 |
+
if S == total_len:
|
| 91 |
+
# Mask already covers full sequence
|
| 92 |
+
pass
|
| 93 |
+
elif S == T:
|
| 94 |
+
# Mask only covers current tokens - expand with ones for past tokens
|
| 95 |
+
if past_key_value is not None:
|
| 96 |
+
past_len = total_len - T
|
| 97 |
+
past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=attn_mask.dtype)
|
| 98 |
+
attn_mask = torch.cat([past_mask, attn_mask], dim=1)
|
| 99 |
+
else:
|
| 100 |
+
# No cache, mask is correct as-is
|
| 101 |
+
pass
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})")
|
| 104 |
+
|
| 105 |
+
# Reshape to (B, 1, T, total_len) for Flash Attention
|
| 106 |
+
# Flash Attention expects mask shape (B, nh, T, S) where T is query length
|
| 107 |
+
# First ensure we have the right length
|
| 108 |
+
if attn_mask.size(1) != total_len:
|
| 109 |
+
raise ValueError(f"Mask length mismatch: got {attn_mask.size(1)}, expected {total_len}")
|
| 110 |
+
|
| 111 |
+
# Reshape: (B, total_len) -> (B, 1, 1, total_len) -> (B, 1, T, total_len) -> (B, nh, T, total_len)
|
| 112 |
+
attn_mask = attn_mask.view(B_mask, 1, 1, total_len)
|
| 113 |
+
# Expand to (B, 1, T, total_len) - repeat for each query position
|
| 114 |
+
attn_mask = attn_mask.expand(B_mask, 1, T, total_len)
|
| 115 |
+
# Expand to include head dimension: (B, nh, T, total_len)
|
| 116 |
+
attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
|
| 117 |
+
|
| 118 |
+
# Verify final shape
|
| 119 |
+
assert attn_mask.shape == (B_mask, self.n_heads, T, total_len), \
|
| 120 |
+
f"Mask shape mismatch: got {attn_mask.shape}, expected ({B_mask}, {self.n_heads}, {T}, {total_len})"
|
| 121 |
+
elif attn_mask.dim() == 4:
|
| 122 |
+
# Already 4D mask - ensure it's the right shape
|
| 123 |
+
B_mask = attn_mask.size(0)
|
| 124 |
+
if attn_mask.size(-2) != T:
|
| 125 |
+
# Slice to match query length if needed
|
| 126 |
+
attn_mask = attn_mask[..., -T:, :]
|
| 127 |
+
# Ensure head dimension matches
|
| 128 |
+
if attn_mask.size(1) == 1:
|
| 129 |
+
attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
|
| 130 |
+
elif attn_mask.size(1) != self.n_heads:
|
| 131 |
+
raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}")
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4")
|
| 134 |
+
|
| 135 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=False)
|
| 136 |
+
else:
|
| 137 |
+
# No explicit mask provided
|
| 138 |
+
if past_key_value is None:
|
| 139 |
+
# No cache: use is_causal for efficiency
|
| 140 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
| 141 |
+
else:
|
| 142 |
+
# With cache: create causal mask manually (can't use is_causal when q and k have different lengths)
|
| 143 |
+
causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool))
|
| 144 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask.view(1, 1, T, total_len), dropout_p=self.dropout if self.training else 0, is_causal=False)
|
| 145 |
+
else:
|
| 146 |
+
# manual implementation of attention
|
| 147 |
+
total_len = k.size(2)
|
| 148 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 149 |
+
|
| 150 |
+
if attn_mask is not None:
|
| 151 |
+
attn_mask = attn_mask.to(torch.bool)
|
| 152 |
+
|
| 153 |
+
# Handle different mask shapes and convert to (B, nh, T, total_len)
|
| 154 |
+
if attn_mask.dim() == 2:
|
| 155 |
+
# (B, S) - expand to cover full sequence if needed
|
| 156 |
+
B_mask = attn_mask.size(0)
|
| 157 |
+
S = attn_mask.size(1)
|
| 158 |
+
|
| 159 |
+
if S == total_len:
|
| 160 |
+
# Mask already covers full sequence
|
| 161 |
+
pass
|
| 162 |
+
elif S == T:
|
| 163 |
+
# Mask only covers current tokens - expand with ones for past tokens
|
| 164 |
+
if past_key_value is not None:
|
| 165 |
+
past_len = total_len - T
|
| 166 |
+
past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=torch.bool)
|
| 167 |
+
attn_mask = torch.cat([past_mask, attn_mask], dim=1)
|
| 168 |
+
else:
|
| 169 |
+
# No cache, mask is correct as-is
|
| 170 |
+
pass
|
| 171 |
+
else:
|
| 172 |
+
raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})")
|
| 173 |
+
|
| 174 |
+
# Reshape to (B, 1, T, total_len) then expand to (B, nh, T, total_len)
|
| 175 |
+
attn_mask = attn_mask.view(B_mask, 1, 1, total_len)
|
| 176 |
+
attn_mask = attn_mask.expand(B_mask, 1, T, total_len)
|
| 177 |
+
attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
|
| 178 |
+
elif attn_mask.dim() == 4:
|
| 179 |
+
# Already 4D mask - ensure it's the right shape
|
| 180 |
+
B_mask = attn_mask.size(0)
|
| 181 |
+
if attn_mask.size(-2) != T:
|
| 182 |
+
# Slice to match query length if needed
|
| 183 |
+
attn_mask = attn_mask[..., -T:, :]
|
| 184 |
+
# Ensure head dimension matches
|
| 185 |
+
if attn_mask.size(1) == 1:
|
| 186 |
+
attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
|
| 187 |
+
elif attn_mask.size(1) != self.n_heads:
|
| 188 |
+
raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}")
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4")
|
| 191 |
+
|
| 192 |
+
att = att.masked_fill(~attn_mask, float('-inf'))
|
| 193 |
+
else:
|
| 194 |
+
# Apply causal mask - created on-the-fly (memory efficient, scales to any length)
|
| 195 |
+
# torch.tril() is fast and doesn't require storing large buffers
|
| 196 |
+
# This approach works for 32k, 1M, or any context length
|
| 197 |
+
causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool))
|
| 198 |
+
att = att.masked_fill(~causal_mask.view(1, 1, T, total_len), float('-inf'))
|
| 199 |
+
|
| 200 |
+
att = F.softmax(att, dim=-1)
|
| 201 |
+
att = self.attn_dropout(att)
|
| 202 |
+
y = att @ v # (B, nh, T, total_len) x (B, nh, total_len, hs) -> (B, nh, T, hs)
|
| 203 |
+
|
| 204 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 205 |
+
|
| 206 |
+
# output projection
|
| 207 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 208 |
+
|
| 209 |
+
# Return cache if requested
|
| 210 |
+
if use_cache:
|
| 211 |
+
return y, (k.detach(), v.detach())
|
| 212 |
+
return y
|
| 213 |
+
|
| 214 |
+
class MLP(nn.Module):
|
| 215 |
+
|
| 216 |
+
def __init__(self, config):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 219 |
+
self.gelu = nn.GELU()
|
| 220 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 221 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
x = self.c_fc(x)
|
| 225 |
+
x = self.gelu(x)
|
| 226 |
+
x = self.c_proj(x)
|
| 227 |
+
x = self.dropout(x)
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
class BlockJ(nn.Module):
|
| 231 |
+
|
| 232 |
+
def __init__(self, config):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
| 235 |
+
self.j = LayerNorm(config.n_embd, config.n_embd)
|
| 236 |
+
self.attn = CausalSelfAttention(config)
|
| 237 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
| 238 |
+
|
| 239 |
+
# Use MoE if configured, otherwise use dense MLP
|
| 240 |
+
if getattr(config, 'use_moe', False):
|
| 241 |
+
self.mlp = MoE(
|
| 242 |
+
num_experts_per_tok=config.num_experts_per_tok,
|
| 243 |
+
num_experts=config.num_experts,
|
| 244 |
+
emb_dim=config.n_embd,
|
| 245 |
+
moe_dim=config.moe_dim,
|
| 246 |
+
dropout=config.dropout
|
| 247 |
+
)
|
| 248 |
+
self.use_moe = True
|
| 249 |
+
else:
|
| 250 |
+
self.mlp = MLP(config)
|
| 251 |
+
self.use_moe = False
|
| 252 |
+
|
| 253 |
+
def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False):
|
| 254 |
+
"""
|
| 255 |
+
Forward pass with optional KV cache support.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
x: (B, T, C) input embeddings
|
| 259 |
+
attn_mask: Optional attention mask
|
| 260 |
+
past_key_value: Optional tuple of (past_k, past_v) for attention layer
|
| 261 |
+
use_cache: Whether to return cache for next step
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
If use_cache: (output, (k, v)) where output is (B, T, C)
|
| 265 |
+
Else: output (B, T, C)
|
| 266 |
+
"""
|
| 267 |
+
h = x
|
| 268 |
+
x_ln = self.ln_1(x)
|
| 269 |
+
|
| 270 |
+
# Attention with optional KV cache
|
| 271 |
+
if use_cache:
|
| 272 |
+
attn_out, new_past = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=True)
|
| 273 |
+
x = h + attn_out + self.j(x_ln)
|
| 274 |
+
else:
|
| 275 |
+
attn_out = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=False)
|
| 276 |
+
x = h + attn_out + self.j(x_ln)
|
| 277 |
+
|
| 278 |
+
x = x + self.mlp(self.ln_2(x))
|
| 279 |
+
|
| 280 |
+
if use_cache:
|
| 281 |
+
return x, new_past
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class MoE(nn.Module):
|
| 286 |
+
"""
|
| 287 |
+
An MoE layer with MLP block with swiglue activation function.
|
| 288 |
+
Optimized for production workflows with proper initialization and dropout support.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __init__(self, num_experts_per_tok: int, num_experts: int, emb_dim: int, moe_dim: int, dropout: float = 0.0, dtype=torch.float32):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.k = int(num_experts_per_tok)
|
| 294 |
+
self.E = int(num_experts)
|
| 295 |
+
self.D = int(emb_dim)
|
| 296 |
+
self.H = int(moe_dim)
|
| 297 |
+
self.dropout = dropout
|
| 298 |
+
|
| 299 |
+
self.gate = nn.Linear(self.D, self.E, bias=False, dtype=dtype) # use gate variable bcause couldnt load from checkpoint
|
| 300 |
+
# Match MLP structure: c_fc -> GELU -> c_proj
|
| 301 |
+
self.fc_bank = nn.Parameter(torch.empty(self.E, self.D, self.H, dtype=dtype)) # Equivalent to c_fc: (n_embd -> 4*n_embd)
|
| 302 |
+
self.proj_bank = nn.Parameter(torch.empty(self.E, self.H, self.D, dtype=dtype)) # Equivalent to c_proj: (4*n_embd -> n_embd)
|
| 303 |
+
self.gelu = nn.GELU() # Match MLP activation
|
| 304 |
+
self.dropout_layer = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
| 305 |
+
|
| 306 |
+
# Initialize parameters
|
| 307 |
+
self._init_parameters()
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def expert_utilization(self, logits):
|
| 311 |
+
"""
|
| 312 |
+
This function compute expert utilization per token and also compute load balancer loss.
|
| 313 |
+
Details of this load balancer can be found in https://arxiv.org/abs/2101.03961
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
_, selected = logits.topk(self.k, dim=-1)
|
| 317 |
+
selected = F.one_hot(selected, num_classes=self.E).sum(dim=2) # B, T, E
|
| 318 |
+
|
| 319 |
+
load = torch.mean(selected.float(), dim=(0,1))
|
| 320 |
+
|
| 321 |
+
# average router probability per expert
|
| 322 |
+
P = torch.softmax(logits, dim=-1).float().mean(dim=(0,1)) # [E]
|
| 323 |
+
self._router_probs = P.detach() # per-expert avg prob
|
| 324 |
+
self._aux_lb = self.E * torch.sum(load * P)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
self._expert_utilization = load
|
| 328 |
+
|
| 329 |
+
def _init_parameters(self):
|
| 330 |
+
"""Initialize MoE parameters following standard practices."""
|
| 331 |
+
# Initialize gate with small values to start with uniform routing
|
| 332 |
+
nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)
|
| 333 |
+
|
| 334 |
+
# Initialize expert banks to match MLP initialization
|
| 335 |
+
# fc_bank: standard normal (like c_fc in MLP)
|
| 336 |
+
nn.init.normal_(self.fc_bank, mean=0.0, std=0.02)
|
| 337 |
+
|
| 338 |
+
# proj_bank: smaller initialization for stability (like c_proj in MLP)
|
| 339 |
+
nn.init.normal_(self.proj_bank, mean=0.0, std=0.02 / math.sqrt(2))
|
| 340 |
+
|
| 341 |
+
def forward(self, x):
|
| 342 |
+
B, T, D = x.shape
|
| 343 |
+
assert D == self.D, f"Expected emb_dim={self.D}, got {D}"
|
| 344 |
+
|
| 345 |
+
logits = self.gate(x) # B, T, E
|
| 346 |
+
|
| 347 |
+
if self.training:
|
| 348 |
+
logits = logits + torch.randn_like(logits) * 1e-1
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
topk_logits, selected = logits.topk(self.k, dim=-1)
|
| 352 |
+
topk_probs = F.softmax(topk_logits, dim=-1)
|
| 353 |
+
|
| 354 |
+
# Match MLP structure exactly: c_fc -> GELU -> c_proj
|
| 355 |
+
# Step 1: c_fc equivalent: x @ fc_bank -> (B, T, E, H)
|
| 356 |
+
h = torch.einsum("btd,edh->bteh", x, self.fc_bank) # B, T, E, H
|
| 357 |
+
|
| 358 |
+
# Step 2: GELU activation (matching MLP)
|
| 359 |
+
h = self.gelu(h) # B, T, E, H
|
| 360 |
+
|
| 361 |
+
# Step 3: c_proj equivalent: h @ proj_bank -> (B, T, E, D)
|
| 362 |
+
y = torch.einsum("bteh,ehd->bted", h, self.proj_bank) # B, T, E, D
|
| 363 |
+
|
| 364 |
+
# Step 4: Select top-k experts and combine
|
| 365 |
+
gather_idx = selected.view(B, T, -1, 1).expand(-1, -1, -1, self.D) # B, T, K, D
|
| 366 |
+
y = torch.gather(y, dim=2, index=gather_idx) # B, T, K, D
|
| 367 |
+
|
| 368 |
+
# Step 5: Weighted sum of selected experts
|
| 369 |
+
y = (y * topk_probs.unsqueeze(-1)).sum(dim=2) # B, T, D
|
| 370 |
+
|
| 371 |
+
# Step 6: Apply dropout like MLP
|
| 372 |
+
y = self.dropout_layer(y)
|
| 373 |
+
|
| 374 |
+
self.expert_utilization(logits)
|
| 375 |
+
return y
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class GPTJXMoEForCausalLM(PreTrainedModel):
|
| 379 |
+
config_class = GPTJXMoEConfig
|
| 380 |
+
base_model_prefix = "transformer"
|
| 381 |
+
is_parallelizable = True
|
| 382 |
+
supports_gradient_checkpointing = True
|
| 383 |
+
_no_split_modules = ["BlockJ"]
|
| 384 |
+
# _skip_keys_device_placement = "past_key_values"
|
| 385 |
+
_supports_flash_attn_2 = True
|
| 386 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
super().__init__(config)
|
| 391 |
+
assert config.vocab_size is not None
|
| 392 |
+
assert config.block_size is not None
|
| 393 |
+
self.config = config
|
| 394 |
+
|
| 395 |
+
self.transformer = nn.ModuleDict(dict(
|
| 396 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
| 397 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
| 398 |
+
drop = nn.Dropout(config.dropout),
|
| 399 |
+
h = nn.ModuleList([BlockJ(config) for _ in range(config.n_layer)]),
|
| 400 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
| 401 |
+
))
|
| 402 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 403 |
+
self.transformer.wte.weight = self.lm_head.weight
|
| 404 |
+
|
| 405 |
+
# No need to store causal mask buffer - masks are created on-the-fly when needed
|
| 406 |
+
# Flash Attention handles causality internally with is_causal=True
|
| 407 |
+
# For manual attention, torch.tril() creates masks efficiently on-the-fly
|
| 408 |
+
# This approach scales to any context length (1M+ tokens) without memory overhead
|
| 409 |
+
|
| 410 |
+
self.apply(self._init_weights)
|
| 411 |
+
|
| 412 |
+
for pn, p in self.named_parameters():
|
| 413 |
+
if pn.endswith('c_proj.weight'):
|
| 414 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 415 |
+
|
| 416 |
+
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
|
| 417 |
+
|
| 418 |
+
def get_num_params(self, non_embedding=True):
|
| 419 |
+
"""
|
| 420 |
+
Return the number of parameters in the model.
|
| 421 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
| 422 |
+
The token embeddings would too, except due to the parameter sharing these
|
| 423 |
+
params are actually used as weights in the final layer, so we include them.
|
| 424 |
+
"""
|
| 425 |
+
n_params = sum(p.numel() for p in self.parameters())
|
| 426 |
+
if non_embedding:
|
| 427 |
+
n_params -= self.transformer.wpe.weight.numel()
|
| 428 |
+
return n_params
|
| 429 |
+
|
| 430 |
+
def get_expert_utilization(self):
|
| 431 |
+
"""
|
| 432 |
+
Get expert utilization statistics for MoE layers.
|
| 433 |
+
Returns expert utilization per layer and load balancing loss.
|
| 434 |
+
Only works when use_moe=True in config.
|
| 435 |
+
"""
|
| 436 |
+
if not getattr(self.config, 'use_moe', False):
|
| 437 |
+
return None, None
|
| 438 |
+
|
| 439 |
+
lb_loss, expert_utilization_per_layer = 0, []
|
| 440 |
+
moe_layers = 0
|
| 441 |
+
for block in self.transformer.h:
|
| 442 |
+
if hasattr(block, 'use_moe') and block.use_moe and hasattr(block.mlp, '_aux_lb'):
|
| 443 |
+
lb_loss += block.mlp._aux_lb
|
| 444 |
+
expert_utilization_per_layer.append(block.mlp._expert_utilization.detach().cpu())
|
| 445 |
+
moe_layers += 1
|
| 446 |
+
|
| 447 |
+
if moe_layers > 0:
|
| 448 |
+
lb_loss = lb_loss / moe_layers
|
| 449 |
+
return expert_utilization_per_layer, lb_loss
|
| 450 |
+
|
| 451 |
+
def get_input_embeddings(self):
|
| 452 |
+
return self.transformer.wte
|
| 453 |
+
|
| 454 |
+
def set_input_embeddings(self, new_embeddings):
|
| 455 |
+
self.transformer.wte = new_embeddings
|
| 456 |
+
|
| 457 |
+
def forward(
|
| 458 |
+
self,
|
| 459 |
+
input_ids,
|
| 460 |
+
targets=None,
|
| 461 |
+
attn_mask=None,
|
| 462 |
+
attention_mask=None, # HF standard name
|
| 463 |
+
past_key_values=None,
|
| 464 |
+
position_ids=None,
|
| 465 |
+
use_cache=None,
|
| 466 |
+
output_hidden_states: Optional[bool] = None,
|
| 467 |
+
**kwargs
|
| 468 |
+
):
|
| 469 |
+
"""
|
| 470 |
+
Forward pass with KV cache support for efficient generation.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
input_ids: (B, T) Token indices
|
| 474 |
+
targets: Optional (B, T) target token indices for training
|
| 475 |
+
attn_mask: Optional attention mask (legacy name)
|
| 476 |
+
attention_mask: Optional attention mask (HF standard name, takes precedence)
|
| 477 |
+
past_key_values: Optional list of (k, v) tuples from previous steps for KV cache
|
| 478 |
+
position_ids: Optional (B, T) position indices (if None, computed from past_key_values)
|
| 479 |
+
use_cache: Whether to return past_key_values for next step (defaults to config.use_kv_cache)
|
| 480 |
+
output_hidden_states: Whether to return hidden states
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
CausalLMOutputWithPast with logits and optionally past_key_values
|
| 484 |
+
"""
|
| 485 |
+
device = input_ids.device
|
| 486 |
+
b, t = input_ids.size()
|
| 487 |
+
|
| 488 |
+
# Use attention_mask if provided (HF standard), otherwise fall back to attn_mask
|
| 489 |
+
if attention_mask is not None:
|
| 490 |
+
attn_mask = attention_mask
|
| 491 |
+
|
| 492 |
+
# Determine if we're using KV cache
|
| 493 |
+
use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False)
|
| 494 |
+
|
| 495 |
+
# Compute past sequence length if using cache
|
| 496 |
+
past_len = 0
|
| 497 |
+
if past_key_values is not None:
|
| 498 |
+
past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0
|
| 499 |
+
|
| 500 |
+
# Handle position_ids
|
| 501 |
+
if position_ids is None:
|
| 502 |
+
# Compute position IDs: from past_len to past_len + t
|
| 503 |
+
pos = torch.arange(past_len, past_len + t, dtype=torch.long, device=device)
|
| 504 |
+
else:
|
| 505 |
+
pos = position_ids
|
| 506 |
+
|
| 507 |
+
# Validate sequence length
|
| 508 |
+
total_len = past_len + t
|
| 509 |
+
assert total_len <= self.config.block_size, f"Cannot forward sequence of length {total_len}, block size is only {self.config.block_size}"
|
| 510 |
+
|
| 511 |
+
# forward the GPT model itself
|
| 512 |
+
tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
|
| 513 |
+
|
| 514 |
+
# Handle position embeddings: wpe expects 1D position indices
|
| 515 |
+
if pos.dim() == 2:
|
| 516 |
+
# If position_ids is 2D (B, T), extract first row (assuming all sequences have same positions)
|
| 517 |
+
pos_1d = pos[0] if pos.size(0) > 0 else pos.squeeze(0)
|
| 518 |
+
else:
|
| 519 |
+
pos_1d = pos
|
| 520 |
+
|
| 521 |
+
pos_emb = self.transformer.wpe(pos_1d) # position embeddings of shape (t, n_embd)
|
| 522 |
+
if pos_emb.dim() == 2:
|
| 523 |
+
pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1) # Expand to (b, t, n_embd)
|
| 524 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
| 525 |
+
|
| 526 |
+
# Expand attention_mask to cover full sequence (past + current) if needed
|
| 527 |
+
# HF's generation API may provide mask only for current tokens
|
| 528 |
+
if attn_mask is not None and past_key_values is not None and use_kv_cache:
|
| 529 |
+
# Check if mask needs expansion
|
| 530 |
+
if attn_mask.dim() == 2:
|
| 531 |
+
mask_len = attn_mask.size(1)
|
| 532 |
+
if mask_len == t and total_len > t:
|
| 533 |
+
# Mask only covers current tokens, expand with ones for past tokens
|
| 534 |
+
past_len = total_len - t
|
| 535 |
+
past_mask = torch.ones(b, past_len, device=device, dtype=attn_mask.dtype)
|
| 536 |
+
attn_mask = torch.cat([past_mask, attn_mask], dim=1)
|
| 537 |
+
|
| 538 |
+
# Process through transformer layers with KV cache
|
| 539 |
+
new_past_key_values = [] if use_kv_cache else None
|
| 540 |
+
|
| 541 |
+
for i, block in enumerate(self.transformer.h):
|
| 542 |
+
layer_past = past_key_values[i] if past_key_values is not None else None
|
| 543 |
+
|
| 544 |
+
if use_kv_cache:
|
| 545 |
+
x, new_past = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=True)
|
| 546 |
+
new_past_key_values.append(new_past)
|
| 547 |
+
else:
|
| 548 |
+
x = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=False)
|
| 549 |
+
|
| 550 |
+
x = self.transformer.ln_f(x)
|
| 551 |
+
|
| 552 |
+
# Compute logits and loss
|
| 553 |
+
if targets is not None:
|
| 554 |
+
# Training: compute logits for all positions
|
| 555 |
+
logits = self.lm_head(x)
|
| 556 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
|
| 557 |
+
else:
|
| 558 |
+
# Inference: only compute logits for last position when using cache, all positions otherwise
|
| 559 |
+
if use_kv_cache and past_key_values is not None:
|
| 560 |
+
logits = self.lm_head(x[:, [-1], :]) # Only last token
|
| 561 |
+
else:
|
| 562 |
+
logits = self.lm_head(x) # All tokens
|
| 563 |
+
loss = None
|
| 564 |
+
|
| 565 |
+
return CausalLMOutputWithPast(
|
| 566 |
+
loss=loss,
|
| 567 |
+
logits=logits,
|
| 568 |
+
past_key_values=tuple(new_past_key_values) if use_kv_cache else None,
|
| 569 |
+
hidden_states=x if output_hidden_states else None,
|
| 570 |
+
attentions=None,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
def prepare_inputs_for_generation(
|
| 574 |
+
self,
|
| 575 |
+
input_ids,
|
| 576 |
+
attention_mask=None,
|
| 577 |
+
past_key_values=None,
|
| 578 |
+
position_ids=None,
|
| 579 |
+
use_cache=None,
|
| 580 |
+
**kwargs
|
| 581 |
+
):
|
| 582 |
+
"""
|
| 583 |
+
Prepare inputs for generation with KV cache support.
|
| 584 |
+
This method is called by HF's generation API.
|
| 585 |
+
"""
|
| 586 |
+
# Determine if we should use cache
|
| 587 |
+
use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False)
|
| 588 |
+
|
| 589 |
+
# Base model inputs
|
| 590 |
+
model_inputs = {
|
| 591 |
+
"input_ids": input_ids,
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
# ---- 1. Handle KV cache (past_key_values) ----
|
| 595 |
+
if past_key_values is not None and use_kv_cache:
|
| 596 |
+
# Only feed the last token when using cached keys/values
|
| 597 |
+
model_inputs["input_ids"] = input_ids[:, -1:]
|
| 598 |
+
model_inputs["past_key_values"] = past_key_values
|
| 599 |
+
|
| 600 |
+
# ---- 2. Handle attention mask ----
|
| 601 |
+
if attention_mask is not None:
|
| 602 |
+
# When using cache, attention_mask should cover the full sequence (past + current)
|
| 603 |
+
if past_key_values is not None and use_kv_cache:
|
| 604 |
+
# Extend attention mask to include past tokens
|
| 605 |
+
# HF generation will handle this, but we ensure it's passed through
|
| 606 |
+
pass
|
| 607 |
+
model_inputs["attention_mask"] = attention_mask
|
| 608 |
+
|
| 609 |
+
# ---- 3. Handle position_ids correctly ----
|
| 610 |
+
# HF relies on this for models like GPT-J, GPT-NeoX, Llama, etc.
|
| 611 |
+
if position_ids is not None:
|
| 612 |
+
if past_key_values is not None and use_kv_cache:
|
| 613 |
+
# Only use the last position when using cache
|
| 614 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 615 |
+
model_inputs["position_ids"] = position_ids
|
| 616 |
+
elif past_key_values is not None and use_kv_cache:
|
| 617 |
+
# Compute position_ids from past_key_values length
|
| 618 |
+
past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0
|
| 619 |
+
model_inputs["position_ids"] = torch.tensor([[past_len]], device=input_ids.device, dtype=torch.long)
|
| 620 |
+
|
| 621 |
+
# ---- 4. Forward arbitrary extra kwargs safely ----
|
| 622 |
+
# For example: use_cache, output_attentions, token_type_ids, etc.
|
| 623 |
+
if use_cache is not None:
|
| 624 |
+
model_inputs["use_cache"] = use_cache
|
| 625 |
+
|
| 626 |
+
for k, v in kwargs.items():
|
| 627 |
+
if v is not None:
|
| 628 |
+
model_inputs[k] = v
|
| 629 |
+
|
| 630 |
+
return model_inputs
|
| 631 |
+
|
| 632 |
+
def _reorder_cache(
|
| 633 |
+
self,
|
| 634 |
+
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
|
| 635 |
+
beam_idx: torch.Tensor,
|
| 636 |
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 637 |
+
"""
|
| 638 |
+
Reorder cache for beam search.
|
| 639 |
+
|
| 640 |
+
Required by HF for beam search to work correctly.
|
| 641 |
+
Selects which beam samples to keep based on beam_idx.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
past_key_values: List of (k, v) tuples from previous steps
|
| 645 |
+
beam_idx: (batch_size,) tensor indicating which beams to keep
|
| 646 |
+
|
| 647 |
+
Returns:
|
| 648 |
+
Reordered past_key_values
|
| 649 |
+
"""
|
| 650 |
+
reordered_past = []
|
| 651 |
+
for layer_past in past_key_values:
|
| 652 |
+
k, v = layer_past
|
| 653 |
+
device = k.device
|
| 654 |
+
beam_idx_dev = beam_idx.to(device)
|
| 655 |
+
reordered_past.append((
|
| 656 |
+
k.index_select(0, beam_idx_dev),
|
| 657 |
+
v.index_select(0, beam_idx_dev)
|
| 658 |
+
))
|
| 659 |
+
return reordered_past
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def crop_block_size(self, block_size):
|
| 663 |
+
assert block_size <= self.config.block_size
|
| 664 |
+
self.config.block_size = block_size
|
| 665 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
| 666 |
+
for block in self.transformer.h:
|
| 667 |
+
if hasattr(block.attn, 'bias'):
|
| 668 |
+
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
| 669 |
+
|
| 670 |
+
def load_dense_weights_into_moe(self, dense_state_dict, strict=False):
|
| 671 |
+
"""
|
| 672 |
+
Migrate Dense MLP weights to MoE experts.
|
| 673 |
+
Ensures exact mathematical equivalence by cloning weights/biases to ALL experts.
|
| 674 |
+
"""
|
| 675 |
+
if not getattr(self.config, 'use_moe', False):
|
| 676 |
+
return self.load_state_dict(dense_state_dict, strict=strict)
|
| 677 |
+
|
| 678 |
+
print("Converting Dense Checkpoint -> MoE Checkpoint...")
|
| 679 |
+
moe_state_dict = {}
|
| 680 |
+
|
| 681 |
+
# Get config details
|
| 682 |
+
num_experts = self.config.num_experts
|
| 683 |
+
moe_dim = self.config.moe_dim
|
| 684 |
+
|
| 685 |
+
for key, value in dense_state_dict.items():
|
| 686 |
+
# Identify MLP weights
|
| 687 |
+
if 'mlp.c_fc' in key or 'mlp.c_proj' in key:
|
| 688 |
+
|
| 689 |
+
# Extract layer index and type (weight/bias)
|
| 690 |
+
# key format: transformer.h.{i}.mlp.c_fc.{weight/bias}
|
| 691 |
+
parts = key.split('.')
|
| 692 |
+
layer_idx = parts[2]
|
| 693 |
+
layer_key_prefix = f"transformer.h.{layer_idx}.mlp"
|
| 694 |
+
|
| 695 |
+
is_bias = 'bias' in key
|
| 696 |
+
is_fc = 'c_fc' in key
|
| 697 |
+
|
| 698 |
+
# --- Handle c_fc (Input -> Hidden) ---
|
| 699 |
+
if is_fc:
|
| 700 |
+
if not is_bias:
|
| 701 |
+
# Weight: Dense is (H, D) -> MoE needs (E, D, H)
|
| 702 |
+
# 1. Transpose to (D, H)
|
| 703 |
+
w_T = value.t()
|
| 704 |
+
# 2. Slice to moe_dim if necessary
|
| 705 |
+
w_T = w_T[:, :moe_dim]
|
| 706 |
+
# 3. Expand to (E, D, H)
|
| 707 |
+
new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone()
|
| 708 |
+
moe_state_dict[f"{layer_key_prefix}.fc_bank"] = new_val
|
| 709 |
+
else:
|
| 710 |
+
# Bias: Dense is (H) -> MoE needs (E, H)
|
| 711 |
+
b = value[:moe_dim]
|
| 712 |
+
new_val = b.unsqueeze(0).expand(num_experts, -1).clone()
|
| 713 |
+
moe_state_dict[f"{layer_key_prefix}.fc_bias"] = new_val
|
| 714 |
+
|
| 715 |
+
# --- Handle c_proj (Hidden -> Output) ---
|
| 716 |
+
else:
|
| 717 |
+
if not is_bias:
|
| 718 |
+
# Weight: Dense is (D, H) -> MoE needs (E, H, D)
|
| 719 |
+
# 1. Transpose to (H, D)
|
| 720 |
+
w_T = value.t()
|
| 721 |
+
# 2. Slice source dimension (H) if necessary
|
| 722 |
+
w_T = w_T[:moe_dim, :]
|
| 723 |
+
# 3. Expand to (E, H, D)
|
| 724 |
+
new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone()
|
| 725 |
+
moe_state_dict[f"{layer_key_prefix}.proj_bank"] = new_val
|
| 726 |
+
else:
|
| 727 |
+
# Bias: Dense is (D) -> MoE needs (E, D)
|
| 728 |
+
# Bias is on the output, so dimension is D, usually doesn't need slicing
|
| 729 |
+
new_val = value.unsqueeze(0).expand(num_experts, -1).clone()
|
| 730 |
+
moe_state_dict[f"{layer_key_prefix}.proj_bias"] = new_val
|
| 731 |
+
|
| 732 |
+
# --- Initialize Gate (if not yet initialized) ---
|
| 733 |
+
# We initialize gate to zero to ensure uniform routing probability initially,
|
| 734 |
+
# which guarantees average of identical experts == single expert.
|
| 735 |
+
gate_key = f"{layer_key_prefix}.gate.weight"
|
| 736 |
+
if gate_key not in moe_state_dict:
|
| 737 |
+
# Zeros = equal probability for all experts
|
| 738 |
+
moe_state_dict[gate_key] = torch.zeros(num_experts, self.config.n_embd)
|
| 739 |
+
|
| 740 |
+
else:
|
| 741 |
+
# Copy non-MLP keys directly (Attn, LayerNorm, Embeddings)
|
| 742 |
+
moe_state_dict[key] = value
|
| 743 |
+
|
| 744 |
+
print("Loading constructed state dict...")
|
| 745 |
+
return self.load_state_dict(moe_state_dict, strict=strict)
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
AutoConfig.register("sabiyarn", GPTJXMoEConfig)
|
| 749 |
+
AutoModel.register(GPTJXMoEConfig,GPTJXMoEForCausalLM)
|
| 750 |
+
AutoModelForCausalLM.register(GPTJXMoEConfig, GPTJXMoEForCausalLM)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
|