Update modeling_super_linear.py
Browse files- modeling_super_linear.py +3 -2
modeling_super_linear.py
CHANGED
|
@@ -354,6 +354,9 @@ class SparseNoisyMoE(nn.Module):
|
|
| 354 |
output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
|
| 355 |
|
| 356 |
load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
if get_prob:
|
| 359 |
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
|
@@ -509,8 +512,6 @@ class superLinear(nn.Module):
|
|
| 509 |
else:
|
| 510 |
out, self.moe_loss = self.moe(x)
|
| 511 |
|
| 512 |
-
print(out.shape)
|
| 513 |
-
|
| 514 |
|
| 515 |
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
|
| 516 |
#print("bitch")
|
|
|
|
| 354 |
output = torch.sum(self.topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
|
| 355 |
|
| 356 |
load_balancing_loss = self.calculate_load_balancing_loss(self.gate_outputs, batch_size)
|
| 357 |
+
|
| 358 |
+
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
| 359 |
+
print(expert_probs.shape)
|
| 360 |
|
| 361 |
if get_prob:
|
| 362 |
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
|
|
|
| 512 |
else:
|
| 513 |
out, self.moe_loss = self.moe(x)
|
| 514 |
|
|
|
|
|
|
|
| 515 |
|
| 516 |
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
|
| 517 |
#print("bitch")
|