Upload folder using huggingface_hub
Browse files- modeling_tiny_gpt.py +9 -3
modeling_tiny_gpt.py
CHANGED
|
@@ -173,6 +173,9 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
|
|
| 173 |
Returns:
|
| 174 |
torch.Tensor: The tensor after applying Rotary Position Embeddings.
|
| 175 |
"""
|
|
|
|
|
|
|
|
|
|
| 176 |
x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
|
| 177 |
|
| 178 |
freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
|
|
@@ -181,7 +184,8 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
|
|
| 181 |
x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
|
| 182 |
x_out=x_out.reshape(*x.shape)
|
| 183 |
|
| 184 |
-
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
|
|
@@ -296,7 +300,7 @@ class SimpleMultiHeadAttention(nn.Module):
|
|
| 296 |
k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
|
| 297 |
|
| 298 |
q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
|
| 299 |
-
k_rotary = apply_rotary_embeddings(k_rotary,
|
| 300 |
|
| 301 |
q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
| 302 |
k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
|
@@ -444,12 +448,14 @@ class tiny_gpt(nn.Module):
|
|
| 444 |
self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
|
| 445 |
|
| 446 |
self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
|
|
|
|
|
|
|
| 447 |
|
| 448 |
|
| 449 |
def forward(self,x:torch.Tensor,start_pos:int):
|
| 450 |
batch_size,seq_len=x.shape
|
| 451 |
h=self.tok_embedding(x)
|
| 452 |
-
freqs_complex=self.
|
| 453 |
|
| 454 |
for layer in self.layers:
|
| 455 |
h = layer(h,freqs_complex=freqs_complex,start_pos=start_pos)
|
|
|
|
| 173 |
Returns:
|
| 174 |
torch.Tensor: The tensor after applying Rotary Position Embeddings.
|
| 175 |
"""
|
| 176 |
+
# Ensure freq_complex is on the same device as x
|
| 177 |
+
freq_complex = freq_complex.to(x.device)
|
| 178 |
+
|
| 179 |
x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
|
| 180 |
|
| 181 |
freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
|
|
|
|
| 184 |
x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
|
| 185 |
x_out=x_out.reshape(*x.shape)
|
| 186 |
|
| 187 |
+
# Keep the output on the same device as the input, not the device parameter
|
| 188 |
+
return x_out.type_as(x)
|
| 189 |
|
| 190 |
|
| 191 |
|
|
|
|
| 300 |
k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
|
| 301 |
|
| 302 |
q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
|
| 303 |
+
k_rotary = apply_rotary_embeddings(k_rotary, freqs_complex, device=self.device)
|
| 304 |
|
| 305 |
q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
| 306 |
k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
|
|
|
| 448 |
self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
|
| 449 |
|
| 450 |
self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
|
| 451 |
+
# Register as buffer so it moves with the model
|
| 452 |
+
self.register_buffer('freqs_complex_buffer', self.freqs_complex)
|
| 453 |
|
| 454 |
|
| 455 |
def forward(self,x:torch.Tensor,start_pos:int):
|
| 456 |
batch_size,seq_len=x.shape
|
| 457 |
h=self.tok_embedding(x)
|
| 458 |
+
freqs_complex=self.freqs_complex_buffer[start_pos:start_pos+seq_len]
|
| 459 |
|
| 460 |
for layer in self.layers:
|
| 461 |
h = layer(h,freqs_complex=freqs_complex,start_pos=start_pos)
|