Marmik commited on
Commit
06846c5
·
verified ·
1 Parent(s): 97a6b05

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tiny_gpt.py +9 -3
modeling_tiny_gpt.py CHANGED
@@ -173,6 +173,9 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
173
  Returns:
174
  torch.Tensor: The tensor after applying Rotary Position Embeddings.
175
  """
 
 
 
176
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
177
 
178
  freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
@@ -181,7 +184,8 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
181
  x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
182
  x_out=x_out.reshape(*x.shape)
183
 
184
- return x_out.type_as(x).to(device)
 
185
 
186
 
187
 
@@ -296,7 +300,7 @@ class SimpleMultiHeadAttention(nn.Module):
296
  k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
297
 
298
  q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
299
- k_rotary = apply_rotary_embeddings(k_rotary, freq_complex=freqs_complex, device=self.device)
300
 
301
  q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
302
  k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
@@ -444,12 +448,14 @@ class tiny_gpt(nn.Module):
444
  self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
445
 
446
  self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
 
 
447
 
448
 
449
  def forward(self,x:torch.Tensor,start_pos:int):
450
  batch_size,seq_len=x.shape
451
  h=self.tok_embedding(x)
452
- freqs_complex=self.freqs_complex[start_pos:start_pos+seq_len]
453
 
454
  for layer in self.layers:
455
  h = layer(h,freqs_complex=freqs_complex,start_pos=start_pos)
 
173
  Returns:
174
  torch.Tensor: The tensor after applying Rotary Position Embeddings.
175
  """
176
+ # Ensure freq_complex is on the same device as x
177
+ freq_complex = freq_complex.to(x.device)
178
+
179
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
180
 
181
  freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
 
184
  x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
185
  x_out=x_out.reshape(*x.shape)
186
 
187
+ # Keep the output on the same device as the input, not the device parameter
188
+ return x_out.type_as(x)
189
 
190
 
191
 
 
300
  k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
301
 
302
  q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
303
+ k_rotary = apply_rotary_embeddings(k_rotary, freqs_complex, device=self.device)
304
 
305
  q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
306
  k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
 
448
  self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
449
 
450
  self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
451
+ # Register as buffer so it moves with the model
452
+ self.register_buffer('freqs_complex_buffer', self.freqs_complex)
453
 
454
 
455
  def forward(self,x:torch.Tensor,start_pos:int):
456
  batch_size,seq_len=x.shape
457
  h=self.tok_embedding(x)
458
+ freqs_complex=self.freqs_complex_buffer[start_pos:start_pos+seq_len]
459
 
460
  for layer in self.layers:
461
  h = layer(h,freqs_complex=freqs_complex,start_pos=start_pos)