Upload folder using huggingface_hub
Browse files- modeling_tiny_mixtral.py +10 -3
modeling_tiny_mixtral.py
CHANGED
|
@@ -22,6 +22,7 @@ class ModelConfig:
|
|
| 22 |
n_layers:int = 5 #number of layers # 5
|
| 23 |
max_seq_len:int = 1024 #maximum sequence length
|
| 24 |
n_experts:int = 8 #number of experts # 8
|
|
|
|
| 25 |
top_k:int = 2 #top k # 2
|
| 26 |
# do not change
|
| 27 |
attn_dropout:float = 0.0 #attention dropout
|
|
@@ -179,6 +180,9 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
|
|
| 179 |
Returns:
|
| 180 |
torch.Tensor: The tensor after applying Rotary Position Embeddings.
|
| 181 |
"""
|
|
|
|
|
|
|
|
|
|
| 182 |
x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
|
| 183 |
|
| 184 |
freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
|
|
@@ -187,7 +191,8 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
|
|
| 187 |
x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
|
| 188 |
x_out=x_out.reshape(*x.shape)
|
| 189 |
|
| 190 |
-
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
|
|
@@ -302,7 +307,7 @@ class SimpleMultiHeadAttention(nn.Module):
|
|
| 302 |
k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
|
| 303 |
|
| 304 |
q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
|
| 305 |
-
k_rotary = apply_rotary_embeddings(k_rotary,
|
| 306 |
|
| 307 |
q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
| 308 |
k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
|
@@ -536,12 +541,14 @@ class tiny_mixtral(nn.Module):
|
|
| 536 |
self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
|
| 537 |
|
| 538 |
self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
|
|
|
|
|
|
|
| 539 |
|
| 540 |
|
| 541 |
def forward(self,x:torch.Tensor,start_pos:int):
|
| 542 |
batch_size,seq_len=x.shape
|
| 543 |
h=self.tok_embedding(x)
|
| 544 |
-
freqs_complex=self.
|
| 545 |
total_load_balancing_loss = 0
|
| 546 |
|
| 547 |
for layer in self.layers:
|
|
|
|
| 22 |
n_layers:int = 5 #number of layers # 5
|
| 23 |
max_seq_len:int = 1024 #maximum sequence length
|
| 24 |
n_experts:int = 8 #number of experts # 8
|
| 25 |
+
|
| 26 |
top_k:int = 2 #top k # 2
|
| 27 |
# do not change
|
| 28 |
attn_dropout:float = 0.0 #attention dropout
|
|
|
|
| 180 |
Returns:
|
| 181 |
torch.Tensor: The tensor after applying Rotary Position Embeddings.
|
| 182 |
"""
|
| 183 |
+
# Ensure freq_complex is on the same device as x
|
| 184 |
+
freq_complex = freq_complex.to(x.device)
|
| 185 |
+
|
| 186 |
x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
|
| 187 |
|
| 188 |
freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
|
|
|
|
| 191 |
x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
|
| 192 |
x_out=x_out.reshape(*x.shape)
|
| 193 |
|
| 194 |
+
# Keep the output on the same device as the input, not the device parameter
|
| 195 |
+
return x_out.type_as(x)
|
| 196 |
|
| 197 |
|
| 198 |
|
|
|
|
| 307 |
k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
|
| 308 |
|
| 309 |
q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
|
| 310 |
+
k_rotary = apply_rotary_embeddings(k_rotary, freqs_complex, device=self.device)
|
| 311 |
|
| 312 |
q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
| 313 |
k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
|
|
|
| 541 |
self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
|
| 542 |
|
| 543 |
self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
|
| 544 |
+
# Register as buffer so it moves with the model
|
| 545 |
+
self.register_buffer('freqs_complex_buffer', self.freqs_complex)
|
| 546 |
|
| 547 |
|
| 548 |
def forward(self,x:torch.Tensor,start_pos:int):
|
| 549 |
batch_size,seq_len=x.shape
|
| 550 |
h=self.tok_embedding(x)
|
| 551 |
+
freqs_complex=self.freqs_complex_buffer[start_pos:start_pos+seq_len]
|
| 552 |
total_load_balancing_loss = 0
|
| 553 |
|
| 554 |
for layer in self.layers:
|