Build uploaded using `kernels`.
Browse files- build/torch210-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch210-cxx11-cpu-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch210-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch210-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch210-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch210-cxx11-xpu20253-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch29-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch29-cxx11-cpu-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch29-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch29-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch29-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch29-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch29-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch29-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py +93 -74
- build/torch29-cxx11-xpu20252-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so} +1 -1
- build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-xpu20252-x86_64-linux/xpu_fused_moe.py +93 -74
build/torch210-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2219056
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1bb9607d2d00b6eb3f3fe58da8dd972deb37b0658b8682807fc2863129f7aa8d
|
| 3 |
size 2219056
|
build/torch210-cxx11-cpu-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch210-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch210-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15061032
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:321e1bb305fd100b1abc99234f480634d05a901ee3a758628d94615d535e2caf
|
| 3 |
size 15061032
|
build/torch210-cxx11-cu126-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch210-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch210-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 21009952
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:83c64c2e54082d931c9fc3027ef6522bf3f3acd4c49d4c5c14dbfcb5ab038b12
|
| 3 |
size 21009952
|
build/torch210-cxx11-cu128-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch210-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch210-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 12041568
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f48c4762cbfdf923c9547acd7d792dd7edec4bcfe5a857b605ce370f807be23a
|
| 3 |
size 12041568
|
build/torch210-cxx11-cu130-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch210-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch210-cxx11-xpu20253-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 4227888
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e840b67c3d3ee92b1150b7c0e4eaab1eda0998347131838eea3bc1bd44049093
|
| 3 |
size 4227888
|
build/torch210-cxx11-xpu20253-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch210-cxx11-xpu20253-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch29-cxx11-cpu-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2201176
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24c19663574a3afb94a458ee318e8b63d47d24f6b1f457a605c115a567810a08
|
| 3 |
size 2201176
|
build/torch29-cxx11-cpu-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch29-cxx11-cpu-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch29-cxx11-cu126-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15046808
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc4e092bd6f32001e850abf73dd6ee609e9a25800d87fd9e19a0e4a6c30f8e9c
|
| 3 |
size 15046808
|
build/torch29-cxx11-cu126-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch29-cxx11-cu126-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch29-cxx11-cu128-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 20995680
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9018001f72f4a1b7f364d1ca582d8a756cbe452ed798efc4c42e74c49ca1839c
|
| 3 |
size 20995680
|
build/torch29-cxx11-cu128-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch29-cxx11-cu128-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch29-cxx11-cu130-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 12031392
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49caf38e644493142784e8ad8fac70c1ec9f249c798399950f4228570a570c04
|
| 3 |
size 12031392
|
build/torch29-cxx11-cu130-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch29-cxx11-cu130-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|
build/torch29-cxx11-xpu20252-x86_64-linux/{_megablocks_099ac3c.abi3.so β _megablocks_9be3a32.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 4075712
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb6f2e895e92997f9d93107066513438e413bdba0012d0ee59737105b7ff6f1c
|
| 3 |
size 4075712
|
build/torch29-cxx11-xpu20252-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_9be3a32
|
| 3 |
+
ops = torch.ops._megablocks_9be3a32
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_9be3a32::{op_name}"
|
build/torch29-cxx11-xpu20252-x86_64-linux/xpu_fused_moe.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from ._ops import ops
|
|
|
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def resolve_dtensor(weight: torch.Tensor):
|
|
@@ -14,74 +16,65 @@ def resolve_dtensor(weight: torch.Tensor):
|
|
| 14 |
return weight
|
| 15 |
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
def
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return None
|
| 77 |
-
return orig_fn(*args, **kwargs)
|
| 78 |
-
return act_with_meta
|
| 79 |
-
|
| 80 |
-
setattr(ops, act_name, make_act_wrapper(original_act))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Install meta kernels on module load
|
| 84 |
-
_install_xpu_meta_kernels()
|
| 85 |
|
| 86 |
|
| 87 |
# default
|
|
@@ -151,6 +144,21 @@ def compute_num_tokens_per_block(num_tokens, num_experts_per_node):
|
|
| 151 |
return 1024
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def implement_zp(qweight):
|
| 155 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 156 |
# only support default zero point now
|
|
@@ -321,7 +329,7 @@ def xpu_fused_moe(hidden_states,
|
|
| 321 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 322 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 323 |
|
| 324 |
-
workspace = torch.
|
| 325 |
dtype=torch.uint8,
|
| 326 |
device=hidden_states.device)
|
| 327 |
if topk_ids.dtype == torch.int32:
|
|
@@ -335,14 +343,25 @@ def xpu_fused_moe(hidden_states,
|
|
| 335 |
inter_size=inter_size,
|
| 336 |
num_experts_on_rank=num_experts_per_node)
|
| 337 |
|
| 338 |
-
|
| 339 |
ws_map["expert_first_token_offset"][1]:
|
| 340 |
ws_map["expert_first_token_offset"][1] +
|
| 341 |
-
expert_first_token_offset_size]
|
| 342 |
-
|
| 343 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 344 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 345 |
-
src_to_dest_map_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 347 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 348 |
permuted_data_size].view(hidden_states.dtype).view(
|
|
|
|
| 3 |
import os
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from ._ops import ops, add_op_namespace_prefix
|
| 7 |
+
|
| 8 |
+
from torch.library import register_fake
|
| 9 |
|
| 10 |
|
| 11 |
def resolve_dtensor(weight: torch.Tensor):
|
|
|
|
| 16 |
return weight
|
| 17 |
|
| 18 |
|
| 19 |
+
# Register fake/meta kernels for torch.compile compatibility
|
| 20 |
+
def _register_xpu_fake_kernels():
|
| 21 |
+
"""Register fake kernels for XPU MoE operations to support torch.compile."""
|
| 22 |
+
|
| 23 |
+
def _register_if_available(op_name, fn):
|
| 24 |
+
if hasattr(ops, op_name):
|
| 25 |
+
register_fake(add_op_namespace_prefix(op_name))(fn)
|
| 26 |
+
|
| 27 |
+
_register_if_available(
|
| 28 |
+
"cutlass_grouped_gemm_interface",
|
| 29 |
+
lambda ptr_A, ptr_B, ptr_scales, ptr_bias, ptr_D, expert_first_token_offset, N, K, num_experts, is_B_int4, is_B_mxfp4: ptr_D,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
_register_if_available(
|
| 33 |
+
"fused_moe_prologue",
|
| 34 |
+
lambda input, token_selected_experts, token_final_scales, workspace, hidden_size, inter_size, num_experts_on_rank: None,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
_register_if_available(
|
| 38 |
+
"moe_gather",
|
| 39 |
+
lambda output, moe_output, topk_weights, unpermuted_row_to_permuted_row, num_experts: None,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_register_if_available(
|
| 43 |
+
"silu_and_mul",
|
| 44 |
+
lambda out, input: None,
|
| 45 |
+
)
|
| 46 |
+
_register_if_available(
|
| 47 |
+
"mul_and_silu",
|
| 48 |
+
lambda out, input: None,
|
| 49 |
+
)
|
| 50 |
+
_register_if_available(
|
| 51 |
+
"gelu_and_mul",
|
| 52 |
+
lambda out, input: None,
|
| 53 |
+
)
|
| 54 |
+
_register_if_available(
|
| 55 |
+
"gelu_tanh_and_mul",
|
| 56 |
+
lambda out, input: None,
|
| 57 |
+
)
|
| 58 |
+
_register_if_available(
|
| 59 |
+
"gelu_fast",
|
| 60 |
+
lambda out, input: None,
|
| 61 |
+
)
|
| 62 |
+
_register_if_available(
|
| 63 |
+
"gelu_new",
|
| 64 |
+
lambda out, input: None,
|
| 65 |
+
)
|
| 66 |
+
_register_if_available(
|
| 67 |
+
"gelu_quick",
|
| 68 |
+
lambda out, input: None,
|
| 69 |
+
)
|
| 70 |
+
_register_if_available(
|
| 71 |
+
"swigluoai_and_mul",
|
| 72 |
+
lambda out, input, alpha=1.702, limit=7.0: None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Register fake kernels on module load
|
| 77 |
+
_register_xpu_fake_kernels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# default
|
|
|
|
| 144 |
return 1024
|
| 145 |
|
| 146 |
|
| 147 |
+
def _bytes_to_typed_tensor(byte_tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 148 |
+
"""Reinterpret a uint8 buffer as a typed tensor by copying bytes.
|
| 149 |
+
|
| 150 |
+
This avoids `Tensor.view(dtype)` which can fail under torch.compile
|
| 151 |
+
constant folding when shape divisibility is not proven.
|
| 152 |
+
"""
|
| 153 |
+
if byte_tensor.dtype != torch.uint8:
|
| 154 |
+
raise ValueError("byte_tensor must be uint8")
|
| 155 |
+
itemsize = torch.empty((), dtype=dtype).element_size()
|
| 156 |
+
numel = byte_tensor.numel() // itemsize
|
| 157 |
+
out = torch.empty((numel,), dtype=dtype, device=byte_tensor.device)
|
| 158 |
+
out.view(torch.uint8).copy_(byte_tensor.contiguous())
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def implement_zp(qweight):
|
| 163 |
# change u4 to s4 to avoid zero point in gemm kernel
|
| 164 |
# only support default zero point now
|
|
|
|
| 329 |
config_ws("permuted_token_final_scales", permuted_token_final_scales_size)
|
| 330 |
config_ws("overlapped_gemm1_gemm2_inputs", permuted_data_size)
|
| 331 |
|
| 332 |
+
workspace = torch.empty(map_offset,
|
| 333 |
dtype=torch.uint8,
|
| 334 |
device=hidden_states.device)
|
| 335 |
if topk_ids.dtype == torch.int32:
|
|
|
|
| 343 |
inter_size=inter_size,
|
| 344 |
num_experts_on_rank=num_experts_per_node)
|
| 345 |
|
| 346 |
+
expert_first_token_offset_bytes = workspace[
|
| 347 |
ws_map["expert_first_token_offset"][1]:
|
| 348 |
ws_map["expert_first_token_offset"][1] +
|
| 349 |
+
expert_first_token_offset_size]
|
| 350 |
+
unpermuted_row_to_permuted_row_bytes = workspace[
|
| 351 |
ws_map["unpermuted_row_to_permuted_row"][1]:
|
| 352 |
ws_map["unpermuted_row_to_permuted_row"][1] +
|
| 353 |
+
src_to_dest_map_size]
|
| 354 |
+
|
| 355 |
+
if torch.compiler.is_compiling():
|
| 356 |
+
expert_first_token_offset = _bytes_to_typed_tensor(
|
| 357 |
+
expert_first_token_offset_bytes, torch.int64
|
| 358 |
+
)
|
| 359 |
+
unpermuted_row_to_permuted_row = _bytes_to_typed_tensor(
|
| 360 |
+
unpermuted_row_to_permuted_row_bytes, torch.int32
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
expert_first_token_offset = expert_first_token_offset_bytes.view(torch.int64)
|
| 364 |
+
unpermuted_row_to_permuted_row = unpermuted_row_to_permuted_row_bytes.view(torch.int32)
|
| 365 |
gemm1_input = workspace[ws_map["overlapped_gemm1_gemm2_inputs"][1]:
|
| 366 |
ws_map["overlapped_gemm1_gemm2_inputs"][1] +
|
| 367 |
permuted_data_size].view(hidden_states.dtype).view(
|