Build uploaded using `kernels`.
Browse files- build/torch210-cxx11-cu126-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} +1 -1
- build/torch210-cxx11-cu126-aarch64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu126-aarch64-linux/xpu_fused_moe.py +57 -4
- build/torch210-cxx11-cu128-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} +1 -1
- build/torch210-cxx11-cu128-aarch64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu128-aarch64-linux/xpu_fused_moe.py +57 -4
- build/torch210-cxx11-cu130-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} +1 -1
- build/torch210-cxx11-cu130-aarch64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu130-aarch64-linux/xpu_fused_moe.py +57 -4
- build/torch29-cxx11-cu126-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} +1 -1
- build/torch29-cxx11-cu126-aarch64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu126-aarch64-linux/xpu_fused_moe.py +57 -4
- build/torch29-cxx11-cu128-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} +1 -1
- build/torch29-cxx11-cu128-aarch64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu128-aarch64-linux/xpu_fused_moe.py +57 -4
- build/torch29-cxx11-cu130-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} +1 -1
- build/torch29-cxx11-cu130-aarch64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu130-aarch64-linux/xpu_fused_moe.py +57 -4
build/torch210-cxx11-cu126-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15124328
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d43ea617155587acccc47750e126596b0438c63c7ada6f3607a2ed4603337f72
|
| 3 |
size 15124328
|
build/torch210-cxx11-cu126-aarch64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_6e04dec
|
| 3 |
+
ops = torch.ops._megablocks_cuda_6e04dec
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_6e04dec::{op_name}"
|
build/torch210-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
CHANGED
|
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
-
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
-
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
|
|
|
|
|
|
| 205 |
is_fp8=False,
|
| 206 |
is_int4=False,
|
| 207 |
is_mxfp4=False):
|
|
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
-
workspace = torch.
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 341 |
workspace=workspace,
|
| 342 |
hidden_size=hidden_size,
|
| 343 |
inter_size=inter_size,
|
|
|
|
|
|
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
expert_first_token_offset_bytes = workspace[
|
|
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
|
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if torch.compiler.is_compiling():
|
| 356 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
|
|
| 359 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
else:
|
| 363 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
|
|
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
|
|
| 451 |
is_B_mxfp4=is_mxfp4)
|
| 452 |
|
| 453 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
|
|
|
| 454 |
unpermuted_row_to_permuted_row,
|
|
|
|
| 455 |
num_experts_per_node)
|
| 456 |
return output
|
| 457 |
|
|
@@ -500,6 +514,21 @@ def route_tokens_xpu(
|
|
| 500 |
return logits, expert_weights, expert_indices
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 504 |
can_torch_compile: bool = True
|
| 505 |
|
|
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 524 |
self.experts, "normalize_expert_weights", None
|
| 525 |
)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 528 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 529 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 598 |
topk_ids=expert_indices,
|
| 599 |
n_experts_per_token=moe_top_k,
|
| 600 |
activation=activation,
|
| 601 |
-
num_experts=
|
|
|
|
|
|
|
| 602 |
is_fp8=is_fp8,
|
| 603 |
is_int4=is_int4,
|
| 604 |
is_mxfp4=is_mxfp4,
|
| 605 |
)
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
# Restore original shape
|
| 608 |
output = output.view(in_shape)
|
| 609 |
|
|
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
| 205 |
+
ep_rank=0,
|
| 206 |
+
ep_size=1,
|
| 207 |
is_fp8=False,
|
| 208 |
is_int4=False,
|
| 209 |
is_mxfp4=False):
|
|
|
|
| 331 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 332 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 333 |
|
| 334 |
+
workspace = torch.zeros(map_offset,
|
| 335 |
dtype=torch.uint8,
|
| 336 |
device=hidden_states.device)
|
| 337 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
workspace=workspace,
|
| 344 |
hidden_size=hidden_size,
|
| 345 |
inter_size=inter_size,
|
| 346 |
+
ep_rank=ep_rank,
|
| 347 |
+
ep_size=ep_size,
|
| 348 |
num_experts_on_rank=num_experts_per_node)
|
| 349 |
|
| 350 |
expert_first_token_offset_bytes = workspace[
|
|
|
|
| 355 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 356 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 357 |
src_to_dest_map_size]
|
| 358 |
+
permuted_row_to_unpermuted_row_bytes = workspace[
|
| 359 |
+
ws_map["permuted_row_to_unpermuted_row"][1]:
|
| 360 |
+
ws_map["permuted_row_to_unpermuted_row"][1] +
|
| 361 |
+
permuted_row_to_unpermuted_row_size]
|
| 362 |
|
| 363 |
if torch.compiler.is_compiling():
|
| 364 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
|
|
| 367 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 368 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 369 |
)
|
| 370 |
+
permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
|
| 371 |
+
permuted_row_to_unpermuted_row_bytes, torch.int32
|
| 372 |
+
)
|
| 373 |
else:
|
| 374 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 375 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 376 |
+
permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
|
| 377 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 378 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 379 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 463 |
is_B_mxfp4=is_mxfp4)
|
| 464 |
|
| 465 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
| 466 |
+
permuted_row_to_unpermuted_row,
|
| 467 |
unpermuted_row_to_permuted_row,
|
| 468 |
+
expert_first_token_offset,
|
| 469 |
num_experts_per_node)
|
| 470 |
return output
|
| 471 |
|
|
|
|
| 514 |
return logits, expert_weights, expert_indices
|
| 515 |
|
| 516 |
|
| 517 |
+
def _get_device_mesh(model):
|
| 518 |
+
"""Extract device_mesh from child's unused pre_hook closure for EP support."""
|
| 519 |
+
try:
|
| 520 |
+
hook = next(
|
| 521 |
+
h
|
| 522 |
+
for h in model.experts._forward_pre_hooks.values()
|
| 523 |
+
if "device_mesh" in h.__code__.co_freevars
|
| 524 |
+
)
|
| 525 |
+
return hook.__closure__[
|
| 526 |
+
hook.__code__.co_freevars.index("device_mesh")
|
| 527 |
+
].cell_contents
|
| 528 |
+
except Exception:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
|
| 532 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 533 |
can_torch_compile: bool = True
|
| 534 |
|
|
|
|
| 553 |
self.experts, "normalize_expert_weights", None
|
| 554 |
)
|
| 555 |
|
| 556 |
+
# Get EP (Expert Parallelism) parameters
|
| 557 |
+
ep_size = 1
|
| 558 |
+
ep_rank = 0
|
| 559 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 560 |
+
if expert_parallel_group is None:
|
| 561 |
+
device_mesh = _get_device_mesh(self)
|
| 562 |
+
if device_mesh is not None:
|
| 563 |
+
expert_parallel_group = device_mesh.get_group()
|
| 564 |
+
if expert_parallel_group is not None:
|
| 565 |
+
import torch.distributed as dist
|
| 566 |
+
if dist.is_initialized():
|
| 567 |
+
ep_size = dist.get_world_size(expert_parallel_group)
|
| 568 |
+
ep_rank = dist.get_rank(expert_parallel_group)
|
| 569 |
+
|
| 570 |
+
# Number of experts on this rank
|
| 571 |
+
num_experts_on_rank = moe_num_experts // ep_size
|
| 572 |
+
|
| 573 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 574 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 575 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
|
|
| 644 |
topk_ids=expert_indices,
|
| 645 |
n_experts_per_token=moe_top_k,
|
| 646 |
activation=activation,
|
| 647 |
+
num_experts=num_experts_on_rank,
|
| 648 |
+
ep_rank=ep_rank,
|
| 649 |
+
ep_size=ep_size,
|
| 650 |
is_fp8=is_fp8,
|
| 651 |
is_int4=is_int4,
|
| 652 |
is_mxfp4=is_mxfp4,
|
| 653 |
)
|
| 654 |
|
| 655 |
+
# All-reduce across EP group to combine partial expert outputs
|
| 656 |
+
if ep_size > 1 and expert_parallel_group is not None:
|
| 657 |
+
import torch.distributed as dist
|
| 658 |
+
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
|
| 659 |
+
|
| 660 |
# Restore original shape
|
| 661 |
output = output.view(in_shape)
|
| 662 |
|
build/torch210-cxx11-cu128-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 21088232
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12705f4547b6a55442c52e081a303d4407202cdc26522f7269c983b627946ab9
|
| 3 |
size 21088232
|
build/torch210-cxx11-cu128-aarch64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_6e04dec
|
| 3 |
+
ops = torch.ops._megablocks_cuda_6e04dec
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_6e04dec::{op_name}"
|
build/torch210-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
CHANGED
|
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
-
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
-
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
|
|
|
|
|
|
| 205 |
is_fp8=False,
|
| 206 |
is_int4=False,
|
| 207 |
is_mxfp4=False):
|
|
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
-
workspace = torch.
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 341 |
workspace=workspace,
|
| 342 |
hidden_size=hidden_size,
|
| 343 |
inter_size=inter_size,
|
|
|
|
|
|
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
expert_first_token_offset_bytes = workspace[
|
|
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
|
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if torch.compiler.is_compiling():
|
| 356 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
|
|
| 359 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
else:
|
| 363 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
|
|
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
|
|
| 451 |
is_B_mxfp4=is_mxfp4)
|
| 452 |
|
| 453 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
|
|
|
| 454 |
unpermuted_row_to_permuted_row,
|
|
|
|
| 455 |
num_experts_per_node)
|
| 456 |
return output
|
| 457 |
|
|
@@ -500,6 +514,21 @@ def route_tokens_xpu(
|
|
| 500 |
return logits, expert_weights, expert_indices
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 504 |
can_torch_compile: bool = True
|
| 505 |
|
|
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 524 |
self.experts, "normalize_expert_weights", None
|
| 525 |
)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 528 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 529 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 598 |
topk_ids=expert_indices,
|
| 599 |
n_experts_per_token=moe_top_k,
|
| 600 |
activation=activation,
|
| 601 |
-
num_experts=
|
|
|
|
|
|
|
| 602 |
is_fp8=is_fp8,
|
| 603 |
is_int4=is_int4,
|
| 604 |
is_mxfp4=is_mxfp4,
|
| 605 |
)
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
# Restore original shape
|
| 608 |
output = output.view(in_shape)
|
| 609 |
|
|
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
| 205 |
+
ep_rank=0,
|
| 206 |
+
ep_size=1,
|
| 207 |
is_fp8=False,
|
| 208 |
is_int4=False,
|
| 209 |
is_mxfp4=False):
|
|
|
|
| 331 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 332 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 333 |
|
| 334 |
+
workspace = torch.zeros(map_offset,
|
| 335 |
dtype=torch.uint8,
|
| 336 |
device=hidden_states.device)
|
| 337 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
workspace=workspace,
|
| 344 |
hidden_size=hidden_size,
|
| 345 |
inter_size=inter_size,
|
| 346 |
+
ep_rank=ep_rank,
|
| 347 |
+
ep_size=ep_size,
|
| 348 |
num_experts_on_rank=num_experts_per_node)
|
| 349 |
|
| 350 |
expert_first_token_offset_bytes = workspace[
|
|
|
|
| 355 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 356 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 357 |
src_to_dest_map_size]
|
| 358 |
+
permuted_row_to_unpermuted_row_bytes = workspace[
|
| 359 |
+
ws_map["permuted_row_to_unpermuted_row"][1]:
|
| 360 |
+
ws_map["permuted_row_to_unpermuted_row"][1] +
|
| 361 |
+
permuted_row_to_unpermuted_row_size]
|
| 362 |
|
| 363 |
if torch.compiler.is_compiling():
|
| 364 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
|
|
| 367 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 368 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 369 |
)
|
| 370 |
+
permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
|
| 371 |
+
permuted_row_to_unpermuted_row_bytes, torch.int32
|
| 372 |
+
)
|
| 373 |
else:
|
| 374 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 375 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 376 |
+
permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
|
| 377 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 378 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 379 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 463 |
is_B_mxfp4=is_mxfp4)
|
| 464 |
|
| 465 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
| 466 |
+
permuted_row_to_unpermuted_row,
|
| 467 |
unpermuted_row_to_permuted_row,
|
| 468 |
+
expert_first_token_offset,
|
| 469 |
num_experts_per_node)
|
| 470 |
return output
|
| 471 |
|
|
|
|
| 514 |
return logits, expert_weights, expert_indices
|
| 515 |
|
| 516 |
|
| 517 |
+
def _get_device_mesh(model):
|
| 518 |
+
"""Extract device_mesh from child's unused pre_hook closure for EP support."""
|
| 519 |
+
try:
|
| 520 |
+
hook = next(
|
| 521 |
+
h
|
| 522 |
+
for h in model.experts._forward_pre_hooks.values()
|
| 523 |
+
if "device_mesh" in h.__code__.co_freevars
|
| 524 |
+
)
|
| 525 |
+
return hook.__closure__[
|
| 526 |
+
hook.__code__.co_freevars.index("device_mesh")
|
| 527 |
+
].cell_contents
|
| 528 |
+
except Exception:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
|
| 532 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 533 |
can_torch_compile: bool = True
|
| 534 |
|
|
|
|
| 553 |
self.experts, "normalize_expert_weights", None
|
| 554 |
)
|
| 555 |
|
| 556 |
+
# Get EP (Expert Parallelism) parameters
|
| 557 |
+
ep_size = 1
|
| 558 |
+
ep_rank = 0
|
| 559 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 560 |
+
if expert_parallel_group is None:
|
| 561 |
+
device_mesh = _get_device_mesh(self)
|
| 562 |
+
if device_mesh is not None:
|
| 563 |
+
expert_parallel_group = device_mesh.get_group()
|
| 564 |
+
if expert_parallel_group is not None:
|
| 565 |
+
import torch.distributed as dist
|
| 566 |
+
if dist.is_initialized():
|
| 567 |
+
ep_size = dist.get_world_size(expert_parallel_group)
|
| 568 |
+
ep_rank = dist.get_rank(expert_parallel_group)
|
| 569 |
+
|
| 570 |
+
# Number of experts on this rank
|
| 571 |
+
num_experts_on_rank = moe_num_experts // ep_size
|
| 572 |
+
|
| 573 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 574 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 575 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
|
|
| 644 |
topk_ids=expert_indices,
|
| 645 |
n_experts_per_token=moe_top_k,
|
| 646 |
activation=activation,
|
| 647 |
+
num_experts=num_experts_on_rank,
|
| 648 |
+
ep_rank=ep_rank,
|
| 649 |
+
ep_size=ep_size,
|
| 650 |
is_fp8=is_fp8,
|
| 651 |
is_int4=is_int4,
|
| 652 |
is_mxfp4=is_mxfp4,
|
| 653 |
)
|
| 654 |
|
| 655 |
+
# All-reduce across EP group to combine partial expert outputs
|
| 656 |
+
if ep_size > 1 and expert_parallel_group is not None:
|
| 657 |
+
import torch.distributed as dist
|
| 658 |
+
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
|
| 659 |
+
|
| 660 |
# Restore original shape
|
| 661 |
output = output.view(in_shape)
|
| 662 |
|
build/torch210-cxx11-cu130-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 12073200
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca7f2de93adbb930ffecaea6953cb94c870333295d05eade3c9c17296aa766a0
|
| 3 |
size 12073200
|
build/torch210-cxx11-cu130-aarch64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_6e04dec
|
| 3 |
+
ops = torch.ops._megablocks_cuda_6e04dec
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_6e04dec::{op_name}"
|
build/torch210-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
CHANGED
|
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
-
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
-
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
|
|
|
|
|
|
| 205 |
is_fp8=False,
|
| 206 |
is_int4=False,
|
| 207 |
is_mxfp4=False):
|
|
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
-
workspace = torch.
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 341 |
workspace=workspace,
|
| 342 |
hidden_size=hidden_size,
|
| 343 |
inter_size=inter_size,
|
|
|
|
|
|
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
expert_first_token_offset_bytes = workspace[
|
|
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
|
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if torch.compiler.is_compiling():
|
| 356 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
|
|
| 359 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
else:
|
| 363 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
|
|
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
|
|
| 451 |
is_B_mxfp4=is_mxfp4)
|
| 452 |
|
| 453 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
|
|
|
| 454 |
unpermuted_row_to_permuted_row,
|
|
|
|
| 455 |
num_experts_per_node)
|
| 456 |
return output
|
| 457 |
|
|
@@ -500,6 +514,21 @@ def route_tokens_xpu(
|
|
| 500 |
return logits, expert_weights, expert_indices
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 504 |
can_torch_compile: bool = True
|
| 505 |
|
|
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 524 |
self.experts, "normalize_expert_weights", None
|
| 525 |
)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 528 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 529 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 598 |
topk_ids=expert_indices,
|
| 599 |
n_experts_per_token=moe_top_k,
|
| 600 |
activation=activation,
|
| 601 |
-
num_experts=
|
|
|
|
|
|
|
| 602 |
is_fp8=is_fp8,
|
| 603 |
is_int4=is_int4,
|
| 604 |
is_mxfp4=is_mxfp4,
|
| 605 |
)
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
# Restore original shape
|
| 608 |
output = output.view(in_shape)
|
| 609 |
|
|
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
| 205 |
+
ep_rank=0,
|
| 206 |
+
ep_size=1,
|
| 207 |
is_fp8=False,
|
| 208 |
is_int4=False,
|
| 209 |
is_mxfp4=False):
|
|
|
|
| 331 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 332 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 333 |
|
| 334 |
+
workspace = torch.zeros(map_offset,
|
| 335 |
dtype=torch.uint8,
|
| 336 |
device=hidden_states.device)
|
| 337 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
workspace=workspace,
|
| 344 |
hidden_size=hidden_size,
|
| 345 |
inter_size=inter_size,
|
| 346 |
+
ep_rank=ep_rank,
|
| 347 |
+
ep_size=ep_size,
|
| 348 |
num_experts_on_rank=num_experts_per_node)
|
| 349 |
|
| 350 |
expert_first_token_offset_bytes = workspace[
|
|
|
|
| 355 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 356 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 357 |
src_to_dest_map_size]
|
| 358 |
+
permuted_row_to_unpermuted_row_bytes = workspace[
|
| 359 |
+
ws_map["permuted_row_to_unpermuted_row"][1]:
|
| 360 |
+
ws_map["permuted_row_to_unpermuted_row"][1] +
|
| 361 |
+
permuted_row_to_unpermuted_row_size]
|
| 362 |
|
| 363 |
if torch.compiler.is_compiling():
|
| 364 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
|
|
| 367 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 368 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 369 |
)
|
| 370 |
+
permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
|
| 371 |
+
permuted_row_to_unpermuted_row_bytes, torch.int32
|
| 372 |
+
)
|
| 373 |
else:
|
| 374 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 375 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 376 |
+
permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
|
| 377 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 378 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 379 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 463 |
is_B_mxfp4=is_mxfp4)
|
| 464 |
|
| 465 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
| 466 |
+
permuted_row_to_unpermuted_row,
|
| 467 |
unpermuted_row_to_permuted_row,
|
| 468 |
+
expert_first_token_offset,
|
| 469 |
num_experts_per_node)
|
| 470 |
return output
|
| 471 |
|
|
|
|
| 514 |
return logits, expert_weights, expert_indices
|
| 515 |
|
| 516 |
|
| 517 |
+
def _get_device_mesh(model):
|
| 518 |
+
"""Extract device_mesh from child's unused pre_hook closure for EP support."""
|
| 519 |
+
try:
|
| 520 |
+
hook = next(
|
| 521 |
+
h
|
| 522 |
+
for h in model.experts._forward_pre_hooks.values()
|
| 523 |
+
if "device_mesh" in h.__code__.co_freevars
|
| 524 |
+
)
|
| 525 |
+
return hook.__closure__[
|
| 526 |
+
hook.__code__.co_freevars.index("device_mesh")
|
| 527 |
+
].cell_contents
|
| 528 |
+
except Exception:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
|
| 532 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 533 |
can_torch_compile: bool = True
|
| 534 |
|
|
|
|
| 553 |
self.experts, "normalize_expert_weights", None
|
| 554 |
)
|
| 555 |
|
| 556 |
+
# Get EP (Expert Parallelism) parameters
|
| 557 |
+
ep_size = 1
|
| 558 |
+
ep_rank = 0
|
| 559 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 560 |
+
if expert_parallel_group is None:
|
| 561 |
+
device_mesh = _get_device_mesh(self)
|
| 562 |
+
if device_mesh is not None:
|
| 563 |
+
expert_parallel_group = device_mesh.get_group()
|
| 564 |
+
if expert_parallel_group is not None:
|
| 565 |
+
import torch.distributed as dist
|
| 566 |
+
if dist.is_initialized():
|
| 567 |
+
ep_size = dist.get_world_size(expert_parallel_group)
|
| 568 |
+
ep_rank = dist.get_rank(expert_parallel_group)
|
| 569 |
+
|
| 570 |
+
# Number of experts on this rank
|
| 571 |
+
num_experts_on_rank = moe_num_experts // ep_size
|
| 572 |
+
|
| 573 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 574 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 575 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
|
|
| 644 |
topk_ids=expert_indices,
|
| 645 |
n_experts_per_token=moe_top_k,
|
| 646 |
activation=activation,
|
| 647 |
+
num_experts=num_experts_on_rank,
|
| 648 |
+
ep_rank=ep_rank,
|
| 649 |
+
ep_size=ep_size,
|
| 650 |
is_fp8=is_fp8,
|
| 651 |
is_int4=is_int4,
|
| 652 |
is_mxfp4=is_mxfp4,
|
| 653 |
)
|
| 654 |
|
| 655 |
+
# All-reduce across EP group to combine partial expert outputs
|
| 656 |
+
if ep_size > 1 and expert_parallel_group is not None:
|
| 657 |
+
import torch.distributed as dist
|
| 658 |
+
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
|
| 659 |
+
|
| 660 |
# Restore original shape
|
| 661 |
output = output.view(in_shape)
|
| 662 |
|
build/torch29-cxx11-cu126-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15121720
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:581f5d3cd17031f674e6da22c23430881408630004e4ece5a57f9c36583665b5
|
| 3 |
size 15121720
|
build/torch29-cxx11-cu126-aarch64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_6e04dec
|
| 3 |
+
ops = torch.ops._megablocks_cuda_6e04dec
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_6e04dec::{op_name}"
|
build/torch29-cxx11-cu126-aarch64-linux/xpu_fused_moe.py
CHANGED
|
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
-
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
-
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
|
|
|
|
|
|
| 205 |
is_fp8=False,
|
| 206 |
is_int4=False,
|
| 207 |
is_mxfp4=False):
|
|
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
-
workspace = torch.
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 341 |
workspace=workspace,
|
| 342 |
hidden_size=hidden_size,
|
| 343 |
inter_size=inter_size,
|
|
|
|
|
|
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
expert_first_token_offset_bytes = workspace[
|
|
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
|
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if torch.compiler.is_compiling():
|
| 356 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
|
|
| 359 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
else:
|
| 363 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
|
|
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
|
|
| 451 |
is_B_mxfp4=is_mxfp4)
|
| 452 |
|
| 453 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
|
|
|
| 454 |
unpermuted_row_to_permuted_row,
|
|
|
|
| 455 |
num_experts_per_node)
|
| 456 |
return output
|
| 457 |
|
|
@@ -500,6 +514,21 @@ def route_tokens_xpu(
|
|
| 500 |
return logits, expert_weights, expert_indices
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 504 |
can_torch_compile: bool = True
|
| 505 |
|
|
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 524 |
self.experts, "normalize_expert_weights", None
|
| 525 |
)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 528 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 529 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 598 |
topk_ids=expert_indices,
|
| 599 |
n_experts_per_token=moe_top_k,
|
| 600 |
activation=activation,
|
| 601 |
-
num_experts=
|
|
|
|
|
|
|
| 602 |
is_fp8=is_fp8,
|
| 603 |
is_int4=is_int4,
|
| 604 |
is_mxfp4=is_mxfp4,
|
| 605 |
)
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
# Restore original shape
|
| 608 |
output = output.view(in_shape)
|
| 609 |
|
|
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
| 205 |
+
ep_rank=0,
|
| 206 |
+
ep_size=1,
|
| 207 |
is_fp8=False,
|
| 208 |
is_int4=False,
|
| 209 |
is_mxfp4=False):
|
|
|
|
| 331 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 332 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 333 |
|
| 334 |
+
workspace = torch.zeros(map_offset,
|
| 335 |
dtype=torch.uint8,
|
| 336 |
device=hidden_states.device)
|
| 337 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
workspace=workspace,
|
| 344 |
hidden_size=hidden_size,
|
| 345 |
inter_size=inter_size,
|
| 346 |
+
ep_rank=ep_rank,
|
| 347 |
+
ep_size=ep_size,
|
| 348 |
num_experts_on_rank=num_experts_per_node)
|
| 349 |
|
| 350 |
expert_first_token_offset_bytes = workspace[
|
|
|
|
| 355 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 356 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 357 |
src_to_dest_map_size]
|
| 358 |
+
permuted_row_to_unpermuted_row_bytes = workspace[
|
| 359 |
+
ws_map["permuted_row_to_unpermuted_row"][1]:
|
| 360 |
+
ws_map["permuted_row_to_unpermuted_row"][1] +
|
| 361 |
+
permuted_row_to_unpermuted_row_size]
|
| 362 |
|
| 363 |
if torch.compiler.is_compiling():
|
| 364 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
|
|
| 367 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 368 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 369 |
)
|
| 370 |
+
permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
|
| 371 |
+
permuted_row_to_unpermuted_row_bytes, torch.int32
|
| 372 |
+
)
|
| 373 |
else:
|
| 374 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 375 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 376 |
+
permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
|
| 377 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 378 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 379 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 463 |
is_B_mxfp4=is_mxfp4)
|
| 464 |
|
| 465 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
| 466 |
+
permuted_row_to_unpermuted_row,
|
| 467 |
unpermuted_row_to_permuted_row,
|
| 468 |
+
expert_first_token_offset,
|
| 469 |
num_experts_per_node)
|
| 470 |
return output
|
| 471 |
|
|
|
|
| 514 |
return logits, expert_weights, expert_indices
|
| 515 |
|
| 516 |
|
| 517 |
+
def _get_device_mesh(model):
|
| 518 |
+
"""Extract device_mesh from child's unused pre_hook closure for EP support."""
|
| 519 |
+
try:
|
| 520 |
+
hook = next(
|
| 521 |
+
h
|
| 522 |
+
for h in model.experts._forward_pre_hooks.values()
|
| 523 |
+
if "device_mesh" in h.__code__.co_freevars
|
| 524 |
+
)
|
| 525 |
+
return hook.__closure__[
|
| 526 |
+
hook.__code__.co_freevars.index("device_mesh")
|
| 527 |
+
].cell_contents
|
| 528 |
+
except Exception:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
|
| 532 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 533 |
can_torch_compile: bool = True
|
| 534 |
|
|
|
|
| 553 |
self.experts, "normalize_expert_weights", None
|
| 554 |
)
|
| 555 |
|
| 556 |
+
# Get EP (Expert Parallelism) parameters
|
| 557 |
+
ep_size = 1
|
| 558 |
+
ep_rank = 0
|
| 559 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 560 |
+
if expert_parallel_group is None:
|
| 561 |
+
device_mesh = _get_device_mesh(self)
|
| 562 |
+
if device_mesh is not None:
|
| 563 |
+
expert_parallel_group = device_mesh.get_group()
|
| 564 |
+
if expert_parallel_group is not None:
|
| 565 |
+
import torch.distributed as dist
|
| 566 |
+
if dist.is_initialized():
|
| 567 |
+
ep_size = dist.get_world_size(expert_parallel_group)
|
| 568 |
+
ep_rank = dist.get_rank(expert_parallel_group)
|
| 569 |
+
|
| 570 |
+
# Number of experts on this rank
|
| 571 |
+
num_experts_on_rank = moe_num_experts // ep_size
|
| 572 |
+
|
| 573 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 574 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 575 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
|
|
| 644 |
topk_ids=expert_indices,
|
| 645 |
n_experts_per_token=moe_top_k,
|
| 646 |
activation=activation,
|
| 647 |
+
num_experts=num_experts_on_rank,
|
| 648 |
+
ep_rank=ep_rank,
|
| 649 |
+
ep_size=ep_size,
|
| 650 |
is_fp8=is_fp8,
|
| 651 |
is_int4=is_int4,
|
| 652 |
is_mxfp4=is_mxfp4,
|
| 653 |
)
|
| 654 |
|
| 655 |
+
# All-reduce across EP group to combine partial expert outputs
|
| 656 |
+
if ep_size > 1 and expert_parallel_group is not None:
|
| 657 |
+
import torch.distributed as dist
|
| 658 |
+
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
|
| 659 |
+
|
| 660 |
# Restore original shape
|
| 661 |
output = output.view(in_shape)
|
| 662 |
|
build/torch29-cxx11-cu128-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 21085456
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:81684a3eed6a7fb374cdbba3cf65f1cd46f5392ddc6d4992d37186c3b15f5734
|
| 3 |
size 21085456
|
build/torch29-cxx11-cu128-aarch64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_6e04dec
|
| 3 |
+
ops = torch.ops._megablocks_cuda_6e04dec
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_6e04dec::{op_name}"
|
build/torch29-cxx11-cu128-aarch64-linux/xpu_fused_moe.py
CHANGED
|
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
-
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
-
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
|
|
|
|
|
|
| 205 |
is_fp8=False,
|
| 206 |
is_int4=False,
|
| 207 |
is_mxfp4=False):
|
|
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
-
workspace = torch.
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 341 |
workspace=workspace,
|
| 342 |
hidden_size=hidden_size,
|
| 343 |
inter_size=inter_size,
|
|
|
|
|
|
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
expert_first_token_offset_bytes = workspace[
|
|
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
|
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if torch.compiler.is_compiling():
|
| 356 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
|
|
| 359 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
else:
|
| 363 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
|
|
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
|
|
| 451 |
is_B_mxfp4=is_mxfp4)
|
| 452 |
|
| 453 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
|
|
|
| 454 |
unpermuted_row_to_permuted_row,
|
|
|
|
| 455 |
num_experts_per_node)
|
| 456 |
return output
|
| 457 |
|
|
@@ -500,6 +514,21 @@ def route_tokens_xpu(
|
|
| 500 |
return logits, expert_weights, expert_indices
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 504 |
can_torch_compile: bool = True
|
| 505 |
|
|
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 524 |
self.experts, "normalize_expert_weights", None
|
| 525 |
)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 528 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 529 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 598 |
topk_ids=expert_indices,
|
| 599 |
n_experts_per_token=moe_top_k,
|
| 600 |
activation=activation,
|
| 601 |
-
num_experts=
|
|
|
|
|
|
|
| 602 |
is_fp8=is_fp8,
|
| 603 |
is_int4=is_int4,
|
| 604 |
is_mxfp4=is_mxfp4,
|
| 605 |
)
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
# Restore original shape
|
| 608 |
output = output.view(in_shape)
|
| 609 |
|
|
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
| 205 |
+
ep_rank=0,
|
| 206 |
+
ep_size=1,
|
| 207 |
is_fp8=False,
|
| 208 |
is_int4=False,
|
| 209 |
is_mxfp4=False):
|
|
|
|
| 331 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 332 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 333 |
|
| 334 |
+
workspace = torch.zeros(map_offset,
|
| 335 |
dtype=torch.uint8,
|
| 336 |
device=hidden_states.device)
|
| 337 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
workspace=workspace,
|
| 344 |
hidden_size=hidden_size,
|
| 345 |
inter_size=inter_size,
|
| 346 |
+
ep_rank=ep_rank,
|
| 347 |
+
ep_size=ep_size,
|
| 348 |
num_experts_on_rank=num_experts_per_node)
|
| 349 |
|
| 350 |
expert_first_token_offset_bytes = workspace[
|
|
|
|
| 355 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 356 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 357 |
src_to_dest_map_size]
|
| 358 |
+
permuted_row_to_unpermuted_row_bytes = workspace[
|
| 359 |
+
ws_map["permuted_row_to_unpermuted_row"][1]:
|
| 360 |
+
ws_map["permuted_row_to_unpermuted_row"][1] +
|
| 361 |
+
permuted_row_to_unpermuted_row_size]
|
| 362 |
|
| 363 |
if torch.compiler.is_compiling():
|
| 364 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
|
|
| 367 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 368 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 369 |
)
|
| 370 |
+
permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
|
| 371 |
+
permuted_row_to_unpermuted_row_bytes, torch.int32
|
| 372 |
+
)
|
| 373 |
else:
|
| 374 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 375 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 376 |
+
permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
|
| 377 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 378 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 379 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 463 |
is_B_mxfp4=is_mxfp4)
|
| 464 |
|
| 465 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
| 466 |
+
permuted_row_to_unpermuted_row,
|
| 467 |
unpermuted_row_to_permuted_row,
|
| 468 |
+
expert_first_token_offset,
|
| 469 |
num_experts_per_node)
|
| 470 |
return output
|
| 471 |
|
|
|
|
| 514 |
return logits, expert_weights, expert_indices
|
| 515 |
|
| 516 |
|
| 517 |
+
def _get_device_mesh(model):
|
| 518 |
+
"""Extract device_mesh from child's unused pre_hook closure for EP support."""
|
| 519 |
+
try:
|
| 520 |
+
hook = next(
|
| 521 |
+
h
|
| 522 |
+
for h in model.experts._forward_pre_hooks.values()
|
| 523 |
+
if "device_mesh" in h.__code__.co_freevars
|
| 524 |
+
)
|
| 525 |
+
return hook.__closure__[
|
| 526 |
+
hook.__code__.co_freevars.index("device_mesh")
|
| 527 |
+
].cell_contents
|
| 528 |
+
except Exception:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
|
| 532 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 533 |
can_torch_compile: bool = True
|
| 534 |
|
|
|
|
| 553 |
self.experts, "normalize_expert_weights", None
|
| 554 |
)
|
| 555 |
|
| 556 |
+
# Get EP (Expert Parallelism) parameters
|
| 557 |
+
ep_size = 1
|
| 558 |
+
ep_rank = 0
|
| 559 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 560 |
+
if expert_parallel_group is None:
|
| 561 |
+
device_mesh = _get_device_mesh(self)
|
| 562 |
+
if device_mesh is not None:
|
| 563 |
+
expert_parallel_group = device_mesh.get_group()
|
| 564 |
+
if expert_parallel_group is not None:
|
| 565 |
+
import torch.distributed as dist
|
| 566 |
+
if dist.is_initialized():
|
| 567 |
+
ep_size = dist.get_world_size(expert_parallel_group)
|
| 568 |
+
ep_rank = dist.get_rank(expert_parallel_group)
|
| 569 |
+
|
| 570 |
+
# Number of experts on this rank
|
| 571 |
+
num_experts_on_rank = moe_num_experts // ep_size
|
| 572 |
+
|
| 573 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 574 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 575 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
|
|
| 644 |
topk_ids=expert_indices,
|
| 645 |
n_experts_per_token=moe_top_k,
|
| 646 |
activation=activation,
|
| 647 |
+
num_experts=num_experts_on_rank,
|
| 648 |
+
ep_rank=ep_rank,
|
| 649 |
+
ep_size=ep_size,
|
| 650 |
is_fp8=is_fp8,
|
| 651 |
is_int4=is_int4,
|
| 652 |
is_mxfp4=is_mxfp4,
|
| 653 |
)
|
| 654 |
|
| 655 |
+
# All-reduce across EP group to combine partial expert outputs
|
| 656 |
+
if ep_size > 1 and expert_parallel_group is not None:
|
| 657 |
+
import torch.distributed as dist
|
| 658 |
+
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
|
| 659 |
+
|
| 660 |
# Restore original shape
|
| 661 |
output = output.view(in_shape)
|
| 662 |
|
build/torch29-cxx11-cu130-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 12070448
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8669b2a5cf6f36ab1d6c518040d4f4e2874d7b1c5880b4424d21f89c60e77c5f
|
| 3 |
size 12070448
|
build/torch29-cxx11-cu130-aarch64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_cuda_6e04dec
|
| 3 |
+
ops = torch.ops._megablocks_cuda_6e04dec
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_cuda_6e04dec::{op_name}"
|
build/torch29-cxx11-cu130-aarch64-linux/xpu_fused_moe.py
CHANGED
|
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
-
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
-
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
|
|
|
|
|
|
| 205 |
is_fp8=False,
|
| 206 |
is_int4=False,
|
| 207 |
is_mxfp4=False):
|
|
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
-
workspace = torch.
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
|
|
| 341 |
workspace=workspace,
|
| 342 |
hidden_size=hidden_size,
|
| 343 |
inter_size=inter_size,
|
|
|
|
|
|
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
expert_first_token_offset_bytes = workspace[
|
|
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
|
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if torch.compiler.is_compiling():
|
| 356 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
|
|
| 359 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
else:
|
| 363 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
|
|
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
|
|
| 451 |
is_B_mxfp4=is_mxfp4)
|
| 452 |
|
| 453 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
|
|
|
| 454 |
unpermuted_row_to_permuted_row,
|
|
|
|
| 455 |
num_experts_per_node)
|
| 456 |
return output
|
| 457 |
|
|
@@ -500,6 +514,21 @@ def route_tokens_xpu(
|
|
| 500 |
return logits, expert_weights, expert_indices
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 504 |
can_torch_compile: bool = True
|
| 505 |
|
|
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 524 |
self.experts, "normalize_expert_weights", None
|
| 525 |
)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 528 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 529 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 598 |
topk_ids=expert_indices,
|
| 599 |
n_experts_per_token=moe_top_k,
|
| 600 |
activation=activation,
|
| 601 |
-
num_experts=
|
|
|
|
|
|
|
| 602 |
is_fp8=is_fp8,
|
| 603 |
is_int4=is_int4,
|
| 604 |
is_mxfp4=is_mxfp4,
|
| 605 |
)
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
# Restore original shape
|
| 608 |
output = output.view(in_shape)
|
| 609 |
|
|
|
|
| 31 |
|
| 32 |
_register_if_available(
|
| 33 |
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
|
| 35 |
)
|
| 36 |
|
| 37 |
_register_if_available(
|
| 38 |
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
|
| 40 |
)
|
| 41 |
|
| 42 |
_register_if_available(
|
|
|
|
| 202 |
n_experts_per_token,
|
| 203 |
activation,
|
| 204 |
num_experts,
|
| 205 |
+
ep_rank=0,
|
| 206 |
+
ep_size=1,
|
| 207 |
is_fp8=False,
|
| 208 |
is_int4=False,
|
| 209 |
is_mxfp4=False):
|
|
|
|
| 331 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 332 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 333 |
|
| 334 |
+
workspace = torch.zeros(map_offset,
|
| 335 |
dtype=torch.uint8,
|
| 336 |
device=hidden_states.device)
|
| 337 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
workspace=workspace,
|
| 344 |
hidden_size=hidden_size,
|
| 345 |
inter_size=inter_size,
|
| 346 |
+
ep_rank=ep_rank,
|
| 347 |
+
ep_size=ep_size,
|
| 348 |
num_experts_on_rank=num_experts_per_node)
|
| 349 |
|
| 350 |
expert_first_token_offset_bytes = workspace[
|
|
|
|
| 355 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 356 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 357 |
src_to_dest_map_size]
|
| 358 |
+
permuted_row_to_unpermuted_row_bytes = workspace[
|
| 359 |
+
ws_map["permuted_row_to_unpermuted_row"][1]:
|
| 360 |
+
ws_map["permuted_row_to_unpermuted_row"][1] +
|
| 361 |
+
permuted_row_to_unpermuted_row_size]
|
| 362 |
|
| 363 |
if torch.compiler.is_compiling():
|
| 364 |
expert_first_token_offset = _bytes_to_typed_tensor(
|
|
|
|
| 367 |
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 368 |
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 369 |
)
|
| 370 |
+
permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
|
| 371 |
+
permuted_row_to_unpermuted_row_bytes, torch.int32
|
| 372 |
+
)
|
| 373 |
else:
|
| 374 |
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 375 |
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 376 |
+
permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
|
| 377 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 378 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 379 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 463 |
is_B_mxfp4=is_mxfp4)
|
| 464 |
|
| 465 |
ops.moe_gather(output, gemm2_output, topk_weights,
|
| 466 |
+
permuted_row_to_unpermuted_row,
|
| 467 |
unpermuted_row_to_permuted_row,
|
| 468 |
+
expert_first_token_offset,
|
| 469 |
num_experts_per_node)
|
| 470 |
return output
|
| 471 |
|
|
|
|
| 514 |
return logits, expert_weights, expert_indices
|
| 515 |
|
| 516 |
|
| 517 |
+
def _get_device_mesh(model):
|
| 518 |
+
"""Extract device_mesh from child's unused pre_hook closure for EP support."""
|
| 519 |
+
try:
|
| 520 |
+
hook = next(
|
| 521 |
+
h
|
| 522 |
+
for h in model.experts._forward_pre_hooks.values()
|
| 523 |
+
if "device_mesh" in h.__code__.co_freevars
|
| 524 |
+
)
|
| 525 |
+
return hook.__closure__[
|
| 526 |
+
hook.__code__.co_freevars.index("device_mesh")
|
| 527 |
+
].cell_contents
|
| 528 |
+
except Exception:
|
| 529 |
+
return None
|
| 530 |
+
|
| 531 |
+
|
| 532 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 533 |
can_torch_compile: bool = True
|
| 534 |
|
|
|
|
| 553 |
self.experts, "normalize_expert_weights", None
|
| 554 |
)
|
| 555 |
|
| 556 |
+
# Get EP (Expert Parallelism) parameters
|
| 557 |
+
ep_size = 1
|
| 558 |
+
ep_rank = 0
|
| 559 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 560 |
+
if expert_parallel_group is None:
|
| 561 |
+
device_mesh = _get_device_mesh(self)
|
| 562 |
+
if device_mesh is not None:
|
| 563 |
+
expert_parallel_group = device_mesh.get_group()
|
| 564 |
+
if expert_parallel_group is not None:
|
| 565 |
+
import torch.distributed as dist
|
| 566 |
+
if dist.is_initialized():
|
| 567 |
+
ep_size = dist.get_world_size(expert_parallel_group)
|
| 568 |
+
ep_rank = dist.get_rank(expert_parallel_group)
|
| 569 |
+
|
| 570 |
+
# Number of experts on this rank
|
| 571 |
+
num_experts_on_rank = moe_num_experts // ep_size
|
| 572 |
+
|
| 573 |
# Detect activation type - check for GptOss-style swigluoai activation
|
| 574 |
# GptOssExperts has alpha and limit attributes for swigluoai
|
| 575 |
if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
|
|
|
|
| 644 |
topk_ids=expert_indices,
|
| 645 |
n_experts_per_token=moe_top_k,
|
| 646 |
activation=activation,
|
| 647 |
+
num_experts=num_experts_on_rank,
|
| 648 |
+
ep_rank=ep_rank,
|
| 649 |
+
ep_size=ep_size,
|
| 650 |
is_fp8=is_fp8,
|
| 651 |
is_int4=is_int4,
|
| 652 |
is_mxfp4=is_mxfp4,
|
| 653 |
)
|
| 654 |
|
| 655 |
+
# All-reduce across EP group to combine partial expert outputs
|
| 656 |
+
if ep_size > 1 and expert_parallel_group is not None:
|
| 657 |
+
import torch.distributed as dist
|
| 658 |
+
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
|
| 659 |
+
|
| 660 |
# Restore original shape
|
| 661 |
output = output.view(in_shape)
|
| 662 |
|