Marmik commited on
Commit
9f01bf1
·
verified ·
1 Parent(s): 52a0934

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +6 -2
  2. generation_config.json +1 -1
  3. modeling_tiny_mixtral.py +595 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "TinyMixtralForCausalLM"
4
  ],
 
 
 
 
5
  "attn_dropout": 0.0,
6
  "attn_eps": 1e-06,
7
  "d_head": 64,
@@ -10,7 +14,7 @@
10
  "dropout": 0.0,
11
  "ffn_eps": 1e-06,
12
  "max_seq_len": 1024,
13
- "model_type": "tiny_mixtral_5l_active",
14
  "n_experts": 8,
15
  "n_heads": 12,
16
  "n_layers": 5,
@@ -18,6 +22,6 @@
18
  "top_k": null,
19
  "top_k_experts": 2,
20
  "torch_dtype": "float32",
21
- "transformers_version": "4.47.1",
22
  "vocab_size": 50257
23
  }
 
2
  "architectures": [
3
  "TinyMixtralForCausalLM"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_tiny_mixtral.TinyMixtralConfig",
7
+ "AutoModelForCausalLM": "modeling_tiny_mixtral.TinyMixtralForCausalLM"
8
+ },
9
  "attn_dropout": 0.0,
10
  "attn_eps": 1e-06,
11
  "d_head": 64,
 
14
  "dropout": 0.0,
15
  "ffn_eps": 1e-06,
16
  "max_seq_len": 1024,
17
+ "model_type": "tiny_mixtral_5l_total",
18
  "n_experts": 8,
19
  "n_heads": 12,
20
  "n_layers": 5,
 
22
  "top_k": null,
23
  "top_k_experts": 2,
24
  "torch_dtype": "float32",
25
+ "transformers_version": "4.53.2",
26
  "vocab_size": 50257
27
  }
generation_config.json CHANGED
@@ -1,4 +1,4 @@
1
  {
2
  "_from_model_config": true,
3
- "transformers_version": "4.47.1"
4
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "transformers_version": "4.53.2"
4
  }
modeling_tiny_mixtral.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
4
+ from transformers.modeling_outputs import MoECausalLMOutputWithPast
5
+ from dataclasses import dataclass
6
+ import torch
7
+ from typing import Optional
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ import sys
11
+ import os
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
13
+
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ """config for tiny mixtral inference"""
18
+ vocab_size:int = 50_257 # 50_256
19
+ d_model: int = 768 #embedding size # 768
20
+ d_head: int = 64 #head size
21
+ n_heads:int = 12 #number of heads # 12
22
+ n_layers:int = 5 #number of layers # 5
23
+ max_seq_len:int = 1024 #maximum sequence length
24
+ n_experts:int = 8 #number of experts # 8
25
+ top_k:int = 2 #top k # 2
26
+ # do not change
27
+ attn_dropout:float = 0.0 #attention dropout
28
+ dropout:float = 0.0 #dropout
29
+ norm_eps:float = 1e-6
30
+ attn_eps:float = 1e-6
31
+ ffn_eps:float = 1e-6
32
+ device:str = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+
34
+
35
+ @dataclass
36
+ class ModelArgs:
37
+ vocab_size:int = 50_256 # 50_256
38
+ d_model: int = 768 #embedding size # 768
39
+ d_head: int = 64 #head size
40
+ n_heads:int = 12 #number of heads # 12
41
+ n_kv_heads:int = 8 #number of key-value heads # 8
42
+ n_layers:int = 5 #number of layers # 5
43
+ train_epochs:int = 2 #number of epochs # 1-2
44
+ batch_size:int = 256 #batch size # 256
45
+ val_epochs:int = 1 #number of validation epochs # 1
46
+ window_size:int = 128 #window size # 128
47
+ seq_len:int = 512 #sequence length # 512
48
+ max_seq_len:int = 1024 #maximum sequence length
49
+ max_lr:float = 5e-4 #maximum learning rate
50
+ n_experts:int = 8 #number of experts # 8
51
+ top_k:int = 2 #top k # 2
52
+ val_steps:int = 300 #validation steps # 250-500
53
+ save_steps:int = 1000 #save steps # 1000 is fine for 1B toks
54
+ # do not change
55
+ clip:int = 1 #gradient clipping
56
+ attn_dropout:float = 0.1 #attention dropout
57
+ dropout:float = 0.1 #dropout
58
+ beta1:float = 0.9 #beta1
59
+ beta2:float = 0.999 #beta2
60
+ device:str = 'cuda' if torch.cuda.is_available() else 'cpu'
61
+ wandb_project:str = 'moe-active'
62
+ norm_eps:float = 1e-6
63
+ attn_eps:float = 1e-6
64
+ ffn_eps:float = 1e-6
65
+
66
+
67
+ class TinyMixtralConfig(PretrainedConfig):
68
+ model_type = "tiny_mixtral_5l_active"
69
+ def __init__(self,
70
+ vocab_size = ModelConfig.vocab_size,
71
+ d_model = ModelConfig.d_model,
72
+ d_head = ModelConfig.d_head,
73
+ n_heads = ModelConfig.n_heads,
74
+ n_layers = ModelConfig.n_layers,
75
+ max_seq_len = ModelConfig.max_seq_len,
76
+ n_experts = ModelConfig.n_experts,
77
+ top_k_experts = ModelConfig.top_k,
78
+ norm_eps = ModelConfig.norm_eps,
79
+ attn_eps = ModelConfig.attn_eps,
80
+ ffn_eps = ModelConfig.ffn_eps,
81
+ device = ModelConfig.device,
82
+ **kwargs
83
+ ):
84
+ kwargs["auto_map"] = {
85
+ "AutoConfig": "modeling_tiny_mixtral.TinyMixtralConfig",
86
+ "AutoModelForCausalLM": "modeling_tiny_mixtral.TinyMixtralForCausalLM"
87
+ }
88
+ # Remove top_k from kwargs if it exists to avoid conflict
89
+ # kwargs.pop('top_k', None)
90
+ super().__init__(**kwargs)
91
+ self.vocab_size = vocab_size
92
+ self.d_model = d_model
93
+ self.d_head = d_head
94
+ self.n_heads = n_heads
95
+ self.n_layers = n_layers
96
+ self.max_seq_len = max_seq_len
97
+ self.n_experts = n_experts
98
+ self.top_k_experts = top_k_experts
99
+ self.norm_eps = norm_eps
100
+ self.attn_eps = attn_eps
101
+ self.ffn_eps = ffn_eps
102
+ self.device = device
103
+
104
+
105
+ class RMSNorm(nn.Module):
106
+ def __init__(self,dim:int,eps:float=1e-6):
107
+ """
108
+ Initializes the RMSNorm module.
109
+
110
+ Args:
111
+ dim (int): The dimensionality of the input feature space.
112
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
113
+ """
114
+ super().__init__()
115
+ self.eps=eps
116
+ self.w=nn.Parameter(torch.ones(dim))
117
+
118
+ def norm(self,x:torch.Tensor):
119
+ """
120
+ Computes the root mean square normalization of the input tensor.
121
+
122
+ Args:
123
+ x (torch.Tensor): The input tensor.
124
+
125
+ Returns:
126
+ torch.Tensor: The normalized tensor.
127
+ """
128
+ return x * torch.rsqrt(torch.mean(x**2,-1, keepdim=True) + self.eps)
129
+ def forward(self,x:torch.Tensor):
130
+ """
131
+ Forward pass of the RMSNorm module.
132
+
133
+ Args:
134
+ x (torch.Tensor): The input tensor.
135
+
136
+ Returns:
137
+ torch.Tensor: The normalized tensor.
138
+ """
139
+ return self.w * self.norm(x.float()).type_as(x)
140
+
141
+
142
+
143
+
144
+ #----Rotary Embeddings---
145
+
146
+ def precompute_theta_pos_frequencies(d_head:int,seq_len:int,device:str,theta:float=10000.0):
147
+ """
148
+ Precomputes the position frequencies for Rotary Position Embeddings.
149
+
150
+ Args:
151
+ d_head (int): The number of dimensions in the attention head.
152
+ seq_len (int): The sequence length of the input sequence.
153
+ device (str): The device on which to create the tensor.
154
+ theta (float, optional): The base for the exponential decay. Defaults to 10000.0.
155
+
156
+ Returns:
157
+ torch.Tensor: A tensor of shape (seq_len, d_head/2) containing the complex position frequencies.
158
+ """
159
+ assert d_head%2==0,"d_head must be even"
160
+ #theta_i=1000^-2(i-1)/d_head for i [1,2...d_head/2]
161
+ theta_nr=torch.arange(0,d_head,2,device=device)
162
+ theta=1.0/(theta**(theta_nr/d_head)).to(device)
163
+
164
+ m=torch.arange(seq_len,device=device)
165
+ m_theta=torch.outer(m,theta).float()
166
+ freq_complex=torch.polar(torch.ones_like(m_theta),m_theta)
167
+
168
+ return freq_complex #(seq_len,d_head/2)
169
+
170
+
171
+ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str):
172
+ """
173
+ Applies Rotary Position Embeddings to the input tensor.
174
+
175
+ Args:
176
+ x (torch.Tensor): The input tensor of shape (batch_size, seq_len, d_head).
177
+ freq_complex (torch.Tensor): The complex position frequencies tensor of shape (seq_len, d_head/2).
178
+
179
+ Returns:
180
+ torch.Tensor: The tensor after applying Rotary Position Embeddings.
181
+ """
182
+ x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
183
+
184
+ freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
185
+
186
+ x_rotated=x_complex * freq_complex #(N,seq_len,h,head_dim/2)
187
+ x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
188
+ x_out=x_out.reshape(*x.shape)
189
+
190
+ return x_out.type_as(x).to(device)
191
+
192
+
193
+
194
+ class SubLayerConnection(nn.Module):
195
+ def __init__(self,size,dropout):
196
+ """
197
+ Initializes the SubLayerConnection module.
198
+
199
+ Args:
200
+ size (int): The size of the input for the layer normalization.
201
+ dropout (float): The dropout rate to be applied after the sublayer.
202
+ """
203
+ super(SubLayerConnection,self).__init__()
204
+ self.norm=nn.LayerNorm(size)
205
+ self.dropout=nn.Dropout(dropout)
206
+
207
+ def forward(self,x,sublayer):
208
+ """
209
+ Computes the output of the SubLayerConnection module.
210
+
211
+ Args:
212
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
213
+ sublayer (nn.Module): The sublayer module to be applied to the input tensor.
214
+
215
+ Returns:
216
+ torch.Tensor: The output tensor of shape (batch_size, seq_len, d_model).
217
+ """
218
+
219
+ return x + self.dropout(sublayer(self.norm(x)))
220
+
221
+
222
+ def clones(module,N):
223
+ """
224
+ Creates a list of N copies of the given nn.Module.
225
+
226
+ Args:
227
+ nn.Module: The nn.Module to be cloned.
228
+ N (int): The number of copies to be made.
229
+
230
+ Returns:
231
+ nn.ModuleList: A list of N identical nn.Module objects.
232
+ """
233
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
234
+
235
+
236
+
237
+
238
+ class SimpleMultiHeadAttention(nn.Module):
239
+ """Simple multi-head attention without GQA, sliding window, or KV cache"""
240
+
241
+ def __init__(self, dim: int, num_heads: int, device, dropout: float = 0.0, bias: bool = False):
242
+ """
243
+ Initialize the SimpleMultiHeadAttention module.
244
+
245
+ Args:
246
+ dim (int): The dimensionality of the input and output features.
247
+ num_heads (int): The number of attention heads.
248
+ device: The device to use (cpu or cuda).
249
+ dropout (float, optional): Dropout rate. Defaults to 0.0.
250
+ bias (bool, optional): Whether to use bias in linear layers. Defaults to False.
251
+ """
252
+ super().__init__()
253
+ assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
254
+
255
+ self.dim = dim
256
+ self.num_heads = num_heads
257
+ self.head_dim = dim // num_heads
258
+ self.device = device
259
+ self.dropout = dropout
260
+
261
+ # Combined projection for queries, keys, and values
262
+ self.c_attn = nn.Linear(dim, 3 * dim, bias=bias)
263
+ # Output projection
264
+ self.c_proj = nn.Linear(dim, dim, bias=bias)
265
+
266
+ # Dropout layers
267
+ self.attn_dropout = nn.Dropout(dropout)
268
+ self.resid_dropout = nn.Dropout(dropout)
269
+
270
+ # Use flash attention if available
271
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
272
+ if not self.flash:
273
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
274
+
275
+ def forward(self, x: torch.Tensor, freqs_complex: torch.Tensor = None, start_pos: int = 0):
276
+ """
277
+ Compute multi-head attention.
278
+
279
+ Args:
280
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
281
+ freqs_complex (torch.Tensor, optional): Complex position frequencies for RoPE. Defaults to None.
282
+ start_pos (int, optional): Starting position (unused in simple attention). Defaults to 0.
283
+
284
+ Returns:
285
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, dim).
286
+ """
287
+ batch_size, seq_len, _ = x.shape
288
+
289
+ # Calculate query, key, values for all heads in batch
290
+ q, k, v = self.c_attn(x).split(self.dim, dim=2)
291
+
292
+ # Reshape and transpose for multi-head attention
293
+ # (batch_size, seq_len, num_heads, head_dim) -> (batch_size, num_heads, seq_len, head_dim)
294
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
295
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
296
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
297
+
298
+ # Apply rotary embeddings if provided
299
+ if freqs_complex is not None:
300
+ # Note: apply_rotary_embeddings expects (batch, seq_len, num_heads, head_dim)
301
+ q_rotary = q.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
302
+ k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
303
+
304
+ q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
305
+ k_rotary = apply_rotary_embeddings(k_rotary, freq_complex=freqs_complex, device=self.device)
306
+
307
+ q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
308
+ k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
309
+
310
+ # Compute attention
311
+ if self.flash:
312
+ # Use flash attention for efficiency
313
+ y = torch.nn.functional.scaled_dot_product_attention(
314
+ q, k, v,
315
+ attn_mask=None,
316
+ dropout_p=self.dropout if self.training else 0,
317
+ is_causal=True
318
+ )
319
+ else:
320
+ # Manual implementation of attention
321
+ # Compute attention scores
322
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
323
+
324
+ # Apply causal mask
325
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=self.device))
326
+ causal_mask = causal_mask.view(1, 1, seq_len, seq_len)
327
+ attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
328
+
329
+ # Apply softmax
330
+ attn_weights = F.softmax(attn_scores, dim=-1)
331
+ attn_weights = self.attn_dropout(attn_weights)
332
+
333
+ # Apply attention to values
334
+ y = torch.matmul(attn_weights, v)
335
+
336
+ # Reshape back to (batch_size, seq_len, dim)
337
+ y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
338
+
339
+ # Output projection
340
+ y = self.resid_dropout(self.c_proj(y))
341
+
342
+ return y
343
+
344
+ def reset_cache(self):
345
+ """Reset cache (no-op for simple attention)"""
346
+ pass
347
+
348
+
349
+
350
+ class SwiGLUFFN(nn.Module):
351
+ def __init__(self,input_dim:int,hidden_dim:int):
352
+ """
353
+ Initializes the SwiGLUFFN module.
354
+
355
+ Args:
356
+ input_dim (int): The dimensionality of the input features.
357
+ hidden_dim (int): The dimensionality of the hidden layer.
358
+
359
+ Initializes three linear layers:
360
+ - `w_1`: Projects input features to the hidden dimension.
361
+ - `w_2`: Projects input features to the hidden dimension using a separate path.
362
+ - `out`: Projects the transformed hidden representation back to the input dimension.
363
+ """
364
+ super().__init__()
365
+ self.w_1=nn.Linear(input_dim,hidden_dim)
366
+ self.w_2=nn.Linear(input_dim,hidden_dim)
367
+ self.out=nn.Linear(hidden_dim,input_dim)
368
+ def forward(self,x:torch.Tensor):
369
+ """
370
+ Computes the output of the SwiGLUFFN module.
371
+ """
372
+ return self.out(self.w_1(x) * F.silu(self.w_2(x)))
373
+
374
+
375
+
376
+
377
+ class SparseMOE(nn.Module):
378
+ def __init__(self,d_model:int,d_hidden:int,num_experts:int=8,top_k:int=2):
379
+ """
380
+ Initializes the SparseMOE module.
381
+
382
+ Args:
383
+ d_model (int): The dimensionality of the input features.
384
+ d_hidden (int): The dimensionality of the hidden layer in each expert.
385
+ num_experts (int, optional): The number of expert networks. Defaults to 8.
386
+ top_k (int, optional): The number of experts to be selected for each input. Defaults to 2.
387
+
388
+ The module contains a list of expert networks, each an instance of the SwiGLUFFN module,
389
+ and a router to compute the selection distribution over the experts.
390
+ """
391
+
392
+ super().__init__()
393
+ self.d_model=d_model
394
+ self.d_hidden=d_hidden
395
+ self.num_experts=num_experts
396
+ self.top_k=top_k
397
+ self.experts=nn.ModuleList([SwiGLUFFN(input_dim=d_model,hidden_dim=d_hidden) for _ in range(num_experts)])
398
+ self.router=nn.Linear(d_model,num_experts)
399
+
400
+ def forward(self,x:torch.Tensor):
401
+ """
402
+ Computes the output of the SparseMOE module.
403
+
404
+ Args:
405
+ x (torch.Tensor): Input tensor of shape (batch_size,seq_len,d_model).
406
+
407
+ Returns:
408
+ tuple: Output tensor of shape (batch_size,seq_len,d_model) and the load balancing loss
409
+ """
410
+ batch_size,seq_len,d_model=x.shape
411
+
412
+ x_flat=x.view(-1,self.d_model) # (batch_size * seq_len, d_model)
413
+
414
+ #Step 1: get router scores for each token
415
+ router_logits=self.router(x_flat)
416
+ router_probs=F.softmax(router_logits,dim=-1)
417
+
418
+ #Step 2: get top-k experts
419
+ topk_probs,topk_indices=torch.topk(router_probs,self.top_k,dim=-1) #(batch_size*seq_len, top_k)
420
+
421
+ #Step 3: compute weighted sum of top-k experts
422
+ expert_outputs=[]
423
+ for i in range(self.top_k):
424
+ expert_idx=topk_indices[:,i]
425
+ outputs=torch.zeros_like(x_flat)
426
+
427
+ for expert_id in range(self.num_experts):
428
+ mask=(expert_id==expert_idx)
429
+ if mask.any():
430
+ selected_x=x_flat[mask]
431
+ expert_out=self.experts[expert_id](selected_x)
432
+ outputs[mask]=expert_out
433
+
434
+ weighted_output = topk_probs[:, i].unsqueeze(-1) * outputs
435
+ expert_outputs.append(weighted_output)
436
+
437
+ final_output = sum(expert_outputs)
438
+
439
+ final_output = final_output.view(batch_size, seq_len, d_model)
440
+
441
+ # router_probs_mean = router_probs.mean(dim=0)
442
+ # load_balancing_loss = (router_probs_mean * router_probs_mean).sum() * self.num_experts
443
+
444
+ # Step 4: Compute load balancing loss (Equation 4 from paper)
445
+ # f_i is the fraction of tokens dispatched to expert i
446
+ f_i = torch.zeros(self.num_experts, device=x.device)
447
+ for i in range(self.num_experts):
448
+ # Count how many tokens are assigned to expert i across all top-k selections
449
+ mask = (topk_indices == i).any(dim=-1) # tokens that use expert i
450
+ f_i[i] = mask.float().mean()
451
+
452
+ # P_i is the fraction of router probability allocated to expert i
453
+ P_i = router_probs.mean(dim=0) # average probability per expert across all tokens
454
+
455
+ # Load balancing loss: α * N * Σ(f_i * P_i)
456
+ alpha = 0.01 # auxiliary loss weight (you can make this configurable)
457
+ load_balancing_loss = alpha * self.num_experts * torch.sum(f_i * P_i)
458
+
459
+ return final_output, load_balancing_loss
460
+
461
+ ##final_loss = task_loss + router_loss_weight * router_loss
462
+
463
+
464
+
465
+ class layer(nn.Module):
466
+ def __init__(self,d_model:int,n_heads:int,num_experts:int,top_k:int,device,attn_eps:float,dropout:float,ffn_eps:float=1e-6):
467
+ """
468
+ Initialize the layer.
469
+
470
+ Args:
471
+ d_model (int): The dimensionality of the input and output features.
472
+ n_heads (int): The number of attention heads.
473
+ num_experts (int): The number of expert networks.
474
+ top_k (int): The number of experts to be selected for each input.
475
+ device (str): The device to use (cpu or cuda).
476
+ attn_eps (float): The small value added to the denominator in the attention normalization for numerical stability.
477
+ dropout (float): The dropout rate to be applied after the sublayer.
478
+ ffn_eps (float, optional): The small value added to the denominator in the feed-forward normalization for numerical stability. Defaults to 1e-6.
479
+ """
480
+ super().__init__()
481
+ self.d_model=d_model
482
+ self.n_heads=n_heads
483
+ self.device=device
484
+
485
+ self.attention=SimpleMultiHeadAttention(dim=self.d_model,num_heads=self.n_heads,device=self.device,
486
+ dropout=dropout,bias=False)
487
+
488
+ self.ffn=SparseMOE(d_model=self.d_model,d_hidden=self.d_model // 2, num_experts=num_experts,top_k=top_k) # for matching total params just do d_hidden = d_model // 2 & d_hidden = d_model * 2 for matching active params
489
+
490
+ self.attn_norm=RMSNorm(dim=d_model,eps=attn_eps)
491
+ self.ffn_norm=RMSNorm(dim=d_model,eps=ffn_eps)
492
+
493
+
494
+
495
+ def forward(self,x:torch.Tensor,freqs_complex:torch.Tensor,start_pos:int):
496
+
497
+ """
498
+ Computes the output of the layer.
499
+
500
+ Args:
501
+ x (torch.Tensor): The input tensor of shape (batch_size, seq_len, d_model).
502
+ freqs_complex (torch.Tensor): The complex position frequencies tensor of shape (seq_len, d_head/2).
503
+ start_pos (int): The starting position of the sequence.
504
+
505
+ Returns:
506
+ tuple: (output tensor of shape (batch_size, seq_len, d_model), load_balancing_loss)
507
+ """
508
+ # print(x.shape)
509
+ # print(freqs_complex.shape)
510
+
511
+ h=x + self.attention(self.attn_norm(x),freqs_complex=freqs_complex,start_pos=start_pos)
512
+ ffn_output,router_loss=self.ffn(self.ffn_norm(h))
513
+ out=h+ffn_output
514
+
515
+
516
+ return out, router_loss
517
+
518
+
519
+ class tiny_mixtral(nn.Module):
520
+ def __init__(self,args:ModelArgs):
521
+ super(tiny_mixtral, self).__init__()
522
+ self.args=args
523
+ self.vocab_size=args.vocab_size
524
+ self.n_layers=args.n_layers
525
+ self.tok_embedding=nn.Embedding(self.vocab_size,args.d_model)
526
+ self.layers=clones(layer(d_model=args.d_model,
527
+ n_heads=args.n_heads,
528
+ num_experts=args.n_experts,
529
+ top_k=args.top_k,
530
+ device=args.device,
531
+ attn_eps=args.attn_eps,
532
+ dropout=args.attn_dropout,
533
+ ffn_eps=args.ffn_eps),self.n_layers)
534
+ self.norm=RMSNorm(args.d_model,eps=args.norm_eps)
535
+
536
+ self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
537
+
538
+ self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
539
+
540
+
541
+ def forward(self,x:torch.Tensor,start_pos:int):
542
+ batch_size,seq_len=x.shape
543
+ h=self.tok_embedding(x)
544
+ freqs_complex=self.freqs_complex[start_pos:start_pos+seq_len]
545
+ total_load_balancing_loss = 0
546
+
547
+ for layer in self.layers:
548
+ h, load_balancing_loss = layer(h,freqs_complex=freqs_complex,start_pos=start_pos)
549
+ total_load_balancing_loss += load_balancing_loss
550
+
551
+ h=self.norm(h)
552
+ out=self.output(h).float()
553
+
554
+ return out, total_load_balancing_loss
555
+
556
+
557
+ class TinyMixtralForCausalLM(PreTrainedModel, GenerationMixin):
558
+ config_class = TinyMixtralConfig
559
+ base_model_prefix = "moe_model"
560
+
561
+ def __init__(self, config):
562
+ super().__init__(config)
563
+ args = ModelConfig(
564
+ vocab_size=config.vocab_size,
565
+ d_model=config.d_model,
566
+ d_head=config.d_head,
567
+ n_heads=config.n_heads,
568
+ n_layers=config.n_layers,
569
+ max_seq_len=config.max_seq_len,
570
+ n_experts=config.n_experts,
571
+ top_k=config.top_k_experts,
572
+ norm_eps=config.norm_eps,
573
+ attn_eps=config.attn_eps,
574
+ ffn_eps=config.ffn_eps,
575
+ device=config.device,
576
+ )
577
+ self.model = tiny_mixtral(args=args)
578
+ self.config = config
579
+ self.post_init()
580
+
581
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
582
+
583
+ outputs, load_balancing_loss = self.model(input_ids, start_pos=0)
584
+
585
+ return MoECausalLMOutputWithPast(
586
+ loss=None,
587
+ logits=outputs,
588
+ aux_loss=load_balancing_loss,
589
+ attentions=None,
590
+ )
591
+
592
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
593
+ return {
594
+ "input_ids": input_ids,
595
+ }