| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import Callable | |
| import pytest | |
| import torch | |
| from flashinfer import fp4_quantize | |
| from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe | |
| from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant | |
| from torch.nn import functional as F | |
| from sglang.srt.layers.activation import SiluAndMul | |
| from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 | |
| from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType | |
| from sglang.srt.layers.moe.flashinfer_cutedsl_moe import flashinfer_cutedsl_moe_masked | |
| from sglang.srt.layers.moe.topk import TopKConfig, select_experts | |
| if torch.cuda.get_device_capability() < (10, 0): | |
| pytest.skip( | |
| reason="Nvfp4 Requires compute capability of 10 or above.", | |
| allow_module_level=True, | |
| ) | |
| kE2M1ToFloat = torch.tensor( | |
| [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 | |
| ) | |
| FLOAT8_E4M3_MAX = 448.0 | |
| FLOAT4_E2M1_MAX = 6.0 | |
| def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): | |
| m_tiles = (m + 128 - 1) // 128 | |
| f = block_size * 4 | |
| k_tiles = (k + f - 1) // f | |
| tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) | |
| tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) | |
| out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) | |
| return out[0:m, 0:k] | |
| def dequantize_nvfp4_to_dtype( | |
| tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 | |
| ): | |
| """Dequantize the fp4 tensor back to high precision.""" | |
| # Two fp4 values are packed into one uint8. | |
| assert tensor_fp4.dtype == torch.uint8 | |
| m, packed_k = tensor_fp4.shape | |
| k = packed_k * 2 | |
| tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) | |
| tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) | |
| tensor_sf = tensor_sf.view(torch.float8_e4m3fn) | |
| tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) | |
| tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale | |
| # scale the tensor | |
| out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) | |
| return out.to(dtype=dtype) | |
| def break_fp4_bytes(a, dtype): | |
| assert a.dtype == torch.uint8 | |
| m, n = a.shape | |
| # Vectorized nibble processing | |
| a_flat = a.flatten() | |
| high = (a_flat & 0xF0) >> 4 # Upper nibbles | |
| low = a_flat & 0x0F # Lower nibbles | |
| # Combine nibbles for batch processing | |
| combined = torch.stack((low, high), dim=1).flatten() | |
| # Vectorized sign and magnitude extraction | |
| signs = (combined & 0x08).to(torch.bool) # Sign bits | |
| abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices | |
| # Device-aware lookup and sign application | |
| kE2M1 = kE2M1ToFloat.to(device=a.device) | |
| values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) | |
| # Reshape to final form | |
| return values.reshape(m, n * 2).to(dtype=dtype) | |
| def compute_routing(router_logits: torch.Tensor, top_k: int): | |
| routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float) | |
| routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) | |
| routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | |
| routing_weights = routing_weights.float() | |
| return routing_weights, selected_experts | |
| def prepare_inputs( | |
| hidden_states: torch.Tensor, | |
| router_logits: torch.Tensor, | |
| num_experts: int, | |
| topk: int, | |
| ): | |
| routing_weights, topk_idx = compute_routing(router_logits, topk) | |
| masked_m = [] | |
| for i in range(num_experts): | |
| mask = topk_idx.view(-1) == i | |
| masked_m.append(mask.sum()) | |
| masked_m = torch.tensor(masked_m, dtype=torch.int32) | |
| hidden_states_3d = torch.empty( | |
| (num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype | |
| ) | |
| for i in range(num_experts): | |
| hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i] | |
| return hidden_states_3d, masked_m, topk_idx, routing_weights | |
| MNK_FACTORS = [ | |
| (2, 1024, 1024), | |
| (2, 1024, 1536), | |
| (2, 3072, 1024), | |
| (2, 3072, 1536), | |
| (64, 1024, 1024), | |
| (64, 1024, 1536), | |
| (64, 3072, 1024), | |
| (64, 2048, 1024), | |
| (224, 1024, 1024), | |
| (224, 1024, 1536), | |
| ] | |
| # Reference implementation of torch_moe | |
| def torch_moe(a, w1, w2, score, topk, expert_map): | |
| B, D = a.shape | |
| a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) | |
| out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) | |
| score = torch.softmax(score, dim=-1, dtype=torch.float32) | |
| topk_weight, topk_ids = torch.topk(score, topk) | |
| topk_weight = topk_weight.view(-1) | |
| topk_ids = topk_ids.view(-1) | |
| if expert_map is not None: | |
| topk_ids = expert_map[topk_ids] | |
| for i in range(w1.shape[0]): | |
| mask = topk_ids == i | |
| if mask.sum(): | |
| out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( | |
| 0, 1 | |
| ) | |
| return ( | |
| out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) | |
| ).sum(dim=1) | |
| def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids): | |
| B, D = a.shape | |
| a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) | |
| out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) | |
| topk_weight = topk_weight.view(-1) | |
| topk_ids = topk_ids.view(-1) | |
| for i in range(w1.shape[0]): | |
| mask = topk_ids == i | |
| if mask.sum(): | |
| m = w1[i].shape[0] | |
| assert m % 2 == 0 | |
| # Note: w1 and w3 are swapped! | |
| w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :] | |
| inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) | |
| inter_gs = torch.tensor(1.0).cuda() | |
| inter_q, inter_blockscale = fp4_quantize(inter, inter_gs) | |
| inter = dequantize_nvfp4_to_dtype( | |
| inter_q, | |
| inter_blockscale, | |
| inter_gs, | |
| dtype=inter.dtype, | |
| device=inter.device, | |
| block_size=16, | |
| ).cuda() | |
| out[mask] = inter @ w2[i].transpose(0, 1) | |
| return ( | |
| out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) | |
| ).sum(dim=1) | |
| def flashinfer_cutedsl_grouped_gemm_nt_masked( | |
| hidden_states: torch.Tensor, # 3d | |
| input_global_scale: torch.Tensor, # (l,) | |
| weights: torch.Tensor, | |
| w_global_scale: torch.Tensor, # (l,) | |
| masked_m: torch.Tensor, | |
| ): | |
| from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked | |
| # hidden_states: [l, m, k] | |
| # weights: [l, n, k] | |
| aq, aq_sf = scaled_fp4_grouped_quant( | |
| hidden_states, | |
| input_global_scale, | |
| masked_m.to(hidden_states.device), | |
| ) | |
| num_experts, n, k = weights.shape | |
| bq, bq_sf = scaled_fp4_grouped_quant( | |
| weights, | |
| w_global_scale, | |
| torch.ones(num_experts, device=weights.device, dtype=torch.int32) * n, | |
| ) | |
| out = torch.zeros( | |
| (num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device | |
| ) | |
| out = out.permute(1, 2, 0) # requirement of kernel | |
| sf_vec_size = 16 | |
| ab_dtype = "float4_e2m1fn" | |
| sf_dtype = "float8_e4m3fn" | |
| c_dtype = "bfloat16" | |
| alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view( | |
| 1, 1, num_experts | |
| ) | |
| def get_cute_dtype(input: torch.Tensor) -> str: | |
| if input.dtype == torch.bfloat16: | |
| return "bfloat16" | |
| elif input.dtype == torch.float16: | |
| return "float16" | |
| elif input.dtype == torch.float32: | |
| return "float32" | |
| else: | |
| raise ValueError(f"Unsupported cute dtype {input.dtype}") | |
| grouped_gemm_nt_masked( | |
| (aq, aq_sf), | |
| (bq, bq_sf), | |
| out, | |
| masked_m.to(aq.device), | |
| ab_dtype=ab_dtype, | |
| sf_dtype=sf_dtype, | |
| c_dtype=c_dtype, | |
| sf_vec_size=sf_vec_size, | |
| alpha=alpha, | |
| alpha_dtype=get_cute_dtype(alpha), | |
| ) | |
| return out | |
| def check_moe( | |
| m: int, | |
| n: int, | |
| k: int, | |
| e: int, | |
| topk: int, | |
| dtype: torch.dtype, | |
| moe_impl: Callable, | |
| flip_w13: bool, | |
| ): | |
| torch.manual_seed(7) | |
| a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 | |
| w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 | |
| quant_blocksize = 16 | |
| round_up = lambda x, y: (x + y - 1) // y * y | |
| sf_w1_2n = round_up(2 * n, 128) | |
| sf_w1_k = round_up(k // quant_blocksize, 4) | |
| w1_blockscale = torch.empty( | |
| (e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn | |
| ) | |
| w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 | |
| sf_w2_k = round_up(k, 128) | |
| sf_w2_n = round_up(n // quant_blocksize, 4) | |
| w2_blockscale = torch.empty( | |
| (e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn | |
| ) | |
| w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8) | |
| w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) | |
| w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32) | |
| w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32) | |
| for expert in range(e): | |
| w1_amax = torch.abs(w1).max().to(torch.float32) | |
| w2_amax = torch.abs(w2).max().to(torch.float32) | |
| w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax | |
| w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax | |
| w1_q[expert], w1_blockscale[expert] = scaled_fp4_quant( | |
| w1[expert], w1_gs[expert] | |
| ) | |
| w2_q[expert], w2_blockscale[expert] = scaled_fp4_quant( | |
| w2[expert], w2_gs[expert] | |
| ) | |
| score = torch.randn((m, e), device="cuda", dtype=dtype) | |
| topk_output = select_experts( | |
| hidden_states=a, | |
| router_logits=score, | |
| topk_config=TopKConfig(top_k=topk, renormalize=False), | |
| ) | |
| topk_weights, topk_ids, _ = topk_output | |
| a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) | |
| a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) | |
| test_output = moe_impl( | |
| a=a, | |
| topk_weights=topk_weights, | |
| topk_ids=topk_ids, | |
| w1_q=w1_q, | |
| w2_q=w2_q, | |
| a1_gs=a1_gs, | |
| w1_blockscale=w1_blockscale, | |
| w1_alphas=(1 / w1_gs), | |
| a2_gs=a2_gs, | |
| w2_blockscale=w2_blockscale, | |
| w2_alphas=(1 / w2_gs), | |
| ) | |
| # Reference check: | |
| a_global_scale = ( | |
| (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) | |
| ).to(torch.float32) | |
| a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale) | |
| _, m_k = a_fp4.shape | |
| a_in_dtype = dequantize_nvfp4_to_dtype( | |
| a_fp4, | |
| a_scale_interleaved, | |
| a_global_scale, | |
| dtype=a.dtype, | |
| device=a.device, | |
| block_size=quant_blocksize, | |
| ) | |
| w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) | |
| w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) | |
| for idx in range(0, e): | |
| w1_d[idx] = dequantize_nvfp4_to_dtype( | |
| w1_q[idx], | |
| w1_blockscale[idx], | |
| w1_gs[idx], | |
| dtype=w1.dtype, | |
| device=w1.device, | |
| block_size=quant_blocksize, | |
| ) | |
| w2_d[idx] = dequantize_nvfp4_to_dtype( | |
| w2_q[idx], | |
| w2_blockscale[idx], | |
| w2_gs[idx], | |
| dtype=w2.dtype, | |
| device=w2.device, | |
| block_size=quant_blocksize, | |
| ) | |
| if flip_w13: | |
| dim = -2 | |
| size = w1_d.size(dim) | |
| assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" | |
| half = size // 2 | |
| # Reorder weight | |
| w1, w3 = w1_d.split(half, dim=dim) | |
| w1_d = torch.cat([w3, w1], dim=dim).contiguous() | |
| torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) | |
| torch.testing.assert_close(torch_output, test_output, atol=1e-1, rtol=1e-1) | |
| def test_cutlass_fp4_moe_no_graph( | |
| m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype | |
| ): | |
| def cutlass_moe_impl( | |
| a, | |
| topk_weights, | |
| topk_ids, | |
| w1_q, | |
| w2_q, | |
| a1_gs, | |
| w1_blockscale, | |
| w1_alphas, | |
| a2_gs, | |
| w2_blockscale, | |
| w2_alphas, | |
| ): | |
| params = CutlassMoEParams( | |
| CutlassMoEType.BlockscaledFP4, | |
| device=a.device, | |
| num_experts=e, | |
| intermediate_size_per_partition=n, # n | |
| hidden_size=k, | |
| ) # k | |
| return cutlass_moe_fp4( | |
| a=a, | |
| a1_gscale=a1_gs, | |
| w1_fp4=w1_q, | |
| w1_blockscale=w1_blockscale, | |
| w1_alphas=w1_alphas, | |
| a2_gscale=a2_gs, | |
| w2_fp4=w2_q, | |
| w2_blockscale=w2_blockscale, | |
| w2_alphas=w2_alphas, | |
| topk_weights=topk_weights, | |
| topk_ids=topk_ids, | |
| params=params, | |
| apply_router_weight_on_input=False, | |
| ) | |
| check_moe(m, n, k, e, topk, dtype, cutlass_moe_impl, flip_w13=False) | |
| def test_flashinfer_fp4_moe_no_graph( | |
| m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype | |
| ): | |
| def flashinfer_moe_impl( | |
| a, | |
| topk_weights, | |
| topk_ids, | |
| w1_q, | |
| w2_q, | |
| a1_gs, | |
| w1_blockscale, | |
| w1_alphas, | |
| a2_gs, | |
| w2_blockscale, | |
| w2_alphas, | |
| ): | |
| return flashinfer_cutlass_fused_moe( | |
| a, | |
| topk_ids.to(torch.int), | |
| topk_weights, | |
| w1_q.view(torch.long), | |
| w2_q.view(torch.long), | |
| a.dtype, | |
| quant_scales=[ | |
| a1_gs, | |
| w1_blockscale.view(torch.int32), | |
| w1_alphas, | |
| a2_gs, | |
| w2_blockscale.view(torch.int32), | |
| w2_alphas, | |
| ], | |
| )[0] | |
| check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True) | |
| def test_flashinfer_cutedsl_moe_masked( | |
| bs: int, hidden_dim: int, inter_dim: int, topk: int | |
| ): | |
| torch.manual_seed(42) | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| num_experts = 8 | |
| hidden_states = ( | |
| torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0 | |
| ) | |
| w1 = ( | |
| torch.randn( | |
| num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device | |
| ) | |
| / 10.0 | |
| ) | |
| w2 = ( | |
| torch.randn( | |
| num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device | |
| ) | |
| / 10.0 | |
| ) | |
| router_logits = torch.randn(bs, num_experts, dtype=torch.float32) | |
| hidden_states_expanded = ( | |
| hidden_states.view(bs, -1, hidden_dim) | |
| .repeat(1, topk, 1) | |
| .reshape(-1, hidden_dim) | |
| ) | |
| hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs( | |
| hidden_states_expanded, router_logits, num_experts, topk | |
| ) | |
| w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device) | |
| w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device) | |
| input_global_scale = torch.ones( | |
| (num_experts,), dtype=torch.float32, device=hidden_states.device | |
| ) | |
| w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax | |
| w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax | |
| a2_global_scale = torch.ones( | |
| (num_experts,), dtype=torch.float32, device=hidden_states.device | |
| ) # assume intermediate scale is 1.0 | |
| w1_fp4, w1_blockscale = scaled_fp4_grouped_quant( | |
| w1, | |
| w1_global_scale, | |
| torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim, | |
| ) | |
| w2_fp4, w2_blockscale = scaled_fp4_grouped_quant( | |
| w2, | |
| w2_global_scale, | |
| torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim, | |
| ) | |
| w1_alpha = 1.0 / (input_global_scale * w1_global_scale) | |
| w2_alpha = 1.0 / (a2_global_scale * w2_global_scale) | |
| out = flashinfer_cutedsl_moe_masked( | |
| hidden_states_3d.to(hidden_states.device), | |
| input_global_scale, | |
| w1_fp4.permute(2, 0, 1), | |
| w1_blockscale, | |
| w1_alpha, | |
| w2_fp4.permute(2, 0, 1), | |
| a2_global_scale, | |
| w2_blockscale, | |
| w2_alpha, | |
| masked_m.to(hidden_states.device), | |
| ) | |
| # reference | |
| a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale) | |
| a_in_dtype = dequantize_nvfp4_to_dtype( | |
| a_fp4, | |
| a_scale_interleaved, | |
| input_global_scale, | |
| dtype=hidden_states.dtype, | |
| device=hidden_states.device, | |
| block_size=16, | |
| ) | |
| w1_d = torch.empty( | |
| (num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype | |
| ) | |
| w2_d = torch.empty( | |
| (num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype | |
| ) | |
| for idx in range(0, num_experts): | |
| w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize( | |
| w1[idx], w1_global_scale[idx] | |
| ) | |
| w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize( | |
| w2[idx], w2_global_scale[idx] | |
| ) | |
| w1_d[idx] = dequantize_nvfp4_to_dtype( | |
| w1_fp4_sliced, | |
| w1_blockscale_sliced, | |
| w1_global_scale[idx], | |
| dtype=w1.dtype, | |
| device=w1.device, | |
| block_size=16, | |
| ) | |
| w2_d[idx] = dequantize_nvfp4_to_dtype( | |
| w2_fp4_sliced, | |
| w2_blockscale_sliced, | |
| w2_global_scale[idx], | |
| dtype=w2.dtype, | |
| device=w2.device, | |
| block_size=16, | |
| ) | |
| ref_output = torch_moe_nvfp4( | |
| a_in_dtype, | |
| w1_d, | |
| w2_d, | |
| topk, | |
| routing_weights.to(a_in_dtype.device), | |
| topk_idx.to(a_in_dtype.device), | |
| ) | |
| out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype) | |
| positions = torch.nonzero(masked_m[topk_idx], as_tuple=False) | |
| rows, cols = positions[:, 0], positions[:, 1] | |
| experts = topk_idx[rows, cols] | |
| for i in range(num_experts): | |
| mask = experts == i | |
| if mask.any(): | |
| idx = torch.nonzero(mask, as_tuple=False).squeeze(-1) | |
| r, c = rows[idx], cols[idx] | |
| out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to( | |
| out.device | |
| ).unsqueeze(-1) | |
| torch.testing.assert_close( | |
| out_weighted.cpu(), ref_output.cpu(), atol=5e-2, rtol=5e-2 | |
| ) | |
| def test_grouped_gemm_nt_masked( | |
| bs: int, hidden_dim: int, inter_dim: int, topk: int | |
| ) -> None: | |
| torch.manual_seed(42) | |
| B = bs | |
| D = hidden_dim | |
| N = inter_dim | |
| num_experts = 8 | |
| hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda") | |
| weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda") | |
| router_logits = torch.randn(B, num_experts, dtype=torch.float32) | |
| hidden_states_expanded = ( | |
| hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) | |
| ) | |
| hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs( | |
| hidden_states_expanded, router_logits, num_experts, topk | |
| ) | |
| # reference | |
| out = torch.zeros( | |
| (B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device | |
| ) | |
| for i in range(num_experts): | |
| mask = topk_idx.view(-1) == i | |
| if mask.sum(): | |
| lhs = hidden_states_expanded[mask] | |
| rhs = weights[i] | |
| a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device) | |
| b_amax = rhs.abs().amax().to(torch.float32).to(weights.device) | |
| a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax | |
| b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax | |
| lhsq, lhsq_sf = fp4_quantize( | |
| lhs, | |
| a_gs, | |
| ) | |
| rhsq, rhsq_sf = fp4_quantize( | |
| rhs, | |
| b_gs, | |
| ) | |
| lhs_in_dtype = dequantize_nvfp4_to_dtype( | |
| lhsq, | |
| lhsq_sf, | |
| a_gs, | |
| dtype=hidden_states.dtype, | |
| device=hidden_states.device, | |
| block_size=16, | |
| ) | |
| rhs_in_dtype = dequantize_nvfp4_to_dtype( | |
| rhsq, | |
| rhsq_sf, | |
| b_gs, | |
| dtype=hidden_states.dtype, | |
| device=hidden_states.device, | |
| block_size=16, | |
| ) | |
| out[mask] = lhs_in_dtype @ rhs_in_dtype.t() | |
| a_amax = ( | |
| hidden_states_3d.abs() | |
| .amax(dim=(1, 2)) | |
| .to(torch.float32) | |
| .to(hidden_states.device) | |
| ) | |
| b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device) | |
| a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax | |
| b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax | |
| out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked( | |
| hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m | |
| ) | |
| # re-pack out into [num_experts, max_m, n] | |
| out_ref = torch.zeros( | |
| (num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype | |
| ) | |
| expert_slot = [0] * num_experts | |
| for i, expert_id in enumerate(topk_idx.view(-1).tolist()): | |
| out_ref[expert_id, expert_slot[expert_id], :] = out[i] | |
| expert_slot[expert_id] += 1 | |
| # Note: just to compare the masked position due to cutedsl may write nan | |
| # into unmasked position. | |
| for i in range(num_experts): | |
| torch.testing.assert_close( | |
| out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]], | |
| out_ref.to(out_flashinfer.device)[i, : masked_m[i]], | |
| atol=1e-1, | |
| rtol=5e-2, | |
| ) | |
| if __name__ == "__main__": | |
| test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half) | |
| test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half) | |
| test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4) | |
| test_grouped_gemm_nt_masked(16, 128, 512, 4) | |
Xet Storage Details
- Size:
- 22.4 kB
- Xet hash:
- 6d93496362bc641569e48639b3d1a5df522226fc7e5e77371ee8050abcb0e8dc
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.