Kernels
danieldk HF Staff commited on
Commit
af99866
·
verified ·
1 Parent(s): 4b022f0

Build uploaded using `kernels`.

Browse files
build/torch210-cxx11-cu126-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c22a3296d294dd5d350de36f19583cf331e7cf8a75c4afb2cce263116b149316
3
  size 15124328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d43ea617155587acccc47750e126596b0438c63c7ada6f3607a2ed4603337f72
3
  size 15124328
build/torch210-cxx11-cu126-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch210-cxx11-cu126-aarch64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch210-cxx11-cu128-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ba9d77e7863a9a03527ddfe47cc818b1089f384930e969c575be2da559c052f5
3
  size 21088232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12705f4547b6a55442c52e081a303d4407202cdc26522f7269c983b627946ab9
3
  size 21088232
build/torch210-cxx11-cu128-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch210-cxx11-cu128-aarch64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch210-cxx11-cu130-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:71bc5b507fe6153b2bdcd4f54a763bc90cf863baa85f026afd25d4eb1a82adb6
3
  size 12073200
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca7f2de93adbb930ffecaea6953cb94c870333295d05eade3c9c17296aa766a0
3
  size 12073200
build/torch210-cxx11-cu130-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch210-cxx11-cu130-aarch64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-cu126-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fe1c804c13c22f3a6ba6b7d104ca92da9734c3dc463f085384eb83750769a96
3
  size 15121720
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:581f5d3cd17031f674e6da22c23430881408630004e4ece5a57f9c36583665b5
3
  size 15121720
build/torch29-cxx11-cu126-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch29-cxx11-cu126-aarch64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-cu128-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:64e119261f728c64bc971e5f9017b472ddaae621ee0527fabcea6b8e6dd7f815
3
  size 21085456
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81684a3eed6a7fb374cdbba3cf65f1cd46f5392ddc6d4992d37186c3b15f5734
3
  size 21085456
build/torch29-cxx11-cu128-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch29-cxx11-cu128-aarch64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662
 
build/torch29-cxx11-cu130-aarch64-linux/{_megablocks_cuda_dd32462.abi3.so → _megablocks_cuda_6e04dec.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ee72001370d93ef391b01bca733d227ba8dad92eb99ce7cb51fd97d5589f0ac
3
  size 12070448
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8669b2a5cf6f36ab1d6c518040d4f4e2874d7b1c5880b4424d21f89c60e77c5f
3
  size 12070448
build/torch29-cxx11-cu130-aarch64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_cuda_dd32462
3
- ops = torch.ops._megablocks_cuda_dd32462
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_cuda_dd32462::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_cuda_6e04dec
3
+ ops = torch.ops._megablocks_cuda_6e04dec
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_cuda_6e04dec::{op_name}"
build/torch29-cxx11-cu130-aarch64-linux/xpu_fused_moe.py CHANGED
@@ -31,12 +31,12 @@ def _register_xpu_fake_kernels():
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
- lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
- lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
  )
41
 
42
  _register_if_available(
@@ -202,6 +202,8 @@ def xpu_fused_moe(hidden_states,
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
 
 
205
  is_fp8=False,
206
  is_int4=False,
207
  is_mxfp4=False):
@@ -329,7 +331,7 @@ def xpu_fused_moe(hidden_states,
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
- workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
@@ -341,6 +343,8 @@ def xpu_fused_moe(hidden_states,
341
  workspace=workspace,
342
  hidden_size=hidden_size,
343
  inter_size=inter_size,
 
 
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
  expert_first_token_offset_bytes = workspace[
@@ -351,6 +355,10 @@ def xpu_fused_moe(hidden_states,
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
  src_to_dest_map_size]
 
 
 
 
354
 
355
  if torch.compiler.is_compiling():
356
  expert_first_token_offset = _bytes_to_typed_tensor(
@@ -359,9 +367,13 @@ def xpu_fused_moe(hidden_states,
359
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
  unpermuted_row_to_permuted_row_bytes, torch.int32
361
  )
 
 
 
362
  else:
363
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
 
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
@@ -451,7 +463,9 @@ def xpu_fused_moe(hidden_states,
451
  is_B_mxfp4=is_mxfp4)
452
 
453
  ops.moe_gather(output, gemm2_output, topk_weights,
 
454
  unpermuted_row_to_permuted_row,
 
455
  num_experts_per_node)
456
  return output
457
 
@@ -500,6 +514,21 @@ def route_tokens_xpu(
500
  return logits, expert_weights, expert_indices
501
 
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  class MegaBlocksMoeMLP(torch.nn.Module):
504
  can_torch_compile: bool = True
505
 
@@ -524,6 +553,23 @@ class MegaBlocksMoeMLP(torch.nn.Module):
524
  self.experts, "normalize_expert_weights", None
525
  )
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Detect activation type - check for GptOss-style swigluoai activation
528
  # GptOssExperts has alpha and limit attributes for swigluoai
529
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
@@ -598,12 +644,19 @@ class MegaBlocksMoeMLP(torch.nn.Module):
598
  topk_ids=expert_indices,
599
  n_experts_per_token=moe_top_k,
600
  activation=activation,
601
- num_experts=moe_num_experts,
 
 
602
  is_fp8=is_fp8,
603
  is_int4=is_int4,
604
  is_mxfp4=is_mxfp4,
605
  )
606
 
 
 
 
 
 
607
  # Restore original shape
608
  output = output.view(in_shape)
609
 
 
31
 
32
  _register_if_available(
33
  "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, ep_rank, ep_size, num_experts_on_rank: None,
35
  )
36
 
37
  _register_if_available(
38
  "moe_gather",
39
+ lambda output, moe_output, topk_weights, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_experts: None,
40
  )
41
 
42
  _register_if_available(
 
202
  n_experts_per_token,
203
  activation,
204
  num_experts,
205
+ ep_rank=0,
206
+ ep_size=1,
207
  is_fp8=False,
208
  is_int4=False,
209
  is_mxfp4=False):
 
331
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
332
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
333
 
334
+ workspace = torch.zeros(map_offset,
335
  dtype=torch.uint8,
336
  device=hidden_states.device)
337
  if topk_ids.dtype == torch.int32:
 
343
  workspace=workspace,
344
  hidden_size=hidden_size,
345
  inter_size=inter_size,
346
+ ep_rank=ep_rank,
347
+ ep_size=ep_size,
348
  num_experts_on_rank=num_experts_per_node)
349
 
350
  expert_first_token_offset_bytes = workspace[
 
355
  ws_map["unpermuted_row_to_permuted_row"][1]:
356
  ws_map["unpermuted_row_to_permuted_row"][1] +
357
  src_to_dest_map_size]
358
+ permuted_row_to_unpermuted_row_bytes = workspace[
359
+ ws_map["permuted_row_to_unpermuted_row"][1]:
360
+ ws_map["permuted_row_to_unpermuted_row"][1] +
361
+ permuted_row_to_unpermuted_row_size]
362
 
363
  if torch.compiler.is_compiling():
364
  expert_first_token_offset = _bytes_to_typed_tensor(
 
367
  unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
368
  unpermuted_row_to_permuted_row_bytes, torch.int32
369
  )
370
+ permuted_row_to_unpermuted_row = _bytes_to_typed_tensor(
371
+ permuted_row_to_unpermuted_row_bytes, torch.int32
372
+ )
373
  else:
374
  expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
375
  unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
376
+ permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_bytes.view(torch.int32)
377
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
378
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
379
  permuted_data_size].view(hidden_states.dtype).view(
 
463
  is_B_mxfp4=is_mxfp4)
464
 
465
  ops.moe_gather(output, gemm2_output, topk_weights,
466
+ permuted_row_to_unpermuted_row,
467
  unpermuted_row_to_permuted_row,
468
+ expert_first_token_offset,
469
  num_experts_per_node)
470
  return output
471
 
 
514
  return logits, expert_weights, expert_indices
515
 
516
 
517
+ def _get_device_mesh(model):
518
+ """Extract device_mesh from child's unused pre_hook closure for EP support."""
519
+ try:
520
+ hook = next(
521
+ h
522
+ for h in model.experts._forward_pre_hooks.values()
523
+ if "device_mesh" in h.__code__.co_freevars
524
+ )
525
+ return hook.__closure__[
526
+ hook.__code__.co_freevars.index("device_mesh")
527
+ ].cell_contents
528
+ except Exception:
529
+ return None
530
+
531
+
532
  class MegaBlocksMoeMLP(torch.nn.Module):
533
  can_torch_compile: bool = True
534
 
 
553
  self.experts, "normalize_expert_weights", None
554
  )
555
 
556
+ # Get EP (Expert Parallelism) parameters
557
+ ep_size = 1
558
+ ep_rank = 0
559
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
560
+ if expert_parallel_group is None:
561
+ device_mesh = _get_device_mesh(self)
562
+ if device_mesh is not None:
563
+ expert_parallel_group = device_mesh.get_group()
564
+ if expert_parallel_group is not None:
565
+ import torch.distributed as dist
566
+ if dist.is_initialized():
567
+ ep_size = dist.get_world_size(expert_parallel_group)
568
+ ep_rank = dist.get_rank(expert_parallel_group)
569
+
570
+ # Number of experts on this rank
571
+ num_experts_on_rank = moe_num_experts // ep_size
572
+
573
  # Detect activation type - check for GptOss-style swigluoai activation
574
  # GptOssExperts has alpha and limit attributes for swigluoai
575
  if hasattr(self.experts, "alpha") and hasattr(self.experts, "limit"):
 
644
  topk_ids=expert_indices,
645
  n_experts_per_token=moe_top_k,
646
  activation=activation,
647
+ num_experts=num_experts_on_rank,
648
+ ep_rank=ep_rank,
649
+ ep_size=ep_size,
650
  is_fp8=is_fp8,
651
  is_int4=is_int4,
652
  is_mxfp4=is_mxfp4,
653
  )
654
 
655
+ # All-reduce across EP group to combine partial expert outputs
656
+ if ep_size > 1 and expert_parallel_group is not None:
657
+ import torch.distributed as dist
658
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=expert_parallel_group)
659
+
660
  # Restore original shape
661
  output = output.view(in_shape)
662