Marmik commited on
Commit
4f3fec5
·
verified ·
1 Parent(s): 8be79ab

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_tiny_mixtral.py +589 -0
config.json CHANGED
@@ -2,12 +2,12 @@
2
  "architectures": [
3
  "TinyMixtralForCausalLM"
4
  ],
5
- "attn_dropout": 0.1,
6
  "attn_eps": 1e-06,
7
  "d_head": 64,
8
  "d_model": 768,
9
  "device": "cpu",
10
- "dropout": 0.1,
11
  "ffn_eps": 1e-06,
12
  "max_seq_len": 1024,
13
  "model_type": "tiny_mixtral_5l_active",
 
2
  "architectures": [
3
  "TinyMixtralForCausalLM"
4
  ],
5
+ "attn_dropout": 0.0,
6
  "attn_eps": 1e-06,
7
  "d_head": 64,
8
  "d_model": 768,
9
  "device": "cpu",
10
+ "dropout": 0.0,
11
  "ffn_eps": 1e-06,
12
  "max_seq_len": 1024,
13
  "model_type": "tiny_mixtral_5l_active",
modeling_tiny_mixtral.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
4
+ from transformers.modeling_outputs import MoECausalLMOutputWithPast
5
+ from dataclasses import dataclass
6
+ import torch
7
+ from typing import Optional
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ import sys
11
+ import os
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
13
+
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ """config for tiny mixtral inference"""
18
+ vocab_size:int = 50_257 # 50_256
19
+ d_model: int = 768 #embedding size # 768
20
+ d_head: int = 64 #head size
21
+ n_heads:int = 12 #number of heads # 12
22
+ n_layers:int = 5 #number of layers # 5
23
+ max_seq_len:int = 1024 #maximum sequence length
24
+ n_experts:int = 8 #number of experts # 8
25
+ top_k:int = 2 #top k # 2
26
+ # do not change
27
+ attn_dropout:float = 0.0 #attention dropout
28
+ dropout:float = 0.0 #dropout
29
+ norm_eps:float = 1e-6
30
+ attn_eps:float = 1e-6
31
+ ffn_eps:float = 1e-6
32
+ device:str = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+
34
+
35
+ @dataclass
36
+ class ModelArgs:
37
+ vocab_size:int = 50_256 # 50_256
38
+ d_model: int = 768 #embedding size # 768
39
+ d_head: int = 64 #head size
40
+ n_heads:int = 12 #number of heads # 12
41
+ n_kv_heads:int = 8 #number of key-value heads # 8
42
+ n_layers:int = 5 #number of layers # 5
43
+ train_epochs:int = 2 #number of epochs # 1-2
44
+ batch_size:int = 256 #batch size # 256
45
+ val_epochs:int = 1 #number of validation epochs # 1
46
+ window_size:int = 128 #window size # 128
47
+ seq_len:int = 512 #sequence length # 512
48
+ max_seq_len:int = 1024 #maximum sequence length
49
+ max_lr:float = 5e-4 #maximum learning rate
50
+ n_experts:int = 8 #number of experts # 8
51
+ top_k:int = 2 #top k # 2
52
+ val_steps:int = 300 #validation steps # 250-500
53
+ save_steps:int = 1000 #save steps # 1000 is fine for 1B toks
54
+ # do not change
55
+ clip:int = 1 #gradient clipping
56
+ attn_dropout:float = 0.1 #attention dropout
57
+ dropout:float = 0.1 #dropout
58
+ beta1:float = 0.9 #beta1
59
+ beta2:float = 0.999 #beta2
60
+ device:str = 'cuda' if torch.cuda.is_available() else 'cpu'
61
+ wandb_project:str = 'moe-active'
62
+ norm_eps:float = 1e-6
63
+ attn_eps:float = 1e-6
64
+ ffn_eps:float = 1e-6
65
+
66
+
67
+ class TinyMixtralConfig(PretrainedConfig):
68
+ model_type = "tiny_mixtral_5l_active"
69
+ def __init__(self,
70
+ vocab_size = ModelConfig.vocab_size,
71
+ d_model = ModelConfig.d_model,
72
+ d_head = ModelConfig.d_head,
73
+ n_heads = ModelConfig.n_heads,
74
+ n_layers = ModelConfig.n_layers,
75
+ max_seq_len = ModelConfig.max_seq_len,
76
+ n_experts = ModelConfig.n_experts,
77
+ top_k_experts = ModelConfig.top_k,
78
+ norm_eps = ModelConfig.norm_eps,
79
+ attn_eps = ModelConfig.attn_eps,
80
+ ffn_eps = ModelConfig.ffn_eps,
81
+ device = ModelConfig.device,
82
+ **kwargs
83
+ ):
84
+ super().__init__(top_k=None,**kwargs)
85
+ self.vocab_size = vocab_size
86
+ self.d_model = d_model
87
+ self.d_head = d_head
88
+ self.n_heads = n_heads
89
+ self.n_layers = n_layers
90
+ self.max_seq_len = max_seq_len
91
+ self.n_experts = n_experts
92
+ self.top_k_experts = top_k_experts
93
+ self.norm_eps = norm_eps
94
+ self.attn_eps = attn_eps
95
+ self.ffn_eps = ffn_eps
96
+ self.device = device
97
+
98
+
99
+ class RMSNorm(nn.Module):
100
+ def __init__(self,dim:int,eps:float=1e-6):
101
+ """
102
+ Initializes the RMSNorm module.
103
+
104
+ Args:
105
+ dim (int): The dimensionality of the input feature space.
106
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
107
+ """
108
+ super().__init__()
109
+ self.eps=eps
110
+ self.w=nn.Parameter(torch.ones(dim))
111
+
112
+ def norm(self,x:torch.Tensor):
113
+ """
114
+ Computes the root mean square normalization of the input tensor.
115
+
116
+ Args:
117
+ x (torch.Tensor): The input tensor.
118
+
119
+ Returns:
120
+ torch.Tensor: The normalized tensor.
121
+ """
122
+ return x * torch.rsqrt(torch.mean(x**2,-1, keepdim=True) + self.eps)
123
+ def forward(self,x:torch.Tensor):
124
+ """
125
+ Forward pass of the RMSNorm module.
126
+
127
+ Args:
128
+ x (torch.Tensor): The input tensor.
129
+
130
+ Returns:
131
+ torch.Tensor: The normalized tensor.
132
+ """
133
+ return self.w * self.norm(x.float()).type_as(x)
134
+
135
+
136
+
137
+
138
+ #----Rotary Embeddings---
139
+
140
+ def precompute_theta_pos_frequencies(d_head:int,seq_len:int,device:str,theta:float=10000.0):
141
+ """
142
+ Precomputes the position frequencies for Rotary Position Embeddings.
143
+
144
+ Args:
145
+ d_head (int): The number of dimensions in the attention head.
146
+ seq_len (int): The sequence length of the input sequence.
147
+ device (str): The device on which to create the tensor.
148
+ theta (float, optional): The base for the exponential decay. Defaults to 10000.0.
149
+
150
+ Returns:
151
+ torch.Tensor: A tensor of shape (seq_len, d_head/2) containing the complex position frequencies.
152
+ """
153
+ assert d_head%2==0,"d_head must be even"
154
+ #theta_i=1000^-2(i-1)/d_head for i [1,2...d_head/2]
155
+ theta_nr=torch.arange(0,d_head,2,device=device)
156
+ theta=1.0/(theta**(theta_nr/d_head)).to(device)
157
+
158
+ m=torch.arange(seq_len,device=device)
159
+ m_theta=torch.outer(m,theta).float()
160
+ freq_complex=torch.polar(torch.ones_like(m_theta),m_theta)
161
+
162
+ return freq_complex #(seq_len,d_head/2)
163
+
164
+
165
+ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str):
166
+ """
167
+ Applies Rotary Position Embeddings to the input tensor.
168
+
169
+ Args:
170
+ x (torch.Tensor): The input tensor of shape (batch_size, seq_len, d_head).
171
+ freq_complex (torch.Tensor): The complex position frequencies tensor of shape (seq_len, d_head/2).
172
+
173
+ Returns:
174
+ torch.Tensor: The tensor after applying Rotary Position Embeddings.
175
+ """
176
+ x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
177
+
178
+ freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
179
+
180
+ x_rotated=x_complex * freq_complex #(N,seq_len,h,head_dim/2)
181
+ x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
182
+ x_out=x_out.reshape(*x.shape)
183
+
184
+ return x_out.type_as(x).to(device)
185
+
186
+
187
+
188
+ class SubLayerConnection(nn.Module):
189
+ def __init__(self,size,dropout):
190
+ """
191
+ Initializes the SubLayerConnection module.
192
+
193
+ Args:
194
+ size (int): The size of the input for the layer normalization.
195
+ dropout (float): The dropout rate to be applied after the sublayer.
196
+ """
197
+ super(SubLayerConnection,self).__init__()
198
+ self.norm=nn.LayerNorm(size)
199
+ self.dropout=nn.Dropout(dropout)
200
+
201
+ def forward(self,x,sublayer):
202
+ """
203
+ Computes the output of the SubLayerConnection module.
204
+
205
+ Args:
206
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
207
+ sublayer (nn.Module): The sublayer module to be applied to the input tensor.
208
+
209
+ Returns:
210
+ torch.Tensor: The output tensor of shape (batch_size, seq_len, d_model).
211
+ """
212
+
213
+ return x + self.dropout(sublayer(self.norm(x)))
214
+
215
+
216
+ def clones(module,N):
217
+ """
218
+ Creates a list of N copies of the given nn.Module.
219
+
220
+ Args:
221
+ nn.Module: The nn.Module to be cloned.
222
+ N (int): The number of copies to be made.
223
+
224
+ Returns:
225
+ nn.ModuleList: A list of N identical nn.Module objects.
226
+ """
227
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
228
+
229
+
230
+
231
+
232
+ class SimpleMultiHeadAttention(nn.Module):
233
+ """Simple multi-head attention without GQA, sliding window, or KV cache"""
234
+
235
+ def __init__(self, dim: int, num_heads: int, device, dropout: float = 0.0, bias: bool = False):
236
+ """
237
+ Initialize the SimpleMultiHeadAttention module.
238
+
239
+ Args:
240
+ dim (int): The dimensionality of the input and output features.
241
+ num_heads (int): The number of attention heads.
242
+ device: The device to use (cpu or cuda).
243
+ dropout (float, optional): Dropout rate. Defaults to 0.0.
244
+ bias (bool, optional): Whether to use bias in linear layers. Defaults to False.
245
+ """
246
+ super().__init__()
247
+ assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
248
+
249
+ self.dim = dim
250
+ self.num_heads = num_heads
251
+ self.head_dim = dim // num_heads
252
+ self.device = device
253
+ self.dropout = dropout
254
+
255
+ # Combined projection for queries, keys, and values
256
+ self.c_attn = nn.Linear(dim, 3 * dim, bias=bias)
257
+ # Output projection
258
+ self.c_proj = nn.Linear(dim, dim, bias=bias)
259
+
260
+ # Dropout layers
261
+ self.attn_dropout = nn.Dropout(dropout)
262
+ self.resid_dropout = nn.Dropout(dropout)
263
+
264
+ # Use flash attention if available
265
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
266
+ if not self.flash:
267
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
268
+
269
+ def forward(self, x: torch.Tensor, freqs_complex: torch.Tensor = None, start_pos: int = 0):
270
+ """
271
+ Compute multi-head attention.
272
+
273
+ Args:
274
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
275
+ freqs_complex (torch.Tensor, optional): Complex position frequencies for RoPE. Defaults to None.
276
+ start_pos (int, optional): Starting position (unused in simple attention). Defaults to 0.
277
+
278
+ Returns:
279
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, dim).
280
+ """
281
+ batch_size, seq_len, _ = x.shape
282
+
283
+ # Calculate query, key, values for all heads in batch
284
+ q, k, v = self.c_attn(x).split(self.dim, dim=2)
285
+
286
+ # Reshape and transpose for multi-head attention
287
+ # (batch_size, seq_len, num_heads, head_dim) -> (batch_size, num_heads, seq_len, head_dim)
288
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
289
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
290
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
291
+
292
+ # Apply rotary embeddings if provided
293
+ if freqs_complex is not None:
294
+ # Note: apply_rotary_embeddings expects (batch, seq_len, num_heads, head_dim)
295
+ q_rotary = q.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
296
+ k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
297
+
298
+ q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
299
+ k_rotary = apply_rotary_embeddings(k_rotary, freq_complex=freqs_complex, device=self.device)
300
+
301
+ q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
302
+ k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
303
+
304
+ # Compute attention
305
+ if self.flash:
306
+ # Use flash attention for efficiency
307
+ y = torch.nn.functional.scaled_dot_product_attention(
308
+ q, k, v,
309
+ attn_mask=None,
310
+ dropout_p=self.dropout if self.training else 0,
311
+ is_causal=True
312
+ )
313
+ else:
314
+ # Manual implementation of attention
315
+ # Compute attention scores
316
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
317
+
318
+ # Apply causal mask
319
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=self.device))
320
+ causal_mask = causal_mask.view(1, 1, seq_len, seq_len)
321
+ attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
322
+
323
+ # Apply softmax
324
+ attn_weights = F.softmax(attn_scores, dim=-1)
325
+ attn_weights = self.attn_dropout(attn_weights)
326
+
327
+ # Apply attention to values
328
+ y = torch.matmul(attn_weights, v)
329
+
330
+ # Reshape back to (batch_size, seq_len, dim)
331
+ y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
332
+
333
+ # Output projection
334
+ y = self.resid_dropout(self.c_proj(y))
335
+
336
+ return y
337
+
338
+ def reset_cache(self):
339
+ """Reset cache (no-op for simple attention)"""
340
+ pass
341
+
342
+
343
+
344
+ class SwiGLUFFN(nn.Module):
345
+ def __init__(self,input_dim:int,hidden_dim:int):
346
+ """
347
+ Initializes the SwiGLUFFN module.
348
+
349
+ Args:
350
+ input_dim (int): The dimensionality of the input features.
351
+ hidden_dim (int): The dimensionality of the hidden layer.
352
+
353
+ Initializes three linear layers:
354
+ - `w_1`: Projects input features to the hidden dimension.
355
+ - `w_2`: Projects input features to the hidden dimension using a separate path.
356
+ - `out`: Projects the transformed hidden representation back to the input dimension.
357
+ """
358
+ super().__init__()
359
+ self.w_1=nn.Linear(input_dim,hidden_dim)
360
+ self.w_2=nn.Linear(input_dim,hidden_dim)
361
+ self.out=nn.Linear(hidden_dim,input_dim)
362
+ def forward(self,x:torch.Tensor):
363
+ """
364
+ Computes the output of the SwiGLUFFN module.
365
+ """
366
+ return self.out(self.w_1(x) * F.silu(self.w_2(x)))
367
+
368
+
369
+
370
+
371
+ class SparseMOE(nn.Module):
372
+ def __init__(self,d_model:int,d_hidden:int,num_experts:int=8,top_k:int=2):
373
+ """
374
+ Initializes the SparseMOE module.
375
+
376
+ Args:
377
+ d_model (int): The dimensionality of the input features.
378
+ d_hidden (int): The dimensionality of the hidden layer in each expert.
379
+ num_experts (int, optional): The number of expert networks. Defaults to 8.
380
+ top_k (int, optional): The number of experts to be selected for each input. Defaults to 2.
381
+
382
+ The module contains a list of expert networks, each an instance of the SwiGLUFFN module,
383
+ and a router to compute the selection distribution over the experts.
384
+ """
385
+
386
+ super().__init__()
387
+ self.d_model=d_model
388
+ self.d_hidden=d_hidden
389
+ self.num_experts=num_experts
390
+ self.top_k=top_k
391
+ self.experts=nn.ModuleList([SwiGLUFFN(input_dim=d_model,hidden_dim=d_hidden) for _ in range(num_experts)])
392
+ self.router=nn.Linear(d_model,num_experts)
393
+
394
+ def forward(self,x:torch.Tensor):
395
+ """
396
+ Computes the output of the SparseMOE module.
397
+
398
+ Args:
399
+ x (torch.Tensor): Input tensor of shape (batch_size,seq_len,d_model).
400
+
401
+ Returns:
402
+ tuple: Output tensor of shape (batch_size,seq_len,d_model) and the load balancing loss
403
+ """
404
+ batch_size,seq_len,d_model=x.shape
405
+
406
+ x_flat=x.view(-1,self.d_model) # (batch_size * seq_len, d_model)
407
+
408
+ #Step 1: get router scores for each token
409
+ router_logits=self.router(x_flat)
410
+ router_probs=F.softmax(router_logits,dim=-1)
411
+
412
+ #Step 2: get top-k experts
413
+ topk_probs,topk_indices=torch.topk(router_probs,self.top_k,dim=-1) #(batch_size*seq_len, top_k)
414
+
415
+ #Step 3: compute weighted sum of top-k experts
416
+ expert_outputs=[]
417
+ for i in range(self.top_k):
418
+ expert_idx=topk_indices[:,i]
419
+ outputs=torch.zeros_like(x_flat)
420
+
421
+ for expert_id in range(self.num_experts):
422
+ mask=(expert_id==expert_idx)
423
+ if mask.any():
424
+ selected_x=x_flat[mask]
425
+ expert_out=self.experts[expert_id](selected_x)
426
+ outputs[mask]=expert_out
427
+
428
+ weighted_output = topk_probs[:, i].unsqueeze(-1) * outputs
429
+ expert_outputs.append(weighted_output)
430
+
431
+ final_output = sum(expert_outputs)
432
+
433
+ final_output = final_output.view(batch_size, seq_len, d_model)
434
+
435
+ # router_probs_mean = router_probs.mean(dim=0)
436
+ # load_balancing_loss = (router_probs_mean * router_probs_mean).sum() * self.num_experts
437
+
438
+ # Step 4: Compute load balancing loss (Equation 4 from paper)
439
+ # f_i is the fraction of tokens dispatched to expert i
440
+ f_i = torch.zeros(self.num_experts, device=x.device)
441
+ for i in range(self.num_experts):
442
+ # Count how many tokens are assigned to expert i across all top-k selections
443
+ mask = (topk_indices == i).any(dim=-1) # tokens that use expert i
444
+ f_i[i] = mask.float().mean()
445
+
446
+ # P_i is the fraction of router probability allocated to expert i
447
+ P_i = router_probs.mean(dim=0) # average probability per expert across all tokens
448
+
449
+ # Load balancing loss: α * N * Σ(f_i * P_i)
450
+ alpha = 0.01 # auxiliary loss weight (you can make this configurable)
451
+ load_balancing_loss = alpha * self.num_experts * torch.sum(f_i * P_i)
452
+
453
+ return final_output, load_balancing_loss
454
+
455
+ ##final_loss = task_loss + router_loss_weight * router_loss
456
+
457
+
458
+
459
+ class layer(nn.Module):
460
+ def __init__(self,d_model:int,n_heads:int,num_experts:int,top_k:int,device,attn_eps:float,dropout:float,ffn_eps:float=1e-6):
461
+ """
462
+ Initialize the layer.
463
+
464
+ Args:
465
+ d_model (int): The dimensionality of the input and output features.
466
+ n_heads (int): The number of attention heads.
467
+ num_experts (int): The number of expert networks.
468
+ top_k (int): The number of experts to be selected for each input.
469
+ device (str): The device to use (cpu or cuda).
470
+ attn_eps (float): The small value added to the denominator in the attention normalization for numerical stability.
471
+ dropout (float): The dropout rate to be applied after the sublayer.
472
+ ffn_eps (float, optional): The small value added to the denominator in the feed-forward normalization for numerical stability. Defaults to 1e-6.
473
+ """
474
+ super().__init__()
475
+ self.d_model=d_model
476
+ self.n_heads=n_heads
477
+ self.device=device
478
+
479
+ self.attention=SimpleMultiHeadAttention(dim=self.d_model,num_heads=self.n_heads,device=self.device,
480
+ dropout=dropout,bias=False)
481
+
482
+ self.ffn=SparseMOE(d_model=self.d_model,d_hidden=self.d_model * 2, num_experts=num_experts,top_k=top_k) # for matching total params just do d_hidden = d_model // 2 & d_hidden = d_model * 2 for matching active params
483
+
484
+ self.attn_norm=RMSNorm(dim=d_model,eps=attn_eps)
485
+ self.ffn_norm=RMSNorm(dim=d_model,eps=ffn_eps)
486
+
487
+
488
+
489
+ def forward(self,x:torch.Tensor,freqs_complex:torch.Tensor,start_pos:int):
490
+
491
+ """
492
+ Computes the output of the layer.
493
+
494
+ Args:
495
+ x (torch.Tensor): The input tensor of shape (batch_size, seq_len, d_model).
496
+ freqs_complex (torch.Tensor): The complex position frequencies tensor of shape (seq_len, d_head/2).
497
+ start_pos (int): The starting position of the sequence.
498
+
499
+ Returns:
500
+ tuple: (output tensor of shape (batch_size, seq_len, d_model), load_balancing_loss)
501
+ """
502
+ # print(x.shape)
503
+ # print(freqs_complex.shape)
504
+
505
+ h=x + self.attention(self.attn_norm(x),freqs_complex=freqs_complex,start_pos=start_pos)
506
+ ffn_output,router_loss=self.ffn(self.ffn_norm(h))
507
+ out=h+ffn_output
508
+
509
+
510
+ return out, router_loss
511
+
512
+
513
+ class tiny_mixtral(nn.Module):
514
+ def __init__(self,args:ModelArgs):
515
+ super(tiny_mixtral, self).__init__()
516
+ self.args=args
517
+ self.vocab_size=args.vocab_size
518
+ self.n_layers=args.n_layers
519
+ self.tok_embedding=nn.Embedding(self.vocab_size,args.d_model)
520
+ self.layers=clones(layer(d_model=args.d_model,
521
+ n_heads=args.n_heads,
522
+ num_experts=args.n_experts,
523
+ top_k=args.top_k,
524
+ device=args.device,
525
+ attn_eps=args.attn_eps,
526
+ dropout=args.attn_dropout,
527
+ ffn_eps=args.ffn_eps),self.n_layers)
528
+ self.norm=RMSNorm(args.d_model,eps=args.norm_eps)
529
+
530
+ self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
531
+
532
+ self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
533
+
534
+
535
+ def forward(self,x:torch.Tensor,start_pos:int):
536
+ batch_size,seq_len=x.shape
537
+ h=self.tok_embedding(x)
538
+ freqs_complex=self.freqs_complex[start_pos:start_pos+seq_len]
539
+ total_load_balancing_loss = 0
540
+
541
+ for layer in self.layers:
542
+ h, load_balancing_loss = layer(h,freqs_complex=freqs_complex,start_pos=start_pos)
543
+ total_load_balancing_loss += load_balancing_loss
544
+
545
+ h=self.norm(h)
546
+ out=self.output(h).float()
547
+
548
+ return out, total_load_balancing_loss
549
+
550
+
551
+ class TinyMixtralForCausalLM(PreTrainedModel, GenerationMixin):
552
+ config_class = TinyMixtralConfig
553
+ base_model_prefix = "moe_model"
554
+
555
+ def __init__(self, config):
556
+ super().__init__(config)
557
+ args = ModelConfig(
558
+ vocab_size=config.vocab_size,
559
+ d_model=config.d_model,
560
+ d_head=config.d_head,
561
+ n_heads=config.n_heads,
562
+ n_layers=config.n_layers,
563
+ max_seq_len=config.max_seq_len,
564
+ n_experts=config.n_experts,
565
+ top_k=config.top_k_experts,
566
+ norm_eps=config.norm_eps,
567
+ attn_eps=config.attn_eps,
568
+ ffn_eps=config.ffn_eps,
569
+ device=config.device,
570
+ )
571
+ self.model = tiny_mixtral(args=args)
572
+ self.config = config
573
+ self.post_init()
574
+
575
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
576
+
577
+ outputs, load_balancing_loss = self.model(input_ids, start_pos=0)
578
+
579
+ return MoECausalLMOutputWithPast(
580
+ loss=None,
581
+ logits=outputs,
582
+ aux_loss=load_balancing_loss,
583
+ attentions=None,
584
+ )
585
+
586
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
587
+ return {
588
+ "input_ids": input_ids,
589
+ }