Upload folder using huggingface_hub
Browse files- modeling_tiny_mixtral.py +8 -3
modeling_tiny_mixtral.py
CHANGED
|
@@ -179,6 +179,9 @@ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str)
|
|
| 179 |
Returns:
|
| 180 |
torch.Tensor: The tensor after applying Rotary Position Embeddings.
|
| 181 |
"""
|
|
|
|
|
|
|
|
|
|
| 182 |
x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
|
| 183 |
|
| 184 |
freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
|
|
@@ -302,7 +305,7 @@ class SimpleMultiHeadAttention(nn.Module):
|
|
| 302 |
k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
|
| 303 |
|
| 304 |
q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
|
| 305 |
-
k_rotary = apply_rotary_embeddings(k_rotary,
|
| 306 |
|
| 307 |
q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
| 308 |
k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
|
@@ -456,7 +459,7 @@ class SparseMOE(nn.Module):
|
|
| 456 |
alpha = 0.01 # auxiliary loss weight (you can make this configurable)
|
| 457 |
load_balancing_loss = alpha * self.num_experts * torch.sum(f_i * P_i)
|
| 458 |
|
| 459 |
-
return final_output
|
| 460 |
|
| 461 |
##final_loss = task_loss + router_loss_weight * router_loss
|
| 462 |
|
|
@@ -536,12 +539,14 @@ class tiny_mixtral(nn.Module):
|
|
| 536 |
self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
|
| 537 |
|
| 538 |
self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
|
|
|
|
|
|
|
| 539 |
|
| 540 |
|
| 541 |
def forward(self,x:torch.Tensor,start_pos:int):
|
| 542 |
batch_size,seq_len=x.shape
|
| 543 |
h=self.tok_embedding(x)
|
| 544 |
-
freqs_complex=self.
|
| 545 |
total_load_balancing_loss = 0
|
| 546 |
|
| 547 |
for layer in self.layers:
|
|
|
|
| 179 |
Returns:
|
| 180 |
torch.Tensor: The tensor after applying Rotary Position Embeddings.
|
| 181 |
"""
|
| 182 |
+
# Ensure freq_complex is on the same device as x
|
| 183 |
+
freq_complex = freq_complex.to(x.device)
|
| 184 |
+
|
| 185 |
x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
|
| 186 |
|
| 187 |
freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
|
|
|
|
| 305 |
k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
|
| 306 |
|
| 307 |
q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
|
| 308 |
+
k_rotary = apply_rotary_embeddings(k_rotary, freqs_complex, device=self.device)
|
| 309 |
|
| 310 |
q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
| 311 |
k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
|
|
|
|
| 459 |
alpha = 0.01 # auxiliary loss weight (you can make this configurable)
|
| 460 |
load_balancing_loss = alpha * self.num_experts * torch.sum(f_i * P_i)
|
| 461 |
|
| 462 |
+
return final_output, load_balancing_loss
|
| 463 |
|
| 464 |
##final_loss = task_loss + router_loss_weight * router_loss
|
| 465 |
|
|
|
|
| 539 |
self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
|
| 540 |
|
| 541 |
self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
|
| 542 |
+
# Register as buffer so it moves with the model
|
| 543 |
+
self.register_buffer('freqs_complex_buffer', self.freqs_complex)
|
| 544 |
|
| 545 |
|
| 546 |
def forward(self,x:torch.Tensor,start_pos:int):
|
| 547 |
batch_size,seq_len=x.shape
|
| 548 |
h=self.tok_embedding(x)
|
| 549 |
+
freqs_complex=self.freqs_complex_buffer[start_pos:start_pos+seq_len]
|
| 550 |
total_load_balancing_loss = 0
|
| 551 |
|
| 552 |
for layer in self.layers:
|