firdavsus commited on
Commit
943bd92
·
verified ·
1 Parent(s): a14a7af

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ full_diagnostics.png filter=lfs diff=lfs merge=lfs -text
37
+ training_curves_with_eval.png filter=lfs diff=lfs merge=lfs -text
38
+ weight_histograms.png filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/load-checkpoint.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import login, upload_folder
2
+
3
+
4
+ login()
5
+
6
+
7
+ upload_folder(folder_path=".", repo_id="firdavsus/LLM_D4", repo_type="model")
LLM_2.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import inspect
3
+ from dataclasses import dataclass
4
+ from contextlib import nullcontext
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ from typing import Tuple
10
+ import inspect
11
+
12
+ from transformers.modeling_outputs import CausalLMOutput
13
+ from manager import MANAGER
14
+
15
+ torch.manual_seed(101)
16
+
17
+ def precompute_freqs_cis(config):
18
+ # We now return cos and sin directly instead of a complex polar tensor
19
+ freqs = 1.0 / (config.theta ** (torch.arange(0, config.d_rotate, 2)[: (config.d_rotate // 2)].float() / config.d_rotate))
20
+ t = torch.arange(config.block_size, device=freqs.device)
21
+ freqs = torch.outer(t, freqs).float() # [seq_len, d_rotate/2]
22
+
23
+ # Cos and Sin are what Inductor can easily optimize
24
+ cos = torch.cos(freqs)
25
+ sin = torch.sin(freqs)
26
+
27
+ # Repeat along the last dimension to match the d_rotate size
28
+ # [seq_len, d_rotate/2] -> [seq_len, d_rotate]
29
+ cos = torch.repeat_interleave(cos, 2, dim=-1)
30
+ sin = torch.repeat_interleave(sin, 2, dim=-1)
31
+ return cos, sin
32
+
33
+ def rotate_half(x):
34
+ """Rotates half the hidden dims of the input."""
35
+ # x: [..., d_rotate]
36
+ # Split into [x1, x2, x3, x4...] -> x1, x2 are pairs
37
+ # We use the interleaving pattern: [-x2, x1, -x4, x3...]
38
+ x1 = x[..., 0::2]
39
+ x2 = x[..., 1::2]
40
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
41
+
42
+ def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
43
+ # Reshape freqs for broadcasting: [seq_len, d_rotate] -> [1, seq_len, 1, d_rotate]
44
+ # This matches (batch, seq, head, dim)
45
+ cos = freqs_cos[:xq.shape[1]].view(1, xq.shape[1], 1, xq.shape[-1])
46
+ sin = freqs_sin[:xq.shape[1]].view(1, xq.shape[1], 1, xq.shape[-1])
47
+
48
+ # The RoPE formula: x_out = x * cos + rotate_half(x) * sin
49
+ xq_out = (xq * cos) + (rotate_half(xq) * sin)
50
+ xk_out = (xk * cos) + (rotate_half(xk) * sin)
51
+
52
+ return xq_out.type_as(xq), xk_out.type_as(xk)
53
+
54
+ class MultiHeadLatentAttention(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.d_model = config.n_embd
58
+ self.num_head = config.n_head
59
+ self.d_head = self.d_model // self.num_head
60
+
61
+ self.d_c = config.d_c
62
+ self.d_c1 = config.d_c1
63
+ self.d_rotate = config.d_rotate
64
+
65
+ # ==========================================
66
+ # FUSION 1: All Projections from 'x'
67
+ # Replaces DQ_proj, DKV_proj, and RK_proj
68
+ # ==========================================
69
+ self.W_down = nn.Linear(
70
+ self.d_model,
71
+ self.d_c1 + self.d_c + self.d_rotate,
72
+ bias=config.bias
73
+ )
74
+ self.W_down.is_attention = True
75
+
76
+ # ==========================================
77
+ # FUSION 2: All Q Up-Projections from 'C_Q'
78
+ # Replaces UQ_proj and RQ_proj
79
+ # ==========================================
80
+ self.W_up_q = nn.Linear(
81
+ self.d_c1,
82
+ self.d_model + (self.num_head * self.d_rotate),
83
+ bias=config.bias
84
+ )
85
+ self.W_up_q.is_attention = True
86
+
87
+ # ==========================================
88
+ # FUSION 3: All KV Up-Projections from 'C_KV'
89
+ # Replaces UK_proj and UV_proj (STILL STRICTLY SEPARATE WEIGHTS)
90
+ # ==========================================
91
+ self.W_up_kv = nn.Linear(
92
+ self.d_c,
93
+ self.d_model + self.d_model, # d_model for K, d_model for V
94
+ bias=config.bias
95
+ )
96
+ self.W_up_kv.is_attention = True
97
+
98
+ self.q_norm = nn.RMSNorm(self.d_c1)
99
+ self.kv_norm = nn.RMSNorm(self.d_c)
100
+
101
+ # Output projection and Regularization
102
+ self.output_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
103
+ self.output_proj.output_proj_marker = True
104
+ self.output_proj.is_attention = True
105
+
106
+ self.dropout = nn.Dropout(config.dropout)
107
+ self.attn_dropout_p = config.dropout
108
+
109
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
110
+ cos, sin = precompute_freqs_cis(config)
111
+ self.register_buffer("freqs_cos", cos, persistent=False)
112
+ self.register_buffer("freqs_sin", sin, persistent=False)
113
+
114
+ def forward(self, x):
115
+ batch_size, seq_len, _ = x.size()
116
+
117
+ # ---------------------------------------------------------
118
+ # 1. KERNEL 1: Down-project everything at once
119
+ # ---------------------------------------------------------
120
+ down_out = self.W_down(x)
121
+ # Split into the 3 exact latents your math requires
122
+ C_Q, C_KV, K_rotate = down_out.split(
123
+ [self.d_c1, self.d_c, self.d_rotate], dim=-1
124
+ )
125
+
126
+ C_Q = self.q_norm(C_Q)
127
+ C_KV = self.kv_norm(C_KV)
128
+
129
+ # ---------------------------------------------------------
130
+ # 2. KERNEL 2: Up-project Query content and RoPE
131
+ # ---------------------------------------------------------
132
+ q_up_out = self.W_up_q(C_Q)
133
+ Q_state, Q_rotate = q_up_out.split(
134
+ [self.d_model, self.num_head * self.d_rotate], dim=-1
135
+ )
136
+ Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head)
137
+ Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate)
138
+
139
+ # ---------------------------------------------------------
140
+ # 3. KERNEL 3: Up-project Key and Value content independently
141
+ # ---------------------------------------------------------
142
+ kv_up_out = self.W_up_kv(C_KV)
143
+ K_state, V_state = kv_up_out.split(
144
+ [self.d_model, self.d_model], dim=-1
145
+ )
146
+ K_state = K_state.view(batch_size, seq_len, self.num_head, self.d_head)
147
+ V_state = V_state.view(batch_size, seq_len, self.num_head, self.d_head)
148
+
149
+ # Prepare shared RoPE Key
150
+ K_rotate = K_rotate.view(batch_size, seq_len, 1, self.d_rotate).expand(-1, -1, self.num_head, -1)
151
+
152
+ # ---------------------------------------------------------
153
+ # 4. Apply RoPE, Concatenate, and Attention
154
+ # ---------------------------------------------------------
155
+ Q_rotate, K_rotate = apply_rotary_emb(
156
+ Q_rotate,
157
+ K_rotate,
158
+ self.freqs_cos,
159
+ self.freqs_sin
160
+ )
161
+
162
+ Q = torch.cat([Q_state, Q_rotate], dim=-1).transpose(1, 2)
163
+ K = torch.cat([K_state, K_rotate], dim=-1).transpose(1, 2)
164
+ V = V_state.transpose(1, 2)
165
+
166
+ if self.flash:
167
+ att_output = F.scaled_dot_product_attention(
168
+ Q, K, V,
169
+ dropout_p=self.attn_dropout_p if self.training else 0.0,
170
+ is_causal=True
171
+ )
172
+ else:
173
+ scaler = 1.0 / math.sqrt(self.d_head + self.d_rotate)
174
+ att_matrix = (Q @ K.transpose(-2, -1)) * scaler
175
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).view(1, 1, seq_len, seq_len)
176
+ att_matrix = att_matrix.masked_fill(mask == 0, float('-inf'))
177
+ att_score = self.dropout(F.softmax(att_matrix, dim=-1))
178
+ att_output = att_score @ V
179
+
180
+ att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
181
+
182
+ return self.output_proj(att_output)
183
+
184
+ class Router(nn.Module):
185
+ def __init__(self, config):
186
+ super().__init__()
187
+
188
+ # router settings
189
+ self.top_k = config.top_k
190
+ self.n_exp = config.n_exp
191
+ assert self.top_k >= 1 and self.top_k <= config.n_exp
192
+ self.use_noisy_top_k = config.use_noisy_top_k
193
+ self.train_capacity = config.train_capacity
194
+ self.eval_capacity = config.eval_capacity
195
+ self.min_capacity = config.min_capacity
196
+ self.router_use_full_prec = config.router_use_full_prec
197
+
198
+ # auxiliary / load balancing loss settings
199
+ self.use_aux_loss = config.use_aux_loss
200
+ self.use_router_z_loss = config.use_router_z_loss
201
+
202
+ # linear projection for (noisy) softmax gating
203
+ # no bias is used, see page 4 eq (4) in (https://arxiv.org/abs/1701.06538)
204
+ self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False)
205
+ self.w_g.router_marker = True
206
+ self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False) if self.use_noisy_top_k else None
207
+
208
+ def forward(self, x):
209
+ # optionally run the router in full precision to avoid instability during training
210
+ # see discussion on pg. 9 here: https://arxiv.org/abs/2101.03961
211
+ # setting enabled to False in autocast automatically puts everything in float32
212
+ device_type = 'cuda' if torch.cuda.is_available() else 'cpu' # for later use in torch.autocast
213
+ ctx = nullcontext() if not self.router_use_full_prec else torch.amp.autocast(device_type=device_type, enabled=False)
214
+
215
+ with ctx:
216
+ B, T, _ = x.size()
217
+ num_tokens = B * T
218
+
219
+ # eq (4) in (https://arxiv.org/abs/1701.06538)
220
+ logits = self.w_g(x) # [B, T, n_exp]
221
+ if self.use_noisy_top_k:
222
+ # optionally add noise into the router
223
+ noise = F.softplus(self.w_noise(x))
224
+ noise *= torch.randn_like(noise)
225
+ logits += noise
226
+
227
+ # router z loss, computed on logits (before softmax)
228
+ # this loss prevents router logits from becoming too large
229
+ if self.use_router_z_loss:
230
+ z_loss = self.compute_router_z_loss(logits)
231
+ MANAGER.add_router_z_loss(z_loss)
232
+
233
+ # find top k experts for each token
234
+ top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) # [B, T, k]
235
+
236
+ # normalize expert probabilities
237
+ # Question: should we normalize over all experts or just top-k?
238
+ # we choose to normalize over top-k, other option is commented out below
239
+
240
+ # Shazeer et al (https://arxiv.org/abs/1701.06538) does only topk
241
+ # see page 4 eq (3)-(5), the code for this is commented out below
242
+ router_probs = torch.full_like(logits, float('-inf')) # [B, T, n_exp]
243
+ router_probs.scatter_(-1, top_k_indices, top_k_logits)
244
+ router_probs = F.softmax(router_probs, dim=-1)
245
+
246
+ # # normalize all router logits (not just top-k) via softmax
247
+ router_probs = F.softmax(logits, dim=-1)
248
+
249
+ # compute auxiliary load balancing loss
250
+ # this loss encourages equal probability assigned to each expert
251
+ # and equal load balancing of tokens assigned to each expert
252
+ if self.use_aux_loss:
253
+ aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
254
+ MANAGER.add_aux_loss(aux_loss)
255
+
256
+ # compute expert capacity
257
+ exp_capacity = self.get_capacity(num_tokens)
258
+
259
+ # make a multi-hot mask of chosen experts, size [B, T, n_exp]
260
+ # entries are 0 if expert not chosen and 1 if expert chosen
261
+ exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp) # [B, T, k, n_exp]
262
+ exp_mask = exp_mask.view(num_tokens, self.top_k, self.n_exp) # [B * T, k, n_exp]
263
+ exp_mask = exp_mask.permute(1, 0, 2) # [k, B * T, n_exp]
264
+
265
+ # compute cumulative sum of each token over experts, this stores
266
+ # the index of each token within the batch of each expert
267
+ # NOTE: cumsum should count all top-1 first, top-2 second, etc.
268
+ # so that we prioritize top experts when dropping tokens (this is
269
+ # done by putting k dimension first for the reshape operation)
270
+ exp_rank = exp_mask.reshape(self.top_k * num_tokens, self.n_exp) # [k * B * T, n_exp]
271
+ exp_rank = torch.cumsum(exp_rank, dim=0) - 1 # cumulative sum of expert selections [k * B * T, n_exp]
272
+ exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp) # [k, B * T, n_exp]
273
+
274
+ # mask out (set to zero) entries that go beyond expert capacity
275
+ # compute amount of used capacity by taking a sum over mask
276
+ exp_mask *= torch.lt(exp_rank, exp_capacity) # [k, B * T, n_exp]
277
+ used_capacity = torch.sum(exp_mask, dim=(0, 1)) # [n_exp]
278
+
279
+ # mask rank to only include tokens that are selected
280
+ # perform a sum so each row only contains index of token
281
+ # for the expert that is selected in that row
282
+ # result is a matrix that contains the position of each token
283
+ # in the batch of its corresponding expert
284
+ exp_rank = torch.sum(exp_mask * exp_rank, dim=-1) # [k, B * T]
285
+
286
+ # mask probabilities to only include selected experts
287
+ router_probs = router_probs.view(num_tokens, self.n_exp)[None, :] # [1, B * T, n_exp]
288
+ exp_weights = exp_mask * router_probs # [k, B * T, n_exp]
289
+
290
+ # convert rank into one-hot vectors over the available capacity
291
+ # stores the position of each token within the capacity of the selected expert
292
+ exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity) # [k, B * T, exp_capacity]
293
+
294
+ # create a vector that stores, for each token, the weight of selected
295
+ # experts at token's position in the capacity of that expert
296
+ # size of tensor is [B * T, n_exp, exp_capacity]
297
+ cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim=0)
298
+ sec_mask = cb_weight.bool() # binary mask of selected experts for each token
299
+ return used_capacity, cb_weight, sec_mask
300
+
301
+ def compute_aux_loss(self, expert_probs: torch.Tensor, indices: torch.Tensor):
302
+ """
303
+ Computes Switch Transformer auxiliary loss (https://arxiv.org/abs/2101.03961)
304
+ See equations (4)-(6) on page 7
305
+ """
306
+
307
+ # equation (5): compute ratio of tokens allocated to each expert
308
+ # total number of tokens is defined as total tokens in batch * k
309
+ # (k = 1) for the Switch Transformer
310
+ with torch.no_grad():
311
+ one_hot_indices = F.one_hot(indices, num_classes=self.n_exp) # [B, T, k, n_exp]
312
+ one_hot_indices = torch.sum(one_hot_indices.float(), dim=2) # [B, T, n_exp] (sum over k dimension)
313
+ tokens_per_expert = torch.mean(one_hot_indices.float(), dim=(0, 1))
314
+
315
+ # equation (6): compute ratio of router probability allocated to each expert
316
+ prob_per_expert = torch.mean(expert_probs.float(), dim=(0, 1))
317
+
318
+ # equation (4): take a scaled dot product between prob/token allocation vectors
319
+ # multiply the result by the number of experts
320
+ return self.n_exp * torch.sum(prob_per_expert * tokens_per_expert)
321
+
322
+ def compute_router_z_loss(self, logits: torch.Tensor):
323
+ """
324
+ Computes ST-MoE router z loss (https://arxiv.org/abs/2202.08906)
325
+ See equation (5) on page 7
326
+ """
327
+
328
+ # exponentiate logits, sum logits of each expert, take log, and square
329
+ # code below is the same as:
330
+ # > z_loss = torch.exp(logits)
331
+ # > z_loss = torch.sum(z_loss, dim=-1)
332
+ # > z_loss = torch.log(z_loss) ** 2.0
333
+ z_loss = torch.logsumexp(logits, dim=-1) ** 2.0 # [B, T, n_exp]
334
+
335
+ # sum over all tokens and divide by total number of tokens
336
+ return torch.mean(z_loss)
337
+
338
+ def get_capacity(self, tokens_per_batch):
339
+ # expert capacity is given by (tokens_per_batch / num_experts) * capacity_factor
340
+ # see eq (3) in Switch Transformer (https://arxiv.org/abs/2101.03961)
341
+ capacity_factor = self.train_capacity if self.training else self.eval_capacity
342
+ capacity = math.floor(self.top_k * capacity_factor * tokens_per_batch / self.n_exp)
343
+ capacity += capacity % 2
344
+ capacity = max(capacity, self.min_capacity)
345
+ assert capacity > 0
346
+ return int(capacity)
347
+
348
+ # FEEDFORWARD
349
+ class MLP(nn.Module):
350
+ def __init__(self, config, ffn_dim=None):
351
+ super().__init__()
352
+
353
+ if ffn_dim==None:
354
+ ffn_dim = config.ffn_dim
355
+
356
+ self.fc1 = nn.Linear(config.n_embd, 2 * ffn_dim, bias=config.bias)
357
+ self.fc1.is_swiglu = True
358
+ self.swish = nn.SiLU()
359
+ self.fc2 = nn.Linear(ffn_dim, config.n_embd, bias=config.bias)
360
+ self.fc2.output_proj_marker = True
361
+
362
+ self.dropout1 = nn.Dropout(config.dropout)
363
+ self.dropout2 = nn.Dropout(config.dropout)
364
+
365
+ # nn.init.xavier_uniform_(self.fc1.weight, gain=math.sqrt(2.0))
366
+ # nn.init.xavier_uniform_(self.fc2.weight, gain=1.0)
367
+
368
+ def forward(self, x):
369
+ x = self.fc1(x)
370
+
371
+ # Inline SwiGLU: Split the doubled dimension and apply gate
372
+ x, gate = x.chunk(2, dim=-1)
373
+ x = x * self.swish(gate)
374
+
375
+ x = self.dropout1(x)
376
+ x = self.fc2(x)
377
+ return self.dropout2(x)
378
+
379
+
380
+ class MLPExperts(nn.Module):
381
+ def __init__(self, config):
382
+ super().__init__()
383
+ self.n_exp = config.n_exp
384
+ self.n_embd = config.n_embd
385
+ self.bias = config.bias
386
+
387
+ self.c_fc = nn.Parameter(torch.empty(self.n_exp, self.n_embd, 2 * config.expert_dim))
388
+ self.c_proj = nn.Parameter(torch.empty(self.n_exp, config.expert_dim, self.n_embd))
389
+
390
+ self.swish = nn.SiLU()
391
+ self.dropout = nn.Dropout(config.dropout)
392
+
393
+ def forward(self, x):
394
+ x = torch.bmm(x, self.c_fc)
395
+
396
+ x, gate = x.chunk(2, dim=-1)
397
+ x = x * self.swish(gate)
398
+
399
+ x = torch.bmm(x, self.c_proj)
400
+
401
+ return self.dropout(x)
402
+
403
+ class MOELayer(nn.Module):
404
+ def __init__(self, config):
405
+ super().__init__()
406
+ self.router = Router(config) # (noisy) top k router
407
+ self.experts = MLPExperts(config) # group of MLPs (experts)
408
+
409
+ self.shared_expert = MLP(config, ffn_dim=config.shared_dim)
410
+
411
+ def forward(self, x: torch.Tensor):
412
+ B, T, n_embd = x.size()
413
+ num_tokens = (B * T)
414
+
415
+ shared_out = self.shared_expert(x)
416
+
417
+ used_capacity, exp_weight, exp_mask = self.router(x)
418
+
419
+ x = x.view(num_tokens, n_embd)
420
+
421
+ # [n_exp, exp_capacity, B * T] * [B * T, n_embd] -> [n_exp, exp_capacity, n_embd]
422
+ exp_batches = exp_mask.permute(1, 2, 0).type_as(x) @ x
423
+
424
+ exp_out = self.experts(exp_batches) # [n_exp, exp_capacity, n_embd]
425
+
426
+ # aggregate expert outputs based on router weights
427
+ # eq (2) on page 4 of ST-MoE (https://arxiv.org/abs/2202.08906)
428
+ # similar equations are used for other MoE papers
429
+ exp_weight = exp_weight.view(num_tokens, -1) # [B * T, n_exp * exp_capacity]
430
+ exp_out = exp_out.view(-1, n_embd) # [n_exp * exp_capacity, n_embd]
431
+ output = exp_weight @ exp_out # [B * T, n_embd]
432
+
433
+ moe_out = output.view(B, T, n_embd)
434
+
435
+ return moe_out + shared_out
436
+
437
+ class Block(nn.Module):
438
+
439
+ def __init__(self, config, use_moe=False):
440
+ super().__init__()
441
+ self.ln_1 = nn.RMSNorm(config.n_embd)
442
+ self.attn = MultiHeadLatentAttention(config)
443
+ self.ln_2 = nn.RMSNorm(config.n_embd)
444
+ if use_moe:
445
+ self.mlp = MOELayer(config)
446
+ else:
447
+ self.mlp = MLP(config)
448
+
449
+ def forward(self, x):
450
+ x = x + self.attn(self.ln_1(x))
451
+ x = x + self.mlp(self.ln_2(x))
452
+ return x
453
+
454
+ @dataclass
455
+ class GPTConfig:
456
+ block_size: int = 2048
457
+ vocab_size: int = 50304
458
+ n_layer: int = 24
459
+ n_head: int = 10
460
+ n_embd: int = 640
461
+ dropout: float = 0.0
462
+ ffn_dim: int = 640*4
463
+ bias: bool = False
464
+
465
+ # MLA - High Efficiency
466
+ d_c: int = 192
467
+ d_c1: int = 192
468
+ d_rotate: int = 64
469
+ theta: float = 10000.0
470
+
471
+ # MoE - Maximally Smart
472
+ n_exp: int = 12
473
+ top_k: int = 3
474
+ expert_dim: int = 640
475
+ shared_dim: int = 640
476
+ stride: int = 2
477
+
478
+ # Stability (Standard Production Settings)
479
+ use_aux_loss: bool = True
480
+ use_router_z_loss: bool = True
481
+ use_noisy_top_k: bool = True
482
+ aux_loss_weight: float = 0.01
483
+ router_z_loss_weight: float = 0.001
484
+ train_capacity: float = 1.25
485
+ eval_capacity: float = 2.0
486
+ min_capacity: int = 4
487
+ use_switch_tfm_init: bool = True
488
+ switch_tfm_init_scale: float = 1.0
489
+ router_use_full_prec: bool = True
490
+
491
+ # Training Hyperparameters
492
+ batch_size: int = 8
493
+ grad_acc: int = 128
494
+ num_train_epochs: int = 1
495
+ learning_rate: float = 3e-4
496
+ weight_decay: float = 0.1
497
+ betas: tuple = (0.9, 0.95)
498
+ warm_up: int = 5000
499
+
500
+ eos_token_id = 0
501
+ bos_token_id = 0
502
+ pad_token_id = 0
503
+
504
+ class HybridOptimizer(torch.optim.Optimizer):
505
+ def __init__(self, optimizers):
506
+ self.optimizers = optimizers
507
+ self.param_groups = []
508
+ for opt in self.optimizers:
509
+ self.param_groups.extend(opt.param_groups)
510
+
511
+ def step(self, closure=None):
512
+ loss = None
513
+ if closure is not None:
514
+ loss = closure()
515
+ for opt in self.optimizers:
516
+ opt.step()
517
+ return loss
518
+
519
+ def zero_grad(self, set_to_none=True):
520
+ for opt in self.optimizers:
521
+ opt.zero_grad(set_to_none=set_to_none)
522
+
523
+ def state_dict(self):
524
+ return [opt.state_dict() for opt in self.optimizers]
525
+
526
+ def load_state_dict(self, state_dict):
527
+ for opt, sd in zip(self.optimizers, state_dict):
528
+ opt.load_state_dict(sd)
529
+
530
+ class GPT(nn.Module):
531
+
532
+ def __init__(self, config):
533
+ super().__init__()
534
+ assert config.vocab_size is not None
535
+ assert config.block_size is not None
536
+ self.config = config
537
+
538
+ self.can_return_loss = True
539
+ self.accepts_loss_kwargs = False
540
+
541
+ if config.n_exp == 1:
542
+ blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
543
+ else:
544
+ blocks = []
545
+ for i in range(config.n_layer):
546
+ use_moe = False if (i < config.stride or i > config.n_layer - config.stride-1) else True
547
+ blocks.append(Block(config, use_moe=use_moe))
548
+ blocks = nn.ModuleList(blocks)
549
+
550
+ self.transformer = nn.ModuleDict(dict(
551
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
552
+ h = blocks,
553
+ ln_f = nn.RMSNorm(config.n_embd),
554
+ ))
555
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
556
+ self.transformer.wte.weight = self.lm_head.weight
557
+ self.apply(self._init_weights)
558
+
559
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
560
+
561
+ def get_num_params(self, non_embedding=True):
562
+ n_params = sum(p.numel() for p in self.parameters())
563
+ return n_params
564
+
565
+ @torch.no_grad()
566
+ def _init_weights(self, module):
567
+ # Setup base configuration
568
+ scale = self.config.switch_tfm_init_scale if hasattr(self.config, 'switch_tfm_init_scale') else 1.0
569
+ n_layer = self.config.n_layer
570
+
571
+ if isinstance(module, nn.Linear):
572
+ # Calculate standard fan-in (input dimension)
573
+ w_fan_in = module.weight.shape[-1]
574
+ base_std = (scale / w_fan_in) ** 0.5
575
+
576
+ # Determine specific scaling per layer type
577
+ if hasattr(module, 'router_marker'):
578
+ # Small std for routers ensures balanced initial expert distribution
579
+ final_std = 0.01
580
+ elif hasattr(module, 'output_proj_marker'):
581
+ # Residual scaling: keeps variance from exploding in deep networks
582
+ final_std = base_std / math.sqrt(2 * n_layer)
583
+ elif hasattr(module, 'is_attention'):
584
+ # Attn weights often benefit from a slight dampener
585
+ final_std = base_std * 0.7
586
+ else:
587
+ # Standard hidden/up-projections
588
+ final_std = base_std
589
+
590
+ # Apply truncated normal initialization
591
+ torch.nn.init.trunc_normal_(
592
+ module.weight, mean=0.0, std=final_std, a=-2*final_std, b=2*final_std
593
+ )
594
+
595
+ if module.bias is not None:
596
+ torch.nn.init.zeros_(module.bias)
597
+
598
+ # Handling custom Parameter-based MLPExperts
599
+ elif isinstance(module, MLPExperts):
600
+ # UP-PROJECTION (c_fc)
601
+ c_fc_fan_in = module.c_fc.shape[-2]
602
+ final_fc_std = (scale / c_fc_fan_in) ** 0.5
603
+ torch.nn.init.trunc_normal_(module.c_fc, std=final_fc_std, a=-2*final_fc_std, b=2*final_fc_std)
604
+
605
+ # DOWN-PROJECTION (c_proj)
606
+ c_proj_fan_in = module.c_proj.shape[-2]
607
+ # Residual scaling for MoE outputs
608
+ final_proj_std = ((scale / c_proj_fan_in) ** 0.5) / math.sqrt(2 * n_layer)
609
+ torch.nn.init.trunc_normal_(module.c_proj, std=final_proj_std, a=-2*final_proj_std, b=2*final_proj_std)
610
+
611
+ elif isinstance(module, nn.Embedding):
612
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
613
+
614
+ # elif isinstance(module, nn.RMSNorm):
615
+ # # Initializing to 0.01 as requested
616
+ # # Note: 1.0 is standard, 0.01 will significantly dampen initial signal
617
+ # torch.nn.init.constant_(module.weight, 1.0)
618
+
619
+ def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
620
+ _, t = input_ids.size()
621
+ assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
622
+
623
+ x = self.transformer.wte(input_ids)
624
+ for block in self.transformer.h:
625
+ x = block(x)
626
+ x = self.transformer.ln_f(x)
627
+
628
+ if labels is not None:
629
+ logits = self.lm_head(x)
630
+
631
+ shift_logits = logits[:, :-1, :].contiguous()
632
+ shift_labels = labels[:, 1:].contiguous()
633
+
634
+ # print("\n\nlabel: ", shift_labels, "\ninput: ", input_ids)
635
+
636
+ loss_fct = nn.CrossEntropyLoss(
637
+ ignore_index=-100,
638
+ label_smoothing=0.1,
639
+ reduction='mean'
640
+ )
641
+
642
+ main_loss = loss_fct(
643
+ shift_logits.view(-1, shift_logits.size(-1)),
644
+ shift_labels.view(-1)
645
+ )
646
+
647
+ loss = main_loss
648
+
649
+ if self.config.n_exp > 1:
650
+ if self.config.use_aux_loss:
651
+ loss += self.config.aux_loss_weight * MANAGER.aggregate_aux_loss()
652
+ MANAGER.reset_aux_loss()
653
+
654
+ if self.config.use_router_z_loss:
655
+ loss += self.config.router_z_loss_weight * MANAGER.aggregate_router_z_loss()
656
+ MANAGER.reset_router_z_loss()
657
+ else:
658
+ logits = self.lm_head(x[:, [-1], :])
659
+ loss = None
660
+
661
+ return CausalLMOutput(loss=loss, logits=logits)
662
+
663
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
664
+ # TODO: add expert config
665
+ # start with all of the candidate parameters
666
+ param_dict = {pn: p for pn, p in self.named_parameters()}
667
+ # filter out those that do not require grad
668
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
669
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
670
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
671
+ # add an extra check for "bias" string to account for bias terms in MoE layers
672
+ decay_params = [p for n, p in param_dict.items() if (p.dim() >= 2 and not n.endswith('bias'))]
673
+ nodecay_params = [p for n, p in param_dict.items() if (p.dim() < 2 or n.endswith('bias'))]
674
+ optim_groups = [
675
+ {'params': decay_params, 'weight_decay': weight_decay},
676
+ {'params': nodecay_params, 'weight_decay': 0.0}
677
+ ]
678
+ num_decay_params = sum(p.numel() for p in decay_params)
679
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
680
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
681
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
682
+ # Create AdamW optimizer and use the fused version if it is available
683
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
684
+ use_fused = fused_available and device_type == 'cuda'
685
+ extra_args = dict(fused=True) if use_fused else dict()
686
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
687
+ print(f"using fused AdamW: {use_fused}")
688
+
689
+ return optimizer
690
+
691
+ @torch.no_grad()
692
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
693
+ for _ in range(max_new_tokens):
694
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
695
+
696
+ # Correctly unpack the dataclass output
697
+ outputs = self(idx_cond)
698
+ logits = outputs.logits[:, -1, :] / temperature
699
+
700
+ if top_k is not None:
701
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
702
+ logits[logits < v[:, [-1]]] = -float('Inf')
703
+
704
+ probs = F.softmax(logits, dim=-1)
705
+
706
+ idx_next = torch.multinomial(probs, num_samples=1)
707
+ idx = torch.cat((idx, idx_next), dim=1)
708
+
709
+ return idx
full_diagnostics.png ADDED

Git LFS Details

  • SHA256: 138d2ab19d71b3409861ca53a592139bc2abf00bdd81d159b5480b940f84c73c
  • Pointer size: 131 Bytes
  • Size of remote file: 706 kB
load.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import login, upload_folder
2
+
3
+
4
+ login()
5
+
6
+
7
+ upload_folder(folder_path=".", repo_id="firdavsus/LLM_D4", repo_type="model")
optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f0add8314988fa54def48bd806d136a4c2fd890195571b0937efafadfa56b61
3
+ size 1027863691
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:671e16a44344b218958d6bea35956ab658a5cf6e1df653fcd89ba32555fea3fa
3
+ size 513935755
rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d652bae1a3c19ced43ea6aa1b59b7afd879cb5f065bc818a85dfcd440c8e7d85
3
+ size 14645
scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85b4acf0da0ea190f06b9e0da812d798a2d827423827367c72b771685bde7cee
3
+ size 1465
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": false,
209
+ "eos_token": "<|endoftext|>",
210
+ "extra_special_tokens": {},
211
+ "model_max_length": 1000000000000000019884624838656,
212
+ "pad_token": "<|endoftext|>",
213
+ "tokenizer_class": "GPTNeoXTokenizer",
214
+ "unk_token": "<|endoftext|>"
215
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc1189daa1d8edc864213640a8c99b39c6c20b57d961e66a0efd1b89958b652f
3
+ size 5841
training_curves_with_eval.png ADDED

Git LFS Details

  • SHA256: dd4f439debee6d335e719327b3b630981875fc4a2ca232a61653bf16a89c9c7e
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
weight_histograms.png ADDED

Git LFS Details

  • SHA256: 45df541c5eba2a1651eebaf31521dd77d8d73cbc465873a5826a0c7227d9c6f3
  • Pointer size: 131 Bytes
  • Size of remote file: 422 kB