BeardedMonster commited on
Commit
04ff41b
·
verified ·
1 Parent(s): 77a941e

Upload GPTJXMoEForCausalLM

Browse files
Files changed (3) hide show
  1. config.json +3 -3
  2. configuration.py +56 -0
  3. modeling.py +755 -0
config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
- "_name_or_path": "BeardedMonster/sabiyarn_moe",
3
  "architectures": [
4
  "GPTJXMoEForCausalLM"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "BeardedMonster/sabiyarnMoE--configuration.GPTJXMoEConfig",
8
- "AutoModelForCausalLM": "BeardedMonster/sabiyarnMoE--modeling.GPTJXMoEForCausalLM"
9
  },
10
  "bias": false,
11
  "block_size": 32768,
 
1
  {
2
+ "_name_or_path": "BeardedMonster/MOE",
3
  "architectures": [
4
  "GPTJXMoEForCausalLM"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "configuration.GPTJXMoEConfig",
8
+ "AutoModelForCausalLM": "modeling.GPTJXMoEForCausalLM"
9
  },
10
  "bias": false,
11
  "block_size": 32768,
configuration.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
3
+ from transformers.modeling_outputs import CausalLMOutputWithPast
4
+ from typing import List, Optional, Tuple
5
+ from torch import nn
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import math
9
+
10
+ repo_name = "BeardedMonster/SabiYarn-125M"
11
+
12
+
13
+ class GPTJXMoEConfig(PretrainedConfig):
14
+ """Configuration class for SabiYarn model."""
15
+
16
+ model_type = "sabiyarn"
17
+
18
+ def __init__(
19
+ self,
20
+ block_size: int = 32768,
21
+ vocab_size: int = 52050, # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
22
+ n_layer: int = 12,
23
+ n_heads: int = 12,
24
+ n_embd: int = 768,
25
+ dropout: float = 0.0,
26
+ max_batch_size: int = 1,
27
+ use_kv_cache: bool = True,
28
+ bias: bool = False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
29
+ kv_cache_dtype: str = "float32", # "float32" or "float16" for memory savings
30
+ # MoE hyperparameters
31
+ use_moe: bool = False, # Whether to use MoE instead of dense MLP
32
+ num_experts: int = 4, # Number of experts in MoE layer
33
+ num_experts_per_tok: int = 2, # Number of experts to route each token to (top-k)
34
+ moe_dim: int = None, # MoE hidden dimension (defaults to 4 * n_embd like MLP)
35
+ **kwargs
36
+ ):
37
+ self.block_size = block_size
38
+ self.vocab_size = vocab_size
39
+ self.n_layer = n_layer
40
+ self.n_heads = n_heads
41
+ self.n_embd = n_embd
42
+ self.dropout = dropout
43
+ self.bias = bias
44
+ self.use_kv_cache = use_kv_cache
45
+ self.max_batch_size = max_batch_size
46
+ self.kv_cache_dtype = kv_cache_dtype # Memory optimization: use float16 for cache
47
+
48
+ # MoE configuration
49
+ self.use_moe = use_moe
50
+ self.num_experts = num_experts
51
+ self.num_experts_per_tok = num_experts_per_tok
52
+ # Default moe_dim to match MLP expansion (4x)
53
+ self.moe_dim = moe_dim if moe_dim is not None else (4 * n_embd)
54
+
55
+ super().__init__(**kwargs)
56
+
modeling.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SabiYarn Model Implementation - Optimized Version
3
+ Memory-efficient with performance optimizations for generation.
4
+ Matches original implementation exactly but with memory optimizations.
5
+ """
6
+
7
+ from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoModelForCausalLM
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ # use package-relative import to avoid colliding with unrelated `model` packages
10
+ from .configuration import GPTJXMoEConfig
11
+ from typing import Optional, List, Tuple
12
+ from torch import nn
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import math
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20
+
21
+ def __init__(self, ndim, bias):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(ndim))
24
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25
+
26
+ def forward(self, input):
27
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_heads == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ # regularization
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_heads = config.n_heads
42
+ self.n_embd = config.n_embd
43
+ self.head_dim = config.n_embd // config.n_heads
44
+ self.dropout = config.dropout
45
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
46
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
47
+
48
+ def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False):
49
+ """
50
+ Forward pass with optional KV cache support.
51
+
52
+ Args:
53
+ x: (B, T, C) input embeddings
54
+ attn_mask: Optional attention mask
55
+ past_key_value: Optional tuple of (past_k, past_v) each (B, nh, past_len, hs)
56
+ use_cache: Whether to return cache for next step
57
+
58
+ Returns:
59
+ If use_cache: (output, (k, v)) where output is (B, T, C) and k, v are (B, nh, total_len, hs)
60
+ Else: output (B, T, C)
61
+ """
62
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
63
+
64
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
66
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
67
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
68
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
69
+
70
+ # Concatenate with past KV cache if provided
71
+ if past_key_value is not None:
72
+ past_k, past_v = past_key_value
73
+ k = torch.cat([past_k, k], dim=2) # (B, nh, past_len + T, hs)
74
+ v = torch.cat([past_v, v], dim=2)
75
+
76
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, total_len) -> (B, nh, T, total_len)
77
+ total_len = k.size(2)
78
+
79
+ if self.flash:
80
+ if attn_mask is not None:
81
+ # efficient attention using Flash Attention CUDA kernels
82
+ attn_mask = attn_mask.to(torch.bool)
83
+
84
+ # Handle different mask shapes and convert to (B, nh, T, total_len)
85
+ if attn_mask.dim() == 2:
86
+ # (B, S) - expand to cover full sequence if needed
87
+ B_mask = attn_mask.size(0)
88
+ S = attn_mask.size(1)
89
+
90
+ if S == total_len:
91
+ # Mask already covers full sequence
92
+ pass
93
+ elif S == T:
94
+ # Mask only covers current tokens - expand with ones for past tokens
95
+ if past_key_value is not None:
96
+ past_len = total_len - T
97
+ past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=attn_mask.dtype)
98
+ attn_mask = torch.cat([past_mask, attn_mask], dim=1)
99
+ else:
100
+ # No cache, mask is correct as-is
101
+ pass
102
+ else:
103
+ raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})")
104
+
105
+ # Reshape to (B, 1, T, total_len) for Flash Attention
106
+ # Flash Attention expects mask shape (B, nh, T, S) where T is query length
107
+ # First ensure we have the right length
108
+ if attn_mask.size(1) != total_len:
109
+ raise ValueError(f"Mask length mismatch: got {attn_mask.size(1)}, expected {total_len}")
110
+
111
+ # Reshape: (B, total_len) -> (B, 1, 1, total_len) -> (B, 1, T, total_len) -> (B, nh, T, total_len)
112
+ attn_mask = attn_mask.view(B_mask, 1, 1, total_len)
113
+ # Expand to (B, 1, T, total_len) - repeat for each query position
114
+ attn_mask = attn_mask.expand(B_mask, 1, T, total_len)
115
+ # Expand to include head dimension: (B, nh, T, total_len)
116
+ attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
117
+
118
+ # Verify final shape
119
+ assert attn_mask.shape == (B_mask, self.n_heads, T, total_len), \
120
+ f"Mask shape mismatch: got {attn_mask.shape}, expected ({B_mask}, {self.n_heads}, {T}, {total_len})"
121
+ elif attn_mask.dim() == 4:
122
+ # Already 4D mask - ensure it's the right shape
123
+ B_mask = attn_mask.size(0)
124
+ if attn_mask.size(-2) != T:
125
+ # Slice to match query length if needed
126
+ attn_mask = attn_mask[..., -T:, :]
127
+ # Ensure head dimension matches
128
+ if attn_mask.size(1) == 1:
129
+ attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
130
+ elif attn_mask.size(1) != self.n_heads:
131
+ raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}")
132
+ else:
133
+ raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4")
134
+
135
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=False)
136
+ else:
137
+ # No explicit mask provided
138
+ if past_key_value is None:
139
+ # No cache: use is_causal for efficiency
140
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
141
+ else:
142
+ # With cache: create causal mask manually (can't use is_causal when q and k have different lengths)
143
+ causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool))
144
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask.view(1, 1, T, total_len), dropout_p=self.dropout if self.training else 0, is_causal=False)
145
+ else:
146
+ # manual implementation of attention
147
+ total_len = k.size(2)
148
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
149
+
150
+ if attn_mask is not None:
151
+ attn_mask = attn_mask.to(torch.bool)
152
+
153
+ # Handle different mask shapes and convert to (B, nh, T, total_len)
154
+ if attn_mask.dim() == 2:
155
+ # (B, S) - expand to cover full sequence if needed
156
+ B_mask = attn_mask.size(0)
157
+ S = attn_mask.size(1)
158
+
159
+ if S == total_len:
160
+ # Mask already covers full sequence
161
+ pass
162
+ elif S == T:
163
+ # Mask only covers current tokens - expand with ones for past tokens
164
+ if past_key_value is not None:
165
+ past_len = total_len - T
166
+ past_mask = torch.ones(B_mask, past_len, device=x.device, dtype=torch.bool)
167
+ attn_mask = torch.cat([past_mask, attn_mask], dim=1)
168
+ else:
169
+ # No cache, mask is correct as-is
170
+ pass
171
+ else:
172
+ raise ValueError(f"Unsupported attention_mask shape: {attn_mask.shape}, expected (B, {T}) or (B, {total_len})")
173
+
174
+ # Reshape to (B, 1, T, total_len) then expand to (B, nh, T, total_len)
175
+ attn_mask = attn_mask.view(B_mask, 1, 1, total_len)
176
+ attn_mask = attn_mask.expand(B_mask, 1, T, total_len)
177
+ attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
178
+ elif attn_mask.dim() == 4:
179
+ # Already 4D mask - ensure it's the right shape
180
+ B_mask = attn_mask.size(0)
181
+ if attn_mask.size(-2) != T:
182
+ # Slice to match query length if needed
183
+ attn_mask = attn_mask[..., -T:, :]
184
+ # Ensure head dimension matches
185
+ if attn_mask.size(1) == 1:
186
+ attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1)
187
+ elif attn_mask.size(1) != self.n_heads:
188
+ raise ValueError(f"Mask head dimension {attn_mask.size(1)} doesn't match n_heads {self.n_heads}")
189
+ else:
190
+ raise ValueError(f"Unsupported attention_mask dimension: {attn_mask.dim()}, expected 2 or 4")
191
+
192
+ att = att.masked_fill(~attn_mask, float('-inf'))
193
+ else:
194
+ # Apply causal mask - created on-the-fly (memory efficient, scales to any length)
195
+ # torch.tril() is fast and doesn't require storing large buffers
196
+ # This approach works for 32k, 1M, or any context length
197
+ causal_mask = torch.tril(torch.ones(T, total_len, device=x.device, dtype=torch.bool))
198
+ att = att.masked_fill(~causal_mask.view(1, 1, T, total_len), float('-inf'))
199
+
200
+ att = F.softmax(att, dim=-1)
201
+ att = self.attn_dropout(att)
202
+ y = att @ v # (B, nh, T, total_len) x (B, nh, total_len, hs) -> (B, nh, T, hs)
203
+
204
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
205
+
206
+ # output projection
207
+ y = self.resid_dropout(self.c_proj(y))
208
+
209
+ # Return cache if requested
210
+ if use_cache:
211
+ return y, (k.detach(), v.detach())
212
+ return y
213
+
214
+ class MLP(nn.Module):
215
+
216
+ def __init__(self, config):
217
+ super().__init__()
218
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
219
+ self.gelu = nn.GELU()
220
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
221
+ self.dropout = nn.Dropout(config.dropout)
222
+
223
+ def forward(self, x):
224
+ x = self.c_fc(x)
225
+ x = self.gelu(x)
226
+ x = self.c_proj(x)
227
+ x = self.dropout(x)
228
+ return x
229
+
230
+ class BlockJ(nn.Module):
231
+
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
235
+ self.j = LayerNorm(config.n_embd, config.n_embd)
236
+ self.attn = CausalSelfAttention(config)
237
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
238
+
239
+ # Use MoE if configured, otherwise use dense MLP
240
+ if getattr(config, 'use_moe', False):
241
+ self.mlp = MoE(
242
+ num_experts_per_tok=config.num_experts_per_tok,
243
+ num_experts=config.num_experts,
244
+ emb_dim=config.n_embd,
245
+ moe_dim=config.moe_dim,
246
+ dropout=config.dropout
247
+ )
248
+ self.use_moe = True
249
+ else:
250
+ self.mlp = MLP(config)
251
+ self.use_moe = False
252
+
253
+ def forward(self, x, attn_mask=None, past_key_value=None, use_cache=False):
254
+ """
255
+ Forward pass with optional KV cache support.
256
+
257
+ Args:
258
+ x: (B, T, C) input embeddings
259
+ attn_mask: Optional attention mask
260
+ past_key_value: Optional tuple of (past_k, past_v) for attention layer
261
+ use_cache: Whether to return cache for next step
262
+
263
+ Returns:
264
+ If use_cache: (output, (k, v)) where output is (B, T, C)
265
+ Else: output (B, T, C)
266
+ """
267
+ h = x
268
+ x_ln = self.ln_1(x)
269
+
270
+ # Attention with optional KV cache
271
+ if use_cache:
272
+ attn_out, new_past = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=True)
273
+ x = h + attn_out + self.j(x_ln)
274
+ else:
275
+ attn_out = self.attn(x_ln, attn_mask=attn_mask, past_key_value=past_key_value, use_cache=False)
276
+ x = h + attn_out + self.j(x_ln)
277
+
278
+ x = x + self.mlp(self.ln_2(x))
279
+
280
+ if use_cache:
281
+ return x, new_past
282
+ return x
283
+
284
+
285
+ class MoE(nn.Module):
286
+ """
287
+ An MoE layer with MLP block with swiglue activation function.
288
+ Optimized for production workflows with proper initialization and dropout support.
289
+ """
290
+
291
+ def __init__(self, num_experts_per_tok: int, num_experts: int, emb_dim: int, moe_dim: int, dropout: float = 0.0, dtype=torch.float32):
292
+ super().__init__()
293
+ self.k = int(num_experts_per_tok)
294
+ self.E = int(num_experts)
295
+ self.D = int(emb_dim)
296
+ self.H = int(moe_dim)
297
+ self.dropout = dropout
298
+
299
+ self.gate = nn.Linear(self.D, self.E, bias=False, dtype=dtype) # use gate variable bcause couldnt load from checkpoint
300
+ # Match MLP structure: c_fc -> GELU -> c_proj
301
+ self.fc_bank = nn.Parameter(torch.empty(self.E, self.D, self.H, dtype=dtype)) # Equivalent to c_fc: (n_embd -> 4*n_embd)
302
+ self.proj_bank = nn.Parameter(torch.empty(self.E, self.H, self.D, dtype=dtype)) # Equivalent to c_proj: (4*n_embd -> n_embd)
303
+ self.gelu = nn.GELU() # Match MLP activation
304
+ self.dropout_layer = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
305
+
306
+ # Initialize parameters
307
+ self._init_parameters()
308
+
309
+
310
+ def expert_utilization(self, logits):
311
+ """
312
+ This function compute expert utilization per token and also compute load balancer loss.
313
+ Details of this load balancer can be found in https://arxiv.org/abs/2101.03961
314
+ """
315
+
316
+ _, selected = logits.topk(self.k, dim=-1)
317
+ selected = F.one_hot(selected, num_classes=self.E).sum(dim=2) # B, T, E
318
+
319
+ load = torch.mean(selected.float(), dim=(0,1))
320
+
321
+ # average router probability per expert
322
+ P = torch.softmax(logits, dim=-1).float().mean(dim=(0,1)) # [E]
323
+ self._router_probs = P.detach() # per-expert avg prob
324
+ self._aux_lb = self.E * torch.sum(load * P)
325
+
326
+
327
+ self._expert_utilization = load
328
+
329
+ def _init_parameters(self):
330
+ """Initialize MoE parameters following standard practices."""
331
+ # Initialize gate with small values to start with uniform routing
332
+ nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)
333
+
334
+ # Initialize expert banks to match MLP initialization
335
+ # fc_bank: standard normal (like c_fc in MLP)
336
+ nn.init.normal_(self.fc_bank, mean=0.0, std=0.02)
337
+
338
+ # proj_bank: smaller initialization for stability (like c_proj in MLP)
339
+ nn.init.normal_(self.proj_bank, mean=0.0, std=0.02 / math.sqrt(2))
340
+
341
+ def forward(self, x):
342
+ B, T, D = x.shape
343
+ assert D == self.D, f"Expected emb_dim={self.D}, got {D}"
344
+
345
+ logits = self.gate(x) # B, T, E
346
+
347
+ if self.training:
348
+ logits = logits + torch.randn_like(logits) * 1e-1
349
+
350
+
351
+ topk_logits, selected = logits.topk(self.k, dim=-1)
352
+ topk_probs = F.softmax(topk_logits, dim=-1)
353
+
354
+ # Match MLP structure exactly: c_fc -> GELU -> c_proj
355
+ # Step 1: c_fc equivalent: x @ fc_bank -> (B, T, E, H)
356
+ h = torch.einsum("btd,edh->bteh", x, self.fc_bank) # B, T, E, H
357
+
358
+ # Step 2: GELU activation (matching MLP)
359
+ h = self.gelu(h) # B, T, E, H
360
+
361
+ # Step 3: c_proj equivalent: h @ proj_bank -> (B, T, E, D)
362
+ y = torch.einsum("bteh,ehd->bted", h, self.proj_bank) # B, T, E, D
363
+
364
+ # Step 4: Select top-k experts and combine
365
+ gather_idx = selected.view(B, T, -1, 1).expand(-1, -1, -1, self.D) # B, T, K, D
366
+ y = torch.gather(y, dim=2, index=gather_idx) # B, T, K, D
367
+
368
+ # Step 5: Weighted sum of selected experts
369
+ y = (y * topk_probs.unsqueeze(-1)).sum(dim=2) # B, T, D
370
+
371
+ # Step 6: Apply dropout like MLP
372
+ y = self.dropout_layer(y)
373
+
374
+ self.expert_utilization(logits)
375
+ return y
376
+
377
+
378
+ class GPTJXMoEForCausalLM(PreTrainedModel):
379
+ config_class = GPTJXMoEConfig
380
+ base_model_prefix = "transformer"
381
+ is_parallelizable = True
382
+ supports_gradient_checkpointing = True
383
+ _no_split_modules = ["BlockJ"]
384
+ # _skip_keys_device_placement = "past_key_values"
385
+ _supports_flash_attn_2 = True
386
+ _tied_weights_keys = ["lm_head.weight"]
387
+
388
+
389
+ def __init__(self, config):
390
+ super().__init__(config)
391
+ assert config.vocab_size is not None
392
+ assert config.block_size is not None
393
+ self.config = config
394
+
395
+ self.transformer = nn.ModuleDict(dict(
396
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
397
+ wpe = nn.Embedding(config.block_size, config.n_embd),
398
+ drop = nn.Dropout(config.dropout),
399
+ h = nn.ModuleList([BlockJ(config) for _ in range(config.n_layer)]),
400
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
401
+ ))
402
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
403
+ self.transformer.wte.weight = self.lm_head.weight
404
+
405
+ # No need to store causal mask buffer - masks are created on-the-fly when needed
406
+ # Flash Attention handles causality internally with is_causal=True
407
+ # For manual attention, torch.tril() creates masks efficiently on-the-fly
408
+ # This approach scales to any context length (1M+ tokens) without memory overhead
409
+
410
+ self.apply(self._init_weights)
411
+
412
+ for pn, p in self.named_parameters():
413
+ if pn.endswith('c_proj.weight'):
414
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
415
+
416
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
417
+
418
+ def get_num_params(self, non_embedding=True):
419
+ """
420
+ Return the number of parameters in the model.
421
+ For non-embedding count (default), the position embeddings get subtracted.
422
+ The token embeddings would too, except due to the parameter sharing these
423
+ params are actually used as weights in the final layer, so we include them.
424
+ """
425
+ n_params = sum(p.numel() for p in self.parameters())
426
+ if non_embedding:
427
+ n_params -= self.transformer.wpe.weight.numel()
428
+ return n_params
429
+
430
+ def get_expert_utilization(self):
431
+ """
432
+ Get expert utilization statistics for MoE layers.
433
+ Returns expert utilization per layer and load balancing loss.
434
+ Only works when use_moe=True in config.
435
+ """
436
+ if not getattr(self.config, 'use_moe', False):
437
+ return None, None
438
+
439
+ lb_loss, expert_utilization_per_layer = 0, []
440
+ moe_layers = 0
441
+ for block in self.transformer.h:
442
+ if hasattr(block, 'use_moe') and block.use_moe and hasattr(block.mlp, '_aux_lb'):
443
+ lb_loss += block.mlp._aux_lb
444
+ expert_utilization_per_layer.append(block.mlp._expert_utilization.detach().cpu())
445
+ moe_layers += 1
446
+
447
+ if moe_layers > 0:
448
+ lb_loss = lb_loss / moe_layers
449
+ return expert_utilization_per_layer, lb_loss
450
+
451
+ def get_input_embeddings(self):
452
+ return self.transformer.wte
453
+
454
+ def set_input_embeddings(self, new_embeddings):
455
+ self.transformer.wte = new_embeddings
456
+
457
+ def forward(
458
+ self,
459
+ input_ids,
460
+ targets=None,
461
+ attn_mask=None,
462
+ attention_mask=None, # HF standard name
463
+ past_key_values=None,
464
+ position_ids=None,
465
+ use_cache=None,
466
+ output_hidden_states: Optional[bool] = None,
467
+ **kwargs
468
+ ):
469
+ """
470
+ Forward pass with KV cache support for efficient generation.
471
+
472
+ Args:
473
+ input_ids: (B, T) Token indices
474
+ targets: Optional (B, T) target token indices for training
475
+ attn_mask: Optional attention mask (legacy name)
476
+ attention_mask: Optional attention mask (HF standard name, takes precedence)
477
+ past_key_values: Optional list of (k, v) tuples from previous steps for KV cache
478
+ position_ids: Optional (B, T) position indices (if None, computed from past_key_values)
479
+ use_cache: Whether to return past_key_values for next step (defaults to config.use_kv_cache)
480
+ output_hidden_states: Whether to return hidden states
481
+
482
+ Returns:
483
+ CausalLMOutputWithPast with logits and optionally past_key_values
484
+ """
485
+ device = input_ids.device
486
+ b, t = input_ids.size()
487
+
488
+ # Use attention_mask if provided (HF standard), otherwise fall back to attn_mask
489
+ if attention_mask is not None:
490
+ attn_mask = attention_mask
491
+
492
+ # Determine if we're using KV cache
493
+ use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False)
494
+
495
+ # Compute past sequence length if using cache
496
+ past_len = 0
497
+ if past_key_values is not None:
498
+ past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0
499
+
500
+ # Handle position_ids
501
+ if position_ids is None:
502
+ # Compute position IDs: from past_len to past_len + t
503
+ pos = torch.arange(past_len, past_len + t, dtype=torch.long, device=device)
504
+ else:
505
+ pos = position_ids
506
+
507
+ # Validate sequence length
508
+ total_len = past_len + t
509
+ assert total_len <= self.config.block_size, f"Cannot forward sequence of length {total_len}, block size is only {self.config.block_size}"
510
+
511
+ # forward the GPT model itself
512
+ tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd)
513
+
514
+ # Handle position embeddings: wpe expects 1D position indices
515
+ if pos.dim() == 2:
516
+ # If position_ids is 2D (B, T), extract first row (assuming all sequences have same positions)
517
+ pos_1d = pos[0] if pos.size(0) > 0 else pos.squeeze(0)
518
+ else:
519
+ pos_1d = pos
520
+
521
+ pos_emb = self.transformer.wpe(pos_1d) # position embeddings of shape (t, n_embd)
522
+ if pos_emb.dim() == 2:
523
+ pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1) # Expand to (b, t, n_embd)
524
+ x = self.transformer.drop(tok_emb + pos_emb)
525
+
526
+ # Expand attention_mask to cover full sequence (past + current) if needed
527
+ # HF's generation API may provide mask only for current tokens
528
+ if attn_mask is not None and past_key_values is not None and use_kv_cache:
529
+ # Check if mask needs expansion
530
+ if attn_mask.dim() == 2:
531
+ mask_len = attn_mask.size(1)
532
+ if mask_len == t and total_len > t:
533
+ # Mask only covers current tokens, expand with ones for past tokens
534
+ past_len = total_len - t
535
+ past_mask = torch.ones(b, past_len, device=device, dtype=attn_mask.dtype)
536
+ attn_mask = torch.cat([past_mask, attn_mask], dim=1)
537
+
538
+ # Process through transformer layers with KV cache
539
+ new_past_key_values = [] if use_kv_cache else None
540
+
541
+ for i, block in enumerate(self.transformer.h):
542
+ layer_past = past_key_values[i] if past_key_values is not None else None
543
+
544
+ if use_kv_cache:
545
+ x, new_past = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=True)
546
+ new_past_key_values.append(new_past)
547
+ else:
548
+ x = block(x, attn_mask=attn_mask, past_key_value=layer_past, use_cache=False)
549
+
550
+ x = self.transformer.ln_f(x)
551
+
552
+ # Compute logits and loss
553
+ if targets is not None:
554
+ # Training: compute logits for all positions
555
+ logits = self.lm_head(x)
556
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
557
+ else:
558
+ # Inference: only compute logits for last position when using cache, all positions otherwise
559
+ if use_kv_cache and past_key_values is not None:
560
+ logits = self.lm_head(x[:, [-1], :]) # Only last token
561
+ else:
562
+ logits = self.lm_head(x) # All tokens
563
+ loss = None
564
+
565
+ return CausalLMOutputWithPast(
566
+ loss=loss,
567
+ logits=logits,
568
+ past_key_values=tuple(new_past_key_values) if use_kv_cache else None,
569
+ hidden_states=x if output_hidden_states else None,
570
+ attentions=None,
571
+ )
572
+
573
+ def prepare_inputs_for_generation(
574
+ self,
575
+ input_ids,
576
+ attention_mask=None,
577
+ past_key_values=None,
578
+ position_ids=None,
579
+ use_cache=None,
580
+ **kwargs
581
+ ):
582
+ """
583
+ Prepare inputs for generation with KV cache support.
584
+ This method is called by HF's generation API.
585
+ """
586
+ # Determine if we should use cache
587
+ use_kv_cache = use_cache if use_cache is not None else getattr(self.config, 'use_kv_cache', False)
588
+
589
+ # Base model inputs
590
+ model_inputs = {
591
+ "input_ids": input_ids,
592
+ }
593
+
594
+ # ---- 1. Handle KV cache (past_key_values) ----
595
+ if past_key_values is not None and use_kv_cache:
596
+ # Only feed the last token when using cached keys/values
597
+ model_inputs["input_ids"] = input_ids[:, -1:]
598
+ model_inputs["past_key_values"] = past_key_values
599
+
600
+ # ---- 2. Handle attention mask ----
601
+ if attention_mask is not None:
602
+ # When using cache, attention_mask should cover the full sequence (past + current)
603
+ if past_key_values is not None and use_kv_cache:
604
+ # Extend attention mask to include past tokens
605
+ # HF generation will handle this, but we ensure it's passed through
606
+ pass
607
+ model_inputs["attention_mask"] = attention_mask
608
+
609
+ # ---- 3. Handle position_ids correctly ----
610
+ # HF relies on this for models like GPT-J, GPT-NeoX, Llama, etc.
611
+ if position_ids is not None:
612
+ if past_key_values is not None and use_kv_cache:
613
+ # Only use the last position when using cache
614
+ position_ids = position_ids[:, -1].unsqueeze(-1)
615
+ model_inputs["position_ids"] = position_ids
616
+ elif past_key_values is not None and use_kv_cache:
617
+ # Compute position_ids from past_key_values length
618
+ past_len = past_key_values[0][0].size(2) if len(past_key_values) > 0 else 0
619
+ model_inputs["position_ids"] = torch.tensor([[past_len]], device=input_ids.device, dtype=torch.long)
620
+
621
+ # ---- 4. Forward arbitrary extra kwargs safely ----
622
+ # For example: use_cache, output_attentions, token_type_ids, etc.
623
+ if use_cache is not None:
624
+ model_inputs["use_cache"] = use_cache
625
+
626
+ for k, v in kwargs.items():
627
+ if v is not None:
628
+ model_inputs[k] = v
629
+
630
+ return model_inputs
631
+
632
+ def _reorder_cache(
633
+ self,
634
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
635
+ beam_idx: torch.Tensor,
636
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
637
+ """
638
+ Reorder cache for beam search.
639
+
640
+ Required by HF for beam search to work correctly.
641
+ Selects which beam samples to keep based on beam_idx.
642
+
643
+ Args:
644
+ past_key_values: List of (k, v) tuples from previous steps
645
+ beam_idx: (batch_size,) tensor indicating which beams to keep
646
+
647
+ Returns:
648
+ Reordered past_key_values
649
+ """
650
+ reordered_past = []
651
+ for layer_past in past_key_values:
652
+ k, v = layer_past
653
+ device = k.device
654
+ beam_idx_dev = beam_idx.to(device)
655
+ reordered_past.append((
656
+ k.index_select(0, beam_idx_dev),
657
+ v.index_select(0, beam_idx_dev)
658
+ ))
659
+ return reordered_past
660
+
661
+
662
+ def crop_block_size(self, block_size):
663
+ assert block_size <= self.config.block_size
664
+ self.config.block_size = block_size
665
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
666
+ for block in self.transformer.h:
667
+ if hasattr(block.attn, 'bias'):
668
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
669
+
670
+ def load_dense_weights_into_moe(self, dense_state_dict, strict=False):
671
+ """
672
+ Migrate Dense MLP weights to MoE experts.
673
+ Ensures exact mathematical equivalence by cloning weights/biases to ALL experts.
674
+ """
675
+ if not getattr(self.config, 'use_moe', False):
676
+ return self.load_state_dict(dense_state_dict, strict=strict)
677
+
678
+ print("Converting Dense Checkpoint -> MoE Checkpoint...")
679
+ moe_state_dict = {}
680
+
681
+ # Get config details
682
+ num_experts = self.config.num_experts
683
+ moe_dim = self.config.moe_dim
684
+
685
+ for key, value in dense_state_dict.items():
686
+ # Identify MLP weights
687
+ if 'mlp.c_fc' in key or 'mlp.c_proj' in key:
688
+
689
+ # Extract layer index and type (weight/bias)
690
+ # key format: transformer.h.{i}.mlp.c_fc.{weight/bias}
691
+ parts = key.split('.')
692
+ layer_idx = parts[2]
693
+ layer_key_prefix = f"transformer.h.{layer_idx}.mlp"
694
+
695
+ is_bias = 'bias' in key
696
+ is_fc = 'c_fc' in key
697
+
698
+ # --- Handle c_fc (Input -> Hidden) ---
699
+ if is_fc:
700
+ if not is_bias:
701
+ # Weight: Dense is (H, D) -> MoE needs (E, D, H)
702
+ # 1. Transpose to (D, H)
703
+ w_T = value.t()
704
+ # 2. Slice to moe_dim if necessary
705
+ w_T = w_T[:, :moe_dim]
706
+ # 3. Expand to (E, D, H)
707
+ new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone()
708
+ moe_state_dict[f"{layer_key_prefix}.fc_bank"] = new_val
709
+ else:
710
+ # Bias: Dense is (H) -> MoE needs (E, H)
711
+ b = value[:moe_dim]
712
+ new_val = b.unsqueeze(0).expand(num_experts, -1).clone()
713
+ moe_state_dict[f"{layer_key_prefix}.fc_bias"] = new_val
714
+
715
+ # --- Handle c_proj (Hidden -> Output) ---
716
+ else:
717
+ if not is_bias:
718
+ # Weight: Dense is (D, H) -> MoE needs (E, H, D)
719
+ # 1. Transpose to (H, D)
720
+ w_T = value.t()
721
+ # 2. Slice source dimension (H) if necessary
722
+ w_T = w_T[:moe_dim, :]
723
+ # 3. Expand to (E, H, D)
724
+ new_val = w_T.unsqueeze(0).expand(num_experts, -1, -1).clone()
725
+ moe_state_dict[f"{layer_key_prefix}.proj_bank"] = new_val
726
+ else:
727
+ # Bias: Dense is (D) -> MoE needs (E, D)
728
+ # Bias is on the output, so dimension is D, usually doesn't need slicing
729
+ new_val = value.unsqueeze(0).expand(num_experts, -1).clone()
730
+ moe_state_dict[f"{layer_key_prefix}.proj_bias"] = new_val
731
+
732
+ # --- Initialize Gate (if not yet initialized) ---
733
+ # We initialize gate to zero to ensure uniform routing probability initially,
734
+ # which guarantees average of identical experts == single expert.
735
+ gate_key = f"{layer_key_prefix}.gate.weight"
736
+ if gate_key not in moe_state_dict:
737
+ # Zeros = equal probability for all experts
738
+ moe_state_dict[gate_key] = torch.zeros(num_experts, self.config.n_embd)
739
+
740
+ else:
741
+ # Copy non-MLP keys directly (Attn, LayerNorm, Embeddings)
742
+ moe_state_dict[key] = value
743
+
744
+ print("Loading constructed state dict...")
745
+ return self.load_state_dict(moe_state_dict, strict=strict)
746
+
747
+
748
+ AutoConfig.register("sabiyarn", GPTJXMoEConfig)
749
+ AutoModel.register(GPTJXMoEConfig,GPTJXMoEForCausalLM)
750
+ AutoModelForCausalLM.register(GPTJXMoEConfig, GPTJXMoEForCausalLM)
751
+
752
+
753
+
754
+
755
+