| import torch | |
| from torch.profiler import record_function | |
| import triton_kernels | |
| import triton_kernels.swiglu | |
| from triton_kernels.numerics_details.mxfp import downcast_to_mxfp | |
| from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation | |
| from triton_kernels.matmul_ogs import matmul_ogs | |
| from triton_kernels.numerics import InFlexData | |
| from triton_kernels.routing import routing | |
| from triton_kernels.tensor import convert_layout | |
| from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout | |
| from triton_kernels.tensor import wrap_torch_tensor, FP4 | |
| def quantize_mx4(w): | |
| w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) | |
| w = convert_layout(wrap_torch_tensor(w, dtype=FP4), HopperMXValueLayout, mx_axis=1) | |
| w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) | |
| return w, w_scale | |
| def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True): | |
| if interleaved: | |
| x_glu, x_linear = x[..., ::2], x[..., 1::2] | |
| else: | |
| x_glu, x_linear = torch.chunk(x, 2, dim=-1) | |
| x_glu = x_glu.clamp(min=None, max=limit) | |
| x_linear = x_linear.clamp(min=-limit, max=limit) | |
| out_glu = x_glu * torch.sigmoid(alpha * x_glu) | |
| return out_glu * (x_linear + 1) | |
| def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True): | |
| if x.numel() == 0: | |
| return x | |
| pc1 = PrecisionConfig(weight_scale=w1_mx, flex_ctx=FlexCtx(rhs_data=InFlexData())) | |
| pc2 = PrecisionConfig(weight_scale=w2_mx, flex_ctx=FlexCtx(rhs_data=InFlexData())) | |
| pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=InFlexData())) | |
| with record_function("wg"): | |
| logits = matmul_ogs(x, wg, bg, precision_config=pcg) | |
| with record_function("routing"): | |
| rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1) | |
| if fused_act: | |
| assert interleaved, "Fused activation requires interleaved weights" | |
| with record_function("w1+swiglu"): | |
| act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.702, swiglu_limit), 2) | |
| x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) | |
| else: | |
| with record_function("w1"): | |
| x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1) | |
| with record_function("swiglu"): | |
| x = swiglu(x, limit=swiglu_limit, interleaved=interleaved) | |
| with record_function("w2"): | |
| x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2, gammas=rdata.gate_scal) | |
| return x |