Kernels
danieldk HF Staff commited on
Commit
1a70d6e
Β·
verified Β·
1 Parent(s): af99866

Build uploaded using `kernels`.

Browse files
Files changed (30) hide show
  1. build/torch210-cxx11-cpu-x86_64-linux/{_megablocks_cpu_dd32462.abi3.so β†’ _megablocks_cpu_6e04dec.abi3.so} +1 -1
  2. build/torch210-cxx11-cpu-x86_64-linux/_ops.py +3 -3
  3. build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py +57 -4
  4. build/torch210-cxx11-cu126-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} +1 -1
  5. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  6. build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py +57 -4
  7. build/torch210-cxx11-cu128-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} +1 -1
  8. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  9. build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py +57 -4
  10. build/torch210-cxx11-cu130-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} +1 -1
  11. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  12. build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py +57 -4
  13. build/torch210-cxx11-xpu20253-x86_64-linux/{_megablocks_xpu_dd32462.abi3.so β†’ _megablocks_xpu_6e04dec.abi3.so} +2 -2
  14. build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py +3 -3
  15. build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py +57 -4
  16. build/torch29-cxx11-cpu-x86_64-linux/{_megablocks_cpu_dd32462.abi3.so β†’ _megablocks_cpu_6e04dec.abi3.so} +1 -1
  17. build/torch29-cxx11-cpu-x86_64-linux/_ops.py +3 -3
  18. build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py +57 -4
  19. build/torch29-cxx11-cu126-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} +1 -1
  20. build/torch29-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  21. build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py +57 -4
  22. build/torch29-cxx11-cu128-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} +1 -1
  23. build/torch29-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  24. build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py +57 -4
  25. build/torch29-cxx11-cu130-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} +1 -1
  26. build/torch29-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  27. build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py +57 -4
  28. build/torch29-cxx11-xpu20252-x86_64-linux/{_megablocks_xpu_dd32462.abi3.so β†’ _megablocks_xpu_6e04dec.abi3.so} +2 -2
  29. build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py +3 -3
  30. build/torch29-cxx11-xpu20252-x86_64-linux/xpu_fused_moe.py +57 -4
build/torch210-cxx11-cpu-x86_64-linux/{_megablocks_cpu_dd32462.abi3.so β†’ _megablocks_cpu_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2d456fb003229668826de9c88890c3bd0af4b2bdc313140017b75e3ed2e553c8
3
  size 2219080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70b79b772262fee7ee79153a54dc208c9166f4c34680f752b7bc2ce8d8ae1f74
3
  size 2219080
build/torch210-cxx11-cpu-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cpu_dd32462
3
- ops = torch.ops._megablocks_cpu_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cpu_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cpu_6e04dec
3
+ ops = torch.ops._megablocks_cpu_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cpu_6e04dec::{op_name}"
build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch210-cxx11-cu126-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a86d0f456034ce972bae452ab95b8fb8fcf24e015985f74678fca4b673fc50dc
3
  size 15061056
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55948eae893317a5e500315e47efd66c4482bb67449caef3f512b2cabffb7dc6
3
  size 15061056
build/torch210-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch210-cxx11-cu128-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:000b7e99c0d0afc09167bb819cb7c4de4f2f3de7136744020d58a9f0fd51a24a
3
  size 21009984
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e66fd44576448dc82e7392db0c935cd8654bfcb51db51ddc044e1c33bc82c60
3
  size 21009984
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch210-cxx11-cu130-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e79c036505112f43afa8658e1e04dcc5bf536a297560fd03e7070ebbe21e2b54
3
  size 12041592
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed503a781293a9d6150e0362edbe9360ef6e58590b511ee23596649ee9a437d
3
  size 12041592
build/torch210-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch210-cxx11-xpu20253-x86_64-linux/{_megablocks_xpu_dd32462.abi3.so β†’ _megablocks_xpu_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1eafe2cfbec6a1c65fc7b523e3abdc3270f2e85b4f4bc64b88b9aeb41698484c
3
- size 5331960
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46cfa6050944b0bd6daeaf4848fe5393a68397ae29a5c7f0a04280e287cb0e7d
3
+ size 5381760
build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_xpu_dd32462
3
- ops = torch.ops._megablocks_xpu_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_xpu_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_xpu_6e04dec
3
+ ops = torch.ops._megablocks_xpu_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_xpu_6e04dec::{op_name}"
build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-cpu-x86_64-linux/{_megablocks_cpu_dd32462.abi3.so β†’ _megablocks_cpu_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e97831d51919986a68bcb622208e131d3b66fa4a83b99da941b77708dd522edc
3
  size 2201200
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18348238274eb1b281afe628b09ca6a4a5b8267370aaed7bf34a2bd91c9b815b
3
  size 2201200
build/torch29-cxx11-cpu-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cpu_dd32462
3
- ops = torch.ops._megablocks_cpu_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cpu_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cpu_6e04dec
3
+ ops = torch.ops._megablocks_cpu_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cpu_6e04dec::{op_name}"
build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-cu126-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a66815ed63ba2be83feeae47d8255e2d72ce1cd5e6a0e9f92d063e2cb81a522
3
  size 15046832
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fae42809a452f57bb4ef6967a397029f4e557ad73424c1b68fb613070dcd3f0d
3
  size 15046832
build/torch29-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-cu128-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:da033fc8fa10230b35ddd1d0a45dc29bc44462739d0bb70ac7373cf5864b6634
3
  size 20995704
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0349d7de015576f9dae76f82c321d491609d1ae84bc5f2cb8053891e167a0aca
3
  size 20995704
build/torch29-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-cu130-x86_64-linux/{_megablocks_cuda_dd32462.abi3.so β†’ _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5e7d5b3f6e60dd4e7a0ada343f1665ccb8d5daf8b535808b9f455efe022a2783
3
  size 12031416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e1383adbf7afa208f0769d84a826fcd43de9ee9ce39d676ebce97698759c526
3
  size 12031416
build/torch29-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-xpu20252-x86_64-linux/{_megablocks_xpu_dd32462.abi3.so β†’ _megablocks_xpu_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b8c5866635254310bb5a96138b0c7e78dc4509730ada6eb0a4c7d8d112c0585e
3
- size 5192232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02442f31668da97521b3301b613a9acaa3478b83bfe838213ec690f7412c0157
3
+ size 5197008
build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_xpu_dd32462
3
- ops = torch.ops._megablocks_xpu_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_xpu_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_xpu_6e04dec
3
+ ops = torch.ops._megablocks_xpu_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_xpu_6e04dec::{op_name}"
build/torch29-cxx11-xpu20252-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662