Quillan-Ronin / src /AceMoE.py
CrashOverrideX's picture
Add files using upload-large-folder tool
41a3927 verified
import torch
from torch.profiler import record_function
import triton_kernels
import triton_kernels.swiglu
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import routing
from triton_kernels.tensor import convert_layout
from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout
from triton_kernels.tensor import wrap_torch_tensor, FP4
def quantize_mx4(w):
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), HopperMXValueLayout, mx_axis=1)
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
return w, w_scale
def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True):
if interleaved:
x_glu, x_linear = x[..., ::2], x[..., 1::2]
else:
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
return out_glu * (x_linear + 1)
def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True):
if x.numel() == 0:
return x
pc1 = PrecisionConfig(weight_scale=w1_mx, flex_ctx=FlexCtx(rhs_data=InFlexData()))
pc2 = PrecisionConfig(weight_scale=w2_mx, flex_ctx=FlexCtx(rhs_data=InFlexData()))
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=InFlexData()))
with record_function("wg"):
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
with record_function("routing"):
rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1)
if fused_act:
assert interleaved, "Fused activation requires interleaved weights"
with record_function("w1+swiglu"):
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.702, swiglu_limit), 2)
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
else:
with record_function("w1"):
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
with record_function("swiglu"):
x = swiglu(x, limit=swiglu_limit, interleaved=interleaved)
with record_function("w2"):
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2, gammas=rdata.gate_scal)
return x