import torch from torch.profiler import record_function import triton_kernels import triton_kernels.swiglu from triton_kernels.numerics_details.mxfp import downcast_to_mxfp from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation from triton_kernels.matmul_ogs import matmul_ogs from triton_kernels.numerics import InFlexData from triton_kernels.routing import routing from triton_kernels.tensor import convert_layout from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout from triton_kernels.tensor import wrap_torch_tensor, FP4 def quantize_mx4(w): w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) w = convert_layout(wrap_torch_tensor(w, dtype=FP4), HopperMXValueLayout, mx_axis=1) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True): if interleaved: x_glu, x_linear = x[..., ::2], x[..., 1::2] else: x_glu, x_linear = torch.chunk(x, 2, dim=-1) x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) return out_glu * (x_linear + 1) def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True): if x.numel() == 0: return x pc1 = PrecisionConfig(weight_scale=w1_mx, flex_ctx=FlexCtx(rhs_data=InFlexData())) pc2 = PrecisionConfig(weight_scale=w2_mx, flex_ctx=FlexCtx(rhs_data=InFlexData())) pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=InFlexData())) with record_function("wg"): logits = matmul_ogs(x, wg, bg, precision_config=pcg) with record_function("routing"): rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1) if fused_act: assert interleaved, "Fused activation requires interleaved weights" with record_function("w1+swiglu"): act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.702, swiglu_limit), 2) x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) else: with record_function("w1"): x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1) with record_function("swiglu"): x = swiglu(x, limit=swiglu_limit, interleaved=interleaved) with record_function("w2"): x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2, gammas=rdata.gate_scal) return x