Marmik commited on
Commit
d82dc42
·
verified ·
1 Parent(s): 9f01bf1

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tiny_mixtral.py +10 -3
modeling_tiny_mixtral.py CHANGED
@@ -22,6 +22,7 @@ class ModelConfig:
22
  n_layers:int = 5 #number of layers # 5
23
  max_seq_len:int = 1024 #maximum sequence length
24
  n_experts:int = 8 #number of experts # 8
 
25
  top_k:int = 2 #top k # 2
26
  # do not change
27
  attn_dropout:float = 0.0 #attention dropout
@@ -179,6 +180,9 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
179
  Returns:
180
  torch.Tensor: The tensor after applying Rotary Position Embeddings.
181
  """
 
 
 
182
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
183
 
184
  freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
@@ -187,7 +191,8 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
187
  x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
188
  x_out=x_out.reshape(*x.shape)
189
 
190
- return x_out.type_as(x).to(device)
 
191
 
192
 
193
 
@@ -302,7 +307,7 @@ class SimpleMultiHeadAttention(nn.Module):
302
  k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
303
 
304
  q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
305
- k_rotary = apply_rotary_embeddings(k_rotary, freq_complex=freqs_complex, device=self.device)
306
 
307
  q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
308
  k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
@@ -536,12 +541,14 @@ class tiny_mixtral(nn.Module):
536
  self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
537
 
538
  self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
 
 
539
 
540
 
541
  def forward(self,x:torch.Tensor,start_pos:int):
542
  batch_size,seq_len=x.shape
543
  h=self.tok_embedding(x)
544
- freqs_complex=self.freqs_complex[start_pos:start_pos+seq_len]
545
  total_load_balancing_loss = 0
546
 
547
  for layer in self.layers:
 
22
  n_layers:int = 5 #number of layers # 5
23
  max_seq_len:int = 1024 #maximum sequence length
24
  n_experts:int = 8 #number of experts # 8
25
+
26
  top_k:int = 2 #top k # 2
27
  # do not change
28
  attn_dropout:float = 0.0 #attention dropout
 
180
  Returns:
181
  torch.Tensor: The tensor after applying Rotary Position Embeddings.
182
  """
183
+ # Ensure freq_complex is on the same device as x
184
+ freq_complex = freq_complex.to(x.device)
185
+
186
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
187
 
188
  freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
 
191
  x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
192
  x_out=x_out.reshape(*x.shape)
193
 
194
+ # Keep the output on the same device as the input, not the device parameter
195
+ return x_out.type_as(x)
196
 
197
 
198
 
 
307
  k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
308
 
309
  q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
310
+ k_rotary = apply_rotary_embeddings(k_rotary, freqs_complex, device=self.device)
311
 
312
  q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
313
  k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
 
541
  self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
542
 
543
  self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
544
+ # Register as buffer so it moves with the model
545
+ self.register_buffer('freqs_complex_buffer', self.freqs_complex)
546
 
547
 
548
  def forward(self,x:torch.Tensor,start_pos:int):
549
  batch_size,seq_len=x.shape
550
  h=self.tok_embedding(x)
551
+ freqs_complex=self.freqs_complex_buffer[start_pos:start_pos+seq_len]
552
  total_load_balancing_loss = 0
553
 
554
  for layer in self.layers: