drbh
commited on
Commit
·
9a1816c
1
Parent(s):
b08f6c9
fix: adjust layer params in source
Browse files- torch-ext/megablocks/layers.py +15 -18
torch-ext/megablocks/layers.py
CHANGED
|
@@ -683,26 +683,23 @@ def moe_forward(
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 686 |
-
moe_top_k = getattr(self, "
|
| 687 |
-
moe_num_experts = getattr(self, "
|
| 688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
| 689 |
-
alpha = getattr(self, "alpha", 1.
|
| 690 |
-
moe_capacity_factor = getattr(self, "
|
| 691 |
-
moe_jitter_eps = getattr(self, "
|
| 692 |
-
moe_normalize_expert_weights = getattr(
|
| 693 |
-
self, "moe_normalize_expert_weights", None
|
| 694 |
-
)
|
| 695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 696 |
-
|
| 697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
| 698 |
-
expert_parallel_group =
|
| 699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
)
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
output, expert_weights_out, _ = moe_forward(
|
| 706 |
x=x,
|
| 707 |
router_weight=self.router.weight,
|
| 708 |
moe_top_k=moe_top_k,
|
|
@@ -725,4 +722,4 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 725 |
hidden_size=self.experts.hidden_size,
|
| 726 |
mlp_impl=mlp_impl,
|
| 727 |
)
|
| 728 |
-
return output, expert_weights_out
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
|
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
+
|
| 695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
| 696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
+
|
| 699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 701 |
+
|
| 702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
|
| 703 |
x=x,
|
| 704 |
router_weight=self.router.weight,
|
| 705 |
moe_top_k=moe_top_k,
|
|
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
| 725 |
+
return output, expert_weights_out
|