Kernels
danieldk HF Staff commited on
Commit
4e8d945
Β·
verified Β·
1 Parent(s): 4f20330

Build uploaded using `kernels`.

Browse files
Files changed (30) hide show
  1. build/torch210-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  2. build/torch210-cxx11-cpu-x86_64-linux/_ops.py +3 -3
  3. build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py +93 -74
  4. build/torch210-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  5. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  6. build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py +93 -74
  7. build/torch210-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  8. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  9. build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py +93 -74
  10. build/torch210-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  11. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  12. build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py +93 -74
  13. build/torch210-cxx11-xpu20253-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  14. build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py +3 -3
  15. build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py +93 -74
  16. build/torch29-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  17. build/torch29-cxx11-cpu-x86_64-linux/_ops.py +3 -3
  18. build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py +93 -74
  19. build/torch29-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  20. build/torch29-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  21. build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py +93 -74
  22. build/torch29-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  23. build/torch29-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  24. build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py +93 -74
  25. build/torch29-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  26. build/torch29-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  27. build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py +93 -74
  28. build/torch29-cxx11-xpu20252-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} +1 -1
  29. build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py +3 -3
  30. build/torch29-cxx11-xpu20252-x86_64-linux/xpu_fused_moe.py +93 -74
build/torch210-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3a81c0cc23130a95d05263f0509e8de560183f6472f458f4316c97e6e8d8f533
3
  size 2219056
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bb9607d2d00b6eb3f3fe58da8dd972deb37b0658b8682807fc2863129f7aa8d
3
  size 2219056
build/torch210-cxx11-cpu-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch210-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d482577c55ffe1abd34983ce45eeeb280a817e55f92d6585b5e92173b2860749
3
  size 15061032
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:321e1bb305fd100b1abc99234f480634d05a901ee3a758628d94615d535e2caf
3
  size 15061032
build/torch210-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch210-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0876dbd4267e12fa67f24fac60cedbee8e6dd41b85104c4c241b173729bee9a
3
  size 21009952
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83c64c2e54082d931c9fc3027ef6522bf3f3acd4c49d4c5c14dbfcb5ab038b12
3
  size 21009952
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch210-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4c7bc97e0aadcd94b0f6d3d7198269823d894fd5a36f6af9744864211ae0fd71
3
  size 12041568
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f48c4762cbfdf923c9547acd7d792dd7edec4bcfe5a857b605ce370f807be23a
3
  size 12041568
build/torch210-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch210-cxx11-xpu20253-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbf6091a3c2622e19367385fb8c82b507f841749bc9c4177421884232856c021
3
  size 4227888
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e840b67c3d3ee92b1150b7c0e4eaab1eda0998347131838eea3bc1bd44049093
3
  size 4227888
build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch29-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8b3f1c2f3058c4c5c08291c7a51be003046657e7567454a779911c7cebfdc3d9
3
  size 2201176
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24c19663574a3afb94a458ee318e8b63d47d24f6b1f457a605c115a567810a08
3
  size 2201176
build/torch29-cxx11-cpu-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch29-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4d58bdd86403eaa524fac1db9361b0025a175f4b10dcddd8fa0bf99892172e54
3
  size 15046808
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc4e092bd6f32001e850abf73dd6ee609e9a25800d87fd9e19a0e4a6c30f8e9c
3
  size 15046808
build/torch29-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch29-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5c3c17f0fa54822f12b05fe5c22f8b61ad1a9711a02de13a706e1e8f63e141b
3
  size 20995680
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9018001f72f4a1b7f364d1ca582d8a756cbe452ed798efc4c42e74c49ca1839c
3
  size 20995680
build/torch29-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch29-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:609492272ed9672ab824abf87b08f078f409696c8db453ccc5f46dff39d84f98
3
  size 12031392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49caf38e644493142784e8ad8fac70c1ec9f249c798399950f4228570a570c04
3
  size 12031392
build/torch29-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(
build/torch29-cxx11-xpu20252-x86_64-linux/{_megablocks_099ac3c.abi3.so β†’ _megablocks_9be3a32.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:82d4807a02abe216da87ac6d4fbbf4870fdefa64ef182d09ab3408528107f08b
3
  size 4075712
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb6f2e895e92997f9d93107066513438e413bdba0012d0ee59737105b7ff6f1c
3
  size 4075712
build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_099ac3c
3
- ops = torch.ops._megablocks_099ac3c
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_099ac3c::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_9be3a32
3
+ ops = torch.ops._megablocks_9be3a32
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_9be3a32::{op_name}"
build/torch29-cxx11-xpu20252-x86_64-linux/xpu_fused_moe.py CHANGED
@@ -3,7 +3,9 @@
3
  import os
4
  import torch
5
 
6
- from ._ops import ops
 
 
7
 
8
 
9
  def resolve_dtensor(weight: torch.Tensor):
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
14
  return weight
15
 
16
 
17
- # Install meta kernels for torch.compile compatibility
18
- def _install_xpu_meta_kernels():
19
- """Install meta kernels for XPU MoE operations to support torch.compile"""
20
-
21
- # Patch cutlass_grouped_gemm_interface
22
- if hasattr(ops, "cutlass_grouped_gemm_interface"):
23
- original_gemm = ops.cutlass_grouped_gemm_interface
24
-
25
- def gemm_with_meta(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
26
- expert_first_token_offset, N, K, num_experts,
27
- is_B_int4, is_B_mxfp4):
28
- if torch.compiler.is_compiling():
29
- # Meta implementation - ptr_D is the output, return it
30
- return ptr_D
31
- return original_gemm(ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D,
32
- expert_first_token_offset, N, K, num_experts,
33
- is_B_int4, is_B_mxfp4)
34
-
35
- ops.cutlass_grouped_gemm_interface = gemm_with_meta
36
-
37
- # Patch fused_moe_prologue
38
- if hasattr(ops, "fused_moe_prologue"):
39
- original_prologue = ops.fused_moe_prologue
40
-
41
- def prologue_with_meta(input, token_selected_experts, token_final_scales,
42
- workspace, hidden_size, inter_size, num_experts_on_rank):
43
- if torch.compiler.is_compiling():
44
- # Meta implementation - this op modifies workspace in-place
45
- return None
46
- return original_prologue(input, token_selected_experts, token_final_scales,
47
- workspace, hidden_size, inter_size, num_experts_on_rank)
48
-
49
- ops.fused_moe_prologue = prologue_with_meta
50
-
51
- # Patch moe_gather
52
- if hasattr(ops, "moe_gather"):
53
- original_gather = ops.moe_gather
54
-
55
- def gather_with_meta(output, moe_output, topk_weights,
56
- unpermuted_row_to_permuted_row, num_experts):
57
- if torch.compiler.is_compiling():
58
- # Meta implementation - output is modified in-place
59
- return None
60
- return original_gather(output, moe_output, topk_weights,
61
- unpermuted_row_to_permuted_row, num_experts)
62
-
63
- ops.moe_gather = gather_with_meta
64
-
65
- # Patch activation ops
66
- for act_name in ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul",
67
- "gelu_fast", "gelu_new", "gelu_quick", "mul_and_silu",
68
- "swigluoai_and_mul"]:
69
- if hasattr(ops, act_name):
70
- original_act = getattr(ops, act_name)
71
-
72
- def make_act_wrapper(orig_fn):
73
- def act_with_meta(*args, **kwargs):
74
- if torch.compiler.is_compiling():
75
- # Meta implementation - in-place ops, return None
76
- return None
77
- return orig_fn(*args, **kwargs)
78
- return act_with_meta
79
-
80
- setattr(ops, act_name, make_act_wrapper(original_act))
81
-
82
-
83
- # Install meta kernels on module load
84
- _install_xpu_meta_kernels()
85
 
86
 
87
  # default
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
151
  return 1024
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def implement_zp(qweight):
155
  # change u4 to s4 to avoid zero point in gemm kernel
156
  # only support default zero point now
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
321
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
322
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
323
 
324
- workspace = torch.zeros(map_offset,
325
  dtype=torch.uint8,
326
  device=hidden_states.device)
327
  if topk_ids.dtype == torch.int32:
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
335
  inter_size=inter_size,
336
  num_experts_on_rank=num_experts_per_node)
337
 
338
- expert_first_token_offset = workspace[
339
  ws_map["expert_first_token_offset"][1]:
340
  ws_map["expert_first_token_offset"][1] +
341
- expert_first_token_offset_size].view(torch.int64)
342
- unpermuted_row_to_permuted_row = workspace[
343
  ws_map["unpermuted_row_to_permuted_row"][1]:
344
  ws_map["unpermuted_row_to_permuted_row"][1] +
345
- src_to_dest_map_size].view(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
346
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
347
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
348
  permuted_data_size].view(hidden_states.dtype).view(
 
3
  import os
4
  import torch
5
 
6
+ from ._ops import ops, add_op_namespace_prefix
7
+
8
+ from torch.library import register_fake
9
 
10
 
11
  def resolve_dtensor(weight: torch.Tensor):
 
16
  return weight
17
 
18
 
19
+ # Register fake/meta kernels for torch.compile compatibility
20
+ def _register_xpu_fake_kernels():
21
+ """Register fake kernels for XPU MoE operations to support torch.compile."""
22
+
23
+ def _register_if_available(op_name, fn):
24
+ if hasattr(ops, op_name):
25
+ register_fake(add_op_namespace_prefix(op_name))(fn)
26
+
27
+ _register_if_available(
28
+ "cutlass_grouped_gemm_interface",
29
+ lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
30
+ )
31
+
32
+ _register_if_available(
33
+ "fused_moe_prologue",
34
+ lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
35
+ )
36
+
37
+ _register_if_available(
38
+ "moe_gather",
39
+ lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
40
+ )
41
+
42
+ _register_if_available(
43
+ "silu_and_mul",
44
+ lambda out, input: None,
45
+ )
46
+ _register_if_available(
47
+ "mul_and_silu",
48
+ lambda out, input: None,
49
+ )
50
+ _register_if_available(
51
+ "gelu_and_mul",
52
+ lambda out, input: None,
53
+ )
54
+ _register_if_available(
55
+ "gelu_tanh_and_mul",
56
+ lambda out, input: None,
57
+ )
58
+ _register_if_available(
59
+ "gelu_fast",
60
+ lambda out, input: None,
61
+ )
62
+ _register_if_available(
63
+ "gelu_new",
64
+ lambda out, input: None,
65
+ )
66
+ _register_if_available(
67
+ "gelu_quick",
68
+ lambda out, input: None,
69
+ )
70
+ _register_if_available(
71
+ "swigluoai_and_mul",
72
+ lambda out, input, alpha=1.702, limit=7.0: None,
73
+ )
74
+
75
+
76
+ # Register fake kernels on module load
77
+ _register_xpu_fake_kernels()
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  # default
 
144
  return 1024
145
 
146
 
147
+ def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
148
+ """Reinterpret a uint8 buffer as a typed tensor by copying bytes.
149
+
150
+ This avoids `Tensor.view(dtype)` which can fail under torch.compile
151
+ constant folding when shape divisibility is not proven.
152
+ """
153
+ if byte_tensor.dtype != torch.uint8:
154
+ raise ValueError("byte_tensor must be uint8")
155
+ itemsize = torch.empty((), dtype=dtype).element_size()
156
+ numel = byte_tensor.numel() // itemsize
157
+ out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
158
+ out.view(torch.uint8).copy_(byte_tensor.contiguous())
159
+ return out
160
+
161
+
162
  def implement_zp(qweight):
163
  # change u4 to s4 to avoid zero point in gemm kernel
164
  # only support default zero point now
 
329
  config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
330
  config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
331
 
332
+ workspace = torch.empty(map_offset,
333
  dtype=torch.uint8,
334
  device=hidden_states.device)
335
  if topk_ids.dtype == torch.int32:
 
343
  inter_size=inter_size,
344
  num_experts_on_rank=num_experts_per_node)
345
 
346
+ expert_first_token_offset_bytes = workspace[
347
  ws_map["expert_first_token_offset"][1]:
348
  ws_map["expert_first_token_offset"][1] +
349
+ expert_first_token_offset_size]
350
+ unpermuted_row_to_permuted_row_bytes = workspace[
351
  ws_map["unpermuted_row_to_permuted_row"][1]:
352
  ws_map["unpermuted_row_to_permuted_row"][1] +
353
+ src_to_dest_map_size]
354
+
355
+ if torch.compiler.is_compiling():
356
+ expert_first_token_offset = _bytes_to_typed_tensor(
357
+ expert_first_token_offset_bytes, torch.int64
358
+ )
359
+ unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
360
+ unpermuted_row_to_permuted_row_bytes, torch.int32
361
+ )
362
+ else:
363
+ expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
364
+ unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
365
  gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
366
  ws_map["overlapped_gemm1_gemm2_inputs"][1] +
367
  permuted_data_size].view(hidden_states.dtype).view(