Update modeling_super_linear.py
Browse files- modeling_super_linear.py +2 -1
modeling_super_linear.py
CHANGED
|
@@ -357,6 +357,7 @@ class SparseNoisyMoE(nn.Module):
|
|
| 357 |
|
| 358 |
if get_prob:
|
| 359 |
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
|
|
|
| 360 |
return output, load_balancing_loss, expert_probs
|
| 361 |
|
| 362 |
return output, load_balancing_loss
|
|
@@ -491,7 +492,7 @@ class superLinear(nn.Module):
|
|
| 491 |
return cycle
|
| 492 |
|
| 493 |
|
| 494 |
-
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=
|
| 495 |
if len(x_enc.shape) > 2:
|
| 496 |
x = x_enc.permute(0, 2, 1)
|
| 497 |
B, V, L = x.shape
|
|
|
|
| 357 |
|
| 358 |
if get_prob:
|
| 359 |
expert_probs = F.softmax(self.gate_outputs, dim=1)
|
| 360 |
+
print(expert_probs.shape)
|
| 361 |
return output, load_balancing_loss, expert_probs
|
| 362 |
|
| 363 |
return output, load_balancing_loss
|
|
|
|
| 492 |
return cycle
|
| 493 |
|
| 494 |
|
| 495 |
+
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None, freq=[None], get_prob=True):
|
| 496 |
if len(x_enc.shape) > 2:
|
| 497 |
x = x_enc.permute(0, 2, 1)
|
| 498 |
B, V, L = x.shape
|