drbh
commited on
Commit
·
13afbbe
1
Parent(s):
9354548
fix: add parallel forward functional logic
Browse files- torch-ext/megablocks/layers.py +195 -20
torch-ext/megablocks/layers.py
CHANGED
|
@@ -121,7 +121,15 @@ def scale_grad(
|
|
| 121 |
|
| 122 |
|
| 123 |
# Forward pass for the MLP layer
|
| 124 |
-
def mlp_forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# Scale weights
|
| 126 |
w1 = scale_grad(w1, gradient_scale)
|
| 127 |
w2 = scale_grad(w2, gradient_scale)
|
|
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
|
|
| 144 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 145 |
|
| 146 |
|
| 147 |
-
## START: Load Balancing Loss (unused at the moment)
|
| 148 |
-
|
| 149 |
# Global variable to store load balancing loss
|
| 150 |
_LOAD_BALANCING_LOSS = []
|
| 151 |
|
|
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
|
|
| 234 |
return scale * torch.dot(tokens_per_expert, expert_scores)
|
| 235 |
|
| 236 |
|
| 237 |
-
## END Load Balancing Loss
|
| 238 |
-
|
| 239 |
-
|
| 240 |
# Calculate the expert capacity based on tokens, top_k, number of experts,
|
| 241 |
# expert parallel group, capacity factor, and whether expert model parallelism is used.
|
| 242 |
def expert_capacity(
|
|
@@ -410,7 +413,6 @@ def forward_once(
|
|
| 410 |
return x, tokens_per_expert
|
| 411 |
|
| 412 |
|
| 413 |
-
# TODO: replace with functional logic once aligned with ref
|
| 414 |
def parallel_forward_once(
|
| 415 |
x: torch.Tensor,
|
| 416 |
expert_weights: torch.Tensor,
|
|
@@ -429,15 +431,180 @@ def parallel_forward_once(
|
|
| 429 |
moe_expert_model_parallelism: bool = True,
|
| 430 |
hidden_size: int = 1152,
|
| 431 |
):
|
| 432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
def forward(
|
| 440 |
-
# self,
|
| 441 |
x: torch.Tensor,
|
| 442 |
router_weight: torch.Tensor,
|
| 443 |
moe_top_k: int,
|
|
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
|
|
| 446 |
moe_normalize_expert_weights: int = None,
|
| 447 |
uniform_expert_assignment: bool = False,
|
| 448 |
training: bool = False,
|
| 449 |
-
#
|
| 450 |
w1: torch.Tensor = None,
|
| 451 |
w2: torch.Tensor = None,
|
| 452 |
w1_bias: torch.Tensor = None,
|
|
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
|
|
| 522 |
return x, expert_weights, router_scores
|
| 523 |
|
| 524 |
|
| 525 |
-
|
| 526 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 527 |
|
| 528 |
def forward(
|
|
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 536 |
w2 = self.experts.down_proj.data
|
| 537 |
w1_bias = self.experts.gate_up_proj_bias.data
|
| 538 |
w2_bias = self.experts.down_proj_bias.data
|
| 539 |
-
expert_parallel_group = None
|
| 540 |
|
| 541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
hidden_size = self.experts.hidden_size
|
| 543 |
-
|
| 544 |
output, expert_weights_out, router_scores = MyReplacementLayer.forward(
|
| 545 |
x=x,
|
| 546 |
router_weight=router_weight,
|
|
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 559 |
sort_end_bit=sort_end_bit,
|
| 560 |
expert_parallel_group=expert_parallel_group,
|
| 561 |
moe_capacity_factor=1.0,
|
| 562 |
-
moe_expert_model_parallelism=
|
| 563 |
-
forward_fn=
|
| 564 |
hidden_size=hidden_size,
|
| 565 |
)
|
| 566 |
-
return output, expert_weights_out
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
# Forward pass for the MLP layer
|
| 124 |
+
def mlp_forward(
|
| 125 |
+
x: torch.Tensor,
|
| 126 |
+
w1: torch.Tensor,
|
| 127 |
+
w2: torch.Tensor,
|
| 128 |
+
w1_bias: torch.Tensor,
|
| 129 |
+
w2_bias: torch.Tensor,
|
| 130 |
+
gradient_scale: Optional[float] = None,
|
| 131 |
+
alpha: float = 1.702,
|
| 132 |
+
):
|
| 133 |
# Scale weights
|
| 134 |
w1 = scale_grad(w1, gradient_scale)
|
| 135 |
w2 = scale_grad(w2, gradient_scale)
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
|
|
| 240 |
return scale * torch.dot(tokens_per_expert, expert_scores)
|
| 241 |
|
| 242 |
|
|
|
|
|
|
|
|
|
|
| 243 |
# Calculate the expert capacity based on tokens, top_k, number of experts,
|
| 244 |
# expert parallel group, capacity factor, and whether expert model parallelism is used.
|
| 245 |
def expert_capacity(
|
|
|
|
| 413 |
return x, tokens_per_expert
|
| 414 |
|
| 415 |
|
|
|
|
| 416 |
def parallel_forward_once(
|
| 417 |
x: torch.Tensor,
|
| 418 |
expert_weights: torch.Tensor,
|
|
|
|
| 431 |
moe_expert_model_parallelism: bool = True,
|
| 432 |
hidden_size: int = 1152,
|
| 433 |
):
|
| 434 |
+
# Flatten inputs
|
| 435 |
+
expert_weights = expert_weights.flatten()
|
| 436 |
+
top_experts = top_experts.flatten()
|
| 437 |
+
|
| 438 |
+
with torch.no_grad():
|
| 439 |
+
# Step 1: Local permutation setup
|
| 440 |
+
indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
|
| 441 |
+
top_experts, sort_end_bit, num_experts
|
| 442 |
+
)
|
| 443 |
|
| 444 |
+
# Calculate sharding parameters
|
| 445 |
+
world_size = dist.get_world_size(expert_parallel_group)
|
| 446 |
+
hidden_sharding_deg = hidden_sharding_degree(
|
| 447 |
+
world_size, num_experts, hidden_size
|
| 448 |
+
)
|
| 449 |
+
experts_per_rank_val = experts_per_rank(num_experts, world_size)
|
| 450 |
|
| 451 |
+
# Replicate token counts for hidden sharding
|
| 452 |
+
repeated_tokens_per_expert = ops.repeat(
|
| 453 |
+
tokens_per_expert, (hidden_sharding_deg,)
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# Exchange token counts across devices
|
| 457 |
+
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
|
| 458 |
+
# print("world_size:", world_size)
|
| 459 |
+
# print("experts_per_rank_val:", experts_per_rank_val)
|
| 460 |
+
|
| 461 |
+
# Ensure CUB knows which device to use
|
| 462 |
+
tpe_handle = dist.all_to_all_single(
|
| 463 |
+
parallel_tokens_per_expert,
|
| 464 |
+
repeated_tokens_per_expert,
|
| 465 |
+
group=expert_parallel_group,
|
| 466 |
+
async_op=True,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# Step 2: Local permutation - group tokens by target device
|
| 470 |
+
x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
|
| 471 |
+
x = ops.gather(x, indices, bin_ids, bins, top_k)
|
| 472 |
+
|
| 473 |
+
# Step 3: Compute communication counts and exchange tokens
|
| 474 |
+
with torch.no_grad():
|
| 475 |
+
tpe_handle.wait()
|
| 476 |
+
|
| 477 |
+
# Reshape for per-device calculations
|
| 478 |
+
repeated_tokens_per_expert = repeated_tokens_per_expert.view(
|
| 479 |
+
world_size, experts_per_rank_val
|
| 480 |
+
)
|
| 481 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert.view(
|
| 482 |
+
world_size, experts_per_rank_val
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Calculate send/recv counts
|
| 486 |
+
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
|
| 487 |
+
# recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
|
| 488 |
+
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
|
| 489 |
+
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
|
| 490 |
+
tokens_received = sum(recv_counts)
|
| 491 |
+
|
| 492 |
+
# Replicate for hidden sharding
|
| 493 |
+
x = ops.repeat(x, (hidden_sharding_deg, 1))
|
| 494 |
+
|
| 495 |
+
# Cross-device token exchange
|
| 496 |
+
parallel_x, parallel_x_handle = ops.all_to_all(
|
| 497 |
+
x,
|
| 498 |
+
recv_counts,
|
| 499 |
+
send_counts,
|
| 500 |
+
expert_parallel_group,
|
| 501 |
+
async_op=True
|
| 502 |
+
)
|
| 503 |
|
| 504 |
+
with torch.no_grad():
|
| 505 |
+
# Step 4: Setup for local expert computation
|
| 506 |
+
replicate_bins = ops.inclusive_cumsum(
|
| 507 |
+
parallel_tokens_per_expert.flatten(),
|
| 508 |
+
0
|
| 509 |
+
)
|
| 510 |
+
replicate_bins = (
|
| 511 |
+
replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Create expert indices for received tokens
|
| 515 |
+
parallel_top_expert = torch.remainder(
|
| 516 |
+
torch.arange(
|
| 517 |
+
num_experts * hidden_sharding_deg,
|
| 518 |
+
dtype=torch.int32,
|
| 519 |
+
device=indices.device,
|
| 520 |
+
),
|
| 521 |
+
experts_per_rank_val,
|
| 522 |
+
)
|
| 523 |
+
parallel_top_expert = ops.replicate(
|
| 524 |
+
parallel_top_expert.unsqueeze(dim=0),
|
| 525 |
+
replicate_bins,
|
| 526 |
+
tokens_received,
|
| 527 |
+
).flatten()
|
| 528 |
+
|
| 529 |
+
# Sort tokens by expert assignment
|
| 530 |
+
parallel_bin_ids, parallel_indices = ops.sort(
|
| 531 |
+
parallel_top_expert,
|
| 532 |
+
sort_end_bit,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# Calculate bins for local experts
|
| 536 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
|
| 537 |
+
dim=0, dtype=torch.int
|
| 538 |
+
)
|
| 539 |
+
parallel_bins = ops.inclusive_cumsum(
|
| 540 |
+
parallel_tokens_per_expert,
|
| 541 |
+
0
|
| 542 |
+
)
|
| 543 |
+
parallel_bins = (
|
| 544 |
+
parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
# Calculate expert capacity
|
| 548 |
+
expert_capacity = expert_capacity_fn(
|
| 549 |
+
tokens_received,
|
| 550 |
+
top_k,
|
| 551 |
+
experts_per_rank_val,
|
| 552 |
+
expert_parallel_group,
|
| 553 |
+
moe_capacity_factor,
|
| 554 |
+
moe_expert_model_parallelism,
|
| 555 |
+
)
|
| 556 |
+
if expert_capacity == 0:
|
| 557 |
+
expert_capacity = torch.max(parallel_tokens_per_expert).item()
|
| 558 |
+
|
| 559 |
+
# Locally permute the tokens and perform the expert computation.
|
| 560 |
+
# Block to make sure that the cross-device permutation is complete.
|
| 561 |
+
# if self.args.mlp_impl == 'grouped':
|
| 562 |
+
|
| 563 |
+
# TODO: dont always assume grouped MLP
|
| 564 |
+
if True:
|
| 565 |
+
# GroupedMLP requires counts on CPU. We can use the tensor already
|
| 566 |
+
# moved to CPU for the prior all_to_all, which avoids an extra
|
| 567 |
+
# device synchronization.
|
| 568 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
|
| 569 |
+
dim=0,
|
| 570 |
+
dtype=torch.int,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Step 5: Expert computation
|
| 574 |
+
parallel_x_handle.wait()
|
| 575 |
+
|
| 576 |
+
parallel_x = permute_and_compute(
|
| 577 |
+
parallel_x,
|
| 578 |
+
parallel_tokens_per_expert,
|
| 579 |
+
parallel_indices,
|
| 580 |
+
parallel_bin_ids,
|
| 581 |
+
None, # expert_weights
|
| 582 |
+
parallel_bins,
|
| 583 |
+
expert_capacity,
|
| 584 |
+
top_k=1,
|
| 585 |
+
w1=w1,
|
| 586 |
+
w2=w2,
|
| 587 |
+
w1_bias=w1_bias,
|
| 588 |
+
w2_bias=w2_bias,
|
| 589 |
+
gradient_scale=gradient_scale,
|
| 590 |
+
alpha=alpha,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# Step 6: Reverse communication - send results back
|
| 594 |
+
x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
|
| 595 |
+
|
| 596 |
+
# Step 7: Reduce across hidden sharding dimension
|
| 597 |
+
shape = (hidden_sharding_deg, -1, hidden_size)
|
| 598 |
+
x = x.view(shape).sum(dim=0)
|
| 599 |
+
|
| 600 |
+
# Step 8: Final local unpermutation
|
| 601 |
+
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
|
| 602 |
+
|
| 603 |
+
return x, tokens_per_expert.flatten()
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class MyReplacementLayer(torch.nn.Module):
|
| 607 |
def forward(
|
|
|
|
| 608 |
x: torch.Tensor,
|
| 609 |
router_weight: torch.Tensor,
|
| 610 |
moe_top_k: int,
|
|
|
|
| 613 |
moe_normalize_expert_weights: int = None,
|
| 614 |
uniform_expert_assignment: bool = False,
|
| 615 |
training: bool = False,
|
|
|
|
| 616 |
w1: torch.Tensor = None,
|
| 617 |
w2: torch.Tensor = None,
|
| 618 |
w1_bias: torch.Tensor = None,
|
|
|
|
| 688 |
return x, expert_weights, router_scores
|
| 689 |
|
| 690 |
|
|
|
|
| 691 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 692 |
|
| 693 |
def forward(
|
|
|
|
| 701 |
w2 = self.experts.down_proj.data
|
| 702 |
w1_bias = self.experts.gate_up_proj_bias.data
|
| 703 |
w2_bias = self.experts.down_proj_bias.data
|
|
|
|
| 704 |
|
| 705 |
+
# check if the expert_parallel_group attribute is set
|
| 706 |
+
if hasattr(self, "expert_parallel_group"):
|
| 707 |
+
expert_parallel_group = self.expert_parallel_group
|
| 708 |
+
moe_expert_model_parallelism = True
|
| 709 |
+
forward_fn = parallel_forward_once
|
| 710 |
+
else:
|
| 711 |
+
expert_parallel_group = None
|
| 712 |
+
moe_expert_model_parallelism = False
|
| 713 |
+
forward_fn = forward_once
|
| 714 |
+
|
| 715 |
+
sort_end_bit = max(
|
| 716 |
+
int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
|
| 717 |
+
)
|
| 718 |
hidden_size = self.experts.hidden_size
|
|
|
|
| 719 |
output, expert_weights_out, router_scores = MyReplacementLayer.forward(
|
| 720 |
x=x,
|
| 721 |
router_weight=router_weight,
|
|
|
|
| 734 |
sort_end_bit=sort_end_bit,
|
| 735 |
expert_parallel_group=expert_parallel_group,
|
| 736 |
moe_capacity_factor=1.0,
|
| 737 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 738 |
+
forward_fn=forward_fn,
|
| 739 |
hidden_size=hidden_size,
|
| 740 |
)
|
| 741 |
+
return output, expert_weights_out
|