add greedy
#11
by
saada
- opened
- modeling_deepseek.py +4 -0
modeling_deepseek.py
CHANGED
|
@@ -461,6 +461,10 @@ class MoEGate(nn.Module):
|
|
| 461 |
tmp_scores, k=self.top_k, dim=-1, sorted=False
|
| 462 |
)
|
| 463 |
topk_weight = scores.gather(1, topk_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
else:
|
| 465 |
raise NotImplementedError(
|
| 466 |
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
|
|
|
| 461 |
tmp_scores, k=self.top_k, dim=-1, sorted=False
|
| 462 |
)
|
| 463 |
topk_weight = scores.gather(1, topk_idx)
|
| 464 |
+
elif self.topk_method == "greedy":
|
| 465 |
+
topk_weight, topk_idx = torch.topk(
|
| 466 |
+
scores, k=self.top_k, dim=-1, sorted=False
|
| 467 |
+
)
|
| 468 |
else:
|
| 469 |
raise NotImplementedError(
|
| 470 |
f"insupportable TopK function for MoE gating: {self.topk_method}"
|