Marmik commited on
Commit
054320c
·
verified ·
1 Parent(s): 1cae0b0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tiny_mixtral.py +8 -3
modeling_tiny_mixtral.py CHANGED
@@ -179,6 +179,9 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
179
  Returns:
180
  torch.Tensor: The tensor after applying Rotary Position Embeddings.
181
  """
 
 
 
182
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
183
 
184
  freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
@@ -302,7 +305,7 @@ class SimpleMultiHeadAttention(nn.Module):
302
  k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
303
 
304
  q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
305
- k_rotary = apply_rotary_embeddings(k_rotary, freq_complex=freqs_complex, device=self.device)
306
 
307
  q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
308
  k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
@@ -456,7 +459,7 @@ class SparseMOE(nn.Module):
456
  alpha = 0.01 # auxiliary loss weight (you can make this configurable)
457
  load_balancing_loss = alpha * self.num_experts * torch.sum(f_i * P_i)
458
 
459
- return final_output
460
 
461
  ##final_loss = task_loss + router_loss_weight * router_loss
462
 
@@ -536,12 +539,14 @@ class tiny_mixtral(nn.Module):
536
  self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
537
 
538
  self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
 
 
539
 
540
 
541
  def forward(self,x:torch.Tensor,start_pos:int):
542
  batch_size,seq_len=x.shape
543
  h=self.tok_embedding(x)
544
- freqs_complex=self.freqs_complex[start_pos:start_pos+seq_len]
545
  total_load_balancing_loss = 0
546
 
547
  for layer in self.layers:
 
179
  Returns:
180
  torch.Tensor: The tensor after applying Rotary Position Embeddings.
181
  """
182
+ # Ensure freq_complex is on the same device as x
183
+ freq_complex = freq_complex.to(x.device)
184
+
185
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
186
 
187
  freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
 
305
  k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
306
 
307
  q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
308
+ k_rotary = apply_rotary_embeddings(k_rotary, freqs_complex, device=self.device)
309
 
310
  q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
311
  k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
 
459
  alpha = 0.01 # auxiliary loss weight (you can make this configurable)
460
  load_balancing_loss = alpha * self.num_experts * torch.sum(f_i * P_i)
461
 
462
+ return final_output, load_balancing_loss
463
 
464
  ##final_loss = task_loss + router_loss_weight * router_loss
465
 
 
539
  self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
540
 
541
  self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
542
+ # Register as buffer so it moves with the model
543
+ self.register_buffer('freqs_complex_buffer', self.freqs_complex)
544
 
545
 
546
  def forward(self,x:torch.Tensor,start_pos:int):
547
  batch_size,seq_len=x.shape
548
  h=self.tok_embedding(x)
549
+ freqs_complex=self.freqs_complex_buffer[start_pos:start_pos+seq_len]
550
  total_load_balancing_loss = 0
551
 
552
  for layer in self.layers: