| |
| |
| from dataclasses import dataclass, fields, replace |
| import pytest |
| import torch |
| from typing import Union |
| import triton |
| |
| from triton_kernels.routing import routing |
| |
| import triton_kernels.matmul_ogs_details.opt_flags as opt_flags |
| from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue |
| from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch |
| from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig |
| from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4 |
| from triton_kernels.tensor_details import layout |
| |
| from triton_kernels.numerics import InFlexData, OutFlexData |
| from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, dequantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE |
| |
| from triton_kernels.testing import assert_close, compute_actual_scale |
| |
| from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4 |
|
|
| |
| |
| |
|
|
|
|
| def alloc_rand(shape, device, dtype, requires_grad=True): |
| if dtype.itemsize == 1: |
| tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) |
| return tmp.to(dtype).requires_grad_(requires_grad) |
| return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) |
|
|
|
|
| def alloc_rand_like(x): |
| return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) |
|
|
|
|
| def mask_indx(idx, n_expts_act): |
| idx.src_indx[idx.dst_indx[-n_expts_act:]] = -1 |
| idx.dst_indx[-n_expts_act:] = -1 |
| return idx |
|
|
|
|
| def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"): |
| logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True) |
| routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards) |
| routing_data.gate_scal = None |
| gather_idx = gather_idx if do_gather else None |
| scatter_idx = scatter_idx if do_scatter else None |
| return m, routing_data, gather_idx, scatter_idx |
|
|
|
|
| def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype, |
| has_y_gammas, requires_grad=True, device="cuda"): |
| torch.manual_seed(0) |
| assert mode in {'batched', "plain", 'ragged'} |
| in_m = m * (n_expts_act if gindx is None else 1) |
| shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k) |
| shape_batch = tuple() if mode == "plain" else (n_expts_tot // n_expt_shards, ) |
| x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad) |
| w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad) |
| bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad) |
| gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) |
| gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) |
| gs0 = gs0.detach().requires_grad_(requires_grad) |
| gs1 = gs1.detach().requires_grad_(requires_grad) |
| if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): |
| gs0 = None |
| gs1 = None |
| if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: |
| w = w.transpose(-1, -2).contiguous().transpose(-1, -2) |
| return x, w, bias, gs0, gs1 |
|
|
|
|
| |
| |
| |
|
|
|
|
| def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, device="cuda"): |
| weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp |
| |
| make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) + |
| ([val0] |
| if n_expts_tot % 2 else []), dtype=torch.float32, device=device) |
| make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device) |
| in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale) |
| ) if use_flex else InFlexData() |
| in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=weight_dtype, scale=make_tensor(scale0, scale1) |
| ) if use_flex else InFlexData() |
| out_flex_data = lambda scale, use_flex: OutFlexData(dtype=out_dtype, expected_scale=make_scalar( |
| scale), actual_scale=make_scalar(0), checksum_scale=make_scalar(0)) if use_flex else OutFlexData() |
| flex_ctx = FlexCtx( |
| lhs_data=in_flex_data(1.25, act_use_flexpoint), |
| rhs_data=in_flex_edata(1.50, 1.25, weight_use_flexpoint), |
| out_data=out_flex_data(4.00, act_use_flexpoint), |
| ) |
| return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0, |
| out_dtype=out_dtype) |
|
|
|
|
| def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config): |
| flex_ctx = precision_config.flex_ctx |
|
|
| def apply(x, scale): |
| if scale is None: |
| x = x.clone() |
| elif scale.numel() == 1: |
| x = x.float() * scale |
| else: |
| assert x.ndim == 3 |
| assert scale.numel() == x.shape[0] |
| x = x.float() * scale[:, None, None] |
| return x.detach().requires_grad_() |
|
|
| return ( |
| apply(x_tri, flex_ctx.lhs_data.scale), |
| apply(w_tri, flex_ctx.rhs_data.scale), |
| apply(bias_tri, None), |
| None if gs0_tri is None else apply(gs0_tri, None), |
| None if gs1_tri is None else apply(gs1_tri, None), |
| ) |
|
|
|
|
| def dtype_str_to_torch(dtype_str: str) -> torch.dtype: |
| return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) |
|
|
|
|
| |
| @pytest.fixture |
| def opt_flags_scope(request): |
| yield |
| opt_flags.reset_opt_flags_constraints() |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class Case: |
| m: int |
| n: int |
| k: int |
| mode: str |
| act_dtype_str: str |
| weight_dtype_str: str |
| n_expts_tot: int = 1 |
| n_expts_act: int = 1 |
| n_expt_shards: int = 1 |
| split_k: int = 1 |
| hbm_swizzling: bool = False |
| epilogue_subtile: Union[int, None] = None |
|
|
|
|
| @pytest.mark.parametrize( |
| ", ".join(f.name for f in fields(Case)), |
| [ |
| tuple(getattr(case, f.name) for f in fields(Case)) for case in [ |
| |
| Case(16, 256, 256, "ragged", "float16", "float16", 128, 4), |
| Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2), |
| Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4), |
| Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2), |
| Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), |
| Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), |
| Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1), |
| Case(16, 256, 256, "batched", "float16", "float16", 5, 1), |
| Case(16, 256, 256, "ragged", "float16", "float16", 3, 1), |
| Case(256, 256, 256, "ragged", "float16", "float16", 4, 1), |
| Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3), |
| Case(300, 400, 400, "batched", "float16", "float16", 5, 1), |
| Case(300, 400, 400, "ragged", "float16", "float16"), |
| Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"), |
| Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1), |
| Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1), |
| Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2), |
| Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4), |
| Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2), |
| Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2), |
| Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2), |
| Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2), |
| Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1), |
| Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2), |
| Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), |
| |
| Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1), |
| Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), |
| Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1), |
| Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), |
| Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), |
| Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
| Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), |
| Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
| Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4), |
| Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
| Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), |
| Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), |
| Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), |
| Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
| Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
| Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
| Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1), |
| Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9), |
| Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
| Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2), |
| Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
| Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4), |
| Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
| Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4), |
| Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), |
| Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4), |
| Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), |
| Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
| Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False), |
| Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
| Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
| Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1), |
| Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), |
| Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
| Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2), |
| Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
| Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), |
| Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
| Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4), |
| Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), |
| Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4), |
| Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), |
| |
| Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), |
| Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1), |
| Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), |
| Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2), |
| Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2), |
| Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), |
| Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1), |
| Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2), |
| Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2), |
| ] |
| ], |
| ) |
| @pytest.mark.parametrize("block_m", [16, 128]) |
| @pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ |
| (False, False, False), |
| (True, False, False), |
| (False, True, False), |
| (True, True, False), |
| (True, True, True), |
| ]) |
| @pytest.mark.parametrize("has_y_gammas", [False, True]) |
| @pytest.mark.parametrize("is_persistent", [False, True]) |
| def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot, |
| n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile, |
| device, opt_flags_scope, fresh_knobs): |
| |
| if is_cuda(): |
| if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: |
| pytest.skip("Float8 not tested on A100") |
| if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10: |
| pytest.skip("float16 x mx not supported with cuda capability >= 10") |
| if weight_dtype_str.startswith("mx"): |
| if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10: |
| pytest.skip("float8 x mx not supported with cuda capability < 10") |
| if act_dtype_str == "mxfloat8_e4m3fn": |
| if is_persistent: |
| pytest.skip("mx x mx not supported with persistent kernel") |
| if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9: |
| pytest.skip("Not enough memory on A100") |
|
|
| elif is_hip(): |
| if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4(): |
| pytest.skip("float8 x mx only supported on CDNA4") |
| if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str: |
| pytest.skip("NYI: float8 x mxfloat8 not tested on AMD GPU") |
| if act_dtype_str.startswith("mx") and weight_dtype_str.startswith("mx"): |
| pytest.skip("NYI: mx x mx not tested on AMD GPU") |
| if is_persistent: |
| pytest.skip("NYI: Persistent kernel not supported on AMD GPU") |
| if split_k > 1: |
| pytest.skip("splitK hasn't been fully tested on AMD GPU.") |
|
|
| if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3(): |
| pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") |
|
|
| if fused_scatter and split_k > 1: |
| pytest.skip("fused scatter scratchpad not supported with split_k") |
| if hbm_swizzling: |
| if is_hip(): |
| pytest.skip("NYI. HBM swizzling just implemented for CUDA.") |
| if is_cuda(): |
| if torch.cuda.get_device_capability()[0] < 9: |
| pytest.skip("NYI. Ampere swizzling.") |
| if torch.cuda.get_device_capability()[0] < 10: |
| if "mxfloat4" not in weight_dtype_str: |
| pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.") |
| if k % 64 != 0 or n % 64 != 0: |
| |
| pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).") |
|
|
| |
| test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str) |
|
|
| torch.manual_seed(0) |
|
|
| block_k = None |
| if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: |
| |
| |
| |
| block_k = 256 |
|
|
| constraints = { |
| "block_m": block_m, |
| "block_k": block_k, |
| "split_k": split_k, |
| "fused_scatter": fused_scatter, |
| "is_persistent": is_persistent, |
| "epilogue_subtile": epilogue_subtile, |
| } |
| opt_flags.update_opt_flags_constraints(constraints) |
|
|
| weight_mxfp = weight_dtype_str.startswith("mx") |
| if weight_mxfp: |
| weight_dtype_str = weight_dtype_str[2:] |
| act_mxfp8 = act_dtype_str.startswith("mx") |
| act_is_float8 = act_dtype_str.startswith("float8") |
| if act_mxfp8: |
| act_dtype_str = act_dtype_str[2:] |
| dequantize_mxfp8_spec = FnSpecs( |
| FnName.DEQUANTIZE_MXFP8.name, dequantize_mxfp8_fn, (), () |
| ) |
|
|
| test_bwd = False |
| weight_dtype = dtype_str_to_torch(weight_dtype_str) |
| act_dtype = dtype_str_to_torch(act_dtype_str) |
| precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, n_expts_tot // n_expt_shards, device=device) |
| |
| if mode == "ragged": |
| m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, |
| device=device) |
| else: |
| rdata = gindx = sindx = None |
| x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, |
| n_expt_shards, mode, torch.bfloat16 if act_mxfp8 else act_dtype, |
| torch.bfloat16 if weight_mxfp else weight_dtype, |
| has_y_gammas, requires_grad=test_bwd, device=device) |
| x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt) |
|
|
| if w_tri.shape[0] == 1: |
| |
| w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd) |
| w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd) |
|
|
| if weight_mxfp: |
| mx_axis = w_tri.ndim - 2 |
| |
| w_layout, w_layout_opts = layout.StridedLayout, dict() |
| w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict() |
| if hbm_swizzling and "float4" in weight_dtype_str: |
| w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis) |
| w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( |
| mx_axis=mx_axis, num_warps=8) |
| |
| w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) |
| w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) |
| w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype |
| w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) |
| w_scale_tri = wrap_torch_tensor(w_scale_tri) |
| |
| w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) |
| w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) |
| precision_opt.weight_scale = w_scale_tri |
| epilogue = None |
| if act_mxfp8: |
| x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1) |
| x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1) |
| is_input_batched = x_tri.ndim == 3 |
| y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape |
| n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0] |
| y_shape = (y_shape[0], n_rows, w_tri.shape[-1]) |
| if sindx is None or mode == "batched": |
| if not is_input_batched: |
| y_shape = (y_shape[1], y_shape[2]) |
| else: |
| y_shape = (n_rows // rdata.n_expts_act, y_shape[-1]) |
| y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),) |
| y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device) |
| precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale) |
| epilogue = Epilogue(dequantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0) |
| else: |
| y_scale = None |
|
|
| if test_launch_metadata: |
|
|
| def _clobber(t, used_mask): |
| |
| |
| if len(used_mask) == 1: |
| return |
| elif t.element_size() == 1: |
| t.view(torch.int8)[~used_mask] = 127 |
| else: |
| t[~used_mask] = torch.inf |
|
|
| if rdata is not None: |
| n_tokens = rdata.expt_hist.sum().item() |
| used_expts = (rdata.expt_hist > 0) |
| _clobber(w_tri, used_expts) |
| n_w_bytes = used_expts.sum().item() * n * k * w_tri.element_size() |
| else: |
| n_tokens = m |
| n_w_bytes = w_tri.numel() * w_tri.element_size() |
|
|
| if gindx is not None: |
| used_x_rows = (gindx.dst_indx.view(-1, n_expts_act) != -1).any(dim=1) |
| _clobber(x_tri, used_x_rows) |
| n_x_bytes = used_x_rows.sum().item() * k * x_tri.element_size() |
| elif rdata is not None: |
| n_x_bytes = n_tokens * k * x_tri.element_size() |
| else: |
| n_x_bytes = x_tri.numel() * x_tri.element_size() |
|
|
| nbytes = None |
|
|
| def _hook(launch_metadata): |
| nonlocal nbytes |
| metadata = launch_metadata.get() |
| if "matmul_ogs" in metadata["name"]: |
| nbytes = metadata["bytes"] |
|
|
| triton.knobs.runtime.launch_enter_hook = _hook |
|
|
| if mode == "batched": |
| rdata, gindx, sindx = None, None, None |
| flex = precision_opt.flex_ctx |
|
|
| |
| try: |
| tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue) |
| except (opt_flags.InapplicableConstraint, NotImplementedError): |
| pytest.skip("inapplicable opt_flags constraint") |
| |
| sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1 |
| sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1 |
| y_scale = flex.out_data.expected_scale if act_is_float8 else 1 |
|
|
| if test_launch_metadata: |
| if gindx is not None: |
| n_y_bytes = (gindx.src_indx != -1).sum().item() * n * tri_y.element_size() |
| elif rdata is not None: |
| n_y_bytes = n_tokens * n * tri_y.element_size() |
| else: |
| n_y_bytes = tri_y.numel() * tri_y.element_size() |
| assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes |
| triton.knobs.runtime.launch_enter_hook = None |
|
|
| def round_x(x, idx): |
| return x.to(act_dtype).to(torch.float32) if sep_gather else x |
|
|
| round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y |
| ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, |
| rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device) |
| scale = lambda val, scal: val if scal is None else val / scal |
| if n_expt_shards > 1: |
| if do_scatter: |
| indx = sindx.dst_indx[sindx.dst_indx != -1] |
| ref_y = ref_y[indx // n_expts_act, :] |
| if act_is_float8: |
| tri_y = tri_y.view(torch.int8) |
| tri_y = tri_y[indx // n_expts_act, :] |
| if act_is_float8: |
| tri_y = tri_y.view(act_dtype) |
| else: |
| n_rows = rdata.expt_hist.sum() |
| assert n_rows > 0 |
| ref_y = ref_y[:n_rows] |
| tri_y = tri_y[:n_rows] |
| if act_mxfp8: |
| tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, dtype=torch.bfloat16, axis=-1).to(ref_y.dtype) |
| ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1) |
| ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1) |
| maxtol = 4e-1 |
| rmstol = 4e-2 |
| else: |
| maxtol = None |
| rmstol = None |
| assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol) |
|
|
| if act_is_float8: |
| tri_y_scale = flex.out_data.actual_scale.clone() |
| ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype) |
| assert (ref_y_scale - |
| tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}" |
|
|
|
|
| def test_set_idle_sms(): |
| if not is_cuda(): |
| pytest.skip("Only supported on CUDA") |
| from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags |
| num_idle_sms = 24 |
| matmul_ogs_set_idle_sms(num_idle_sms) |
| flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \ |
| 1024, 1024, 1024, None, True, False, 1) |
| assert flags.idle_sms == num_idle_sms |
|
|
|
|
| @pytest.mark.parametrize("m, n, k, mode", [ |
| (1200, 704, 608, "ragged"), |
| (800, 800, 400, "batched"), |
| ]) |
| @pytest.mark.parametrize("split_k", [1, 2]) |
| @pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ |
| (False, False, False), |
| (True, False, False), |
| (False, True, False), |
| (True, True, False), |
| (True, True, True), |
| ]) |
| @pytest.mark.parametrize("is_persistent, epilogue_subtile", [ |
| (False, None), |
| (True, 1), |
| (True, 4), |
| ]) |
| @pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [ |
| (1.1, 1.4), |
| (1.0, 1.2), |
| (0.7, 1.0), |
| ]) |
| def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile, |
| swiglu_alpha, swiglu_limit, device, opt_flags_scope): |
| if fused_scatter and split_k > 1: |
| pytest.skip("fused scatter scratchpad not supported with split_k") |
| torch.manual_seed(0) |
| constraints = { |
| "is_persistent": is_persistent, |
| "epilogue_subtile": epilogue_subtile, |
| "fused_scatter": fused_scatter, |
| "split_k": split_k, |
| } |
| n_expts_tot, n_expts_act, n_expt_shards = 1, 1, 1 |
| opt_flags.update_opt_flags_constraints(constraints) |
|
|
| weight_dtype, act_dtype = torch.float16, torch.float16 |
| if mode == "ragged": |
| m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, |
| device=device) |
| else: |
| rdata = gindx = sindx = None |
|
|
| precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot // n_expt_shards, device=device) |
| x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, |
| act_dtype, weight_dtype, False, requires_grad=False, device=device) |
|
|
| if mode == "batched": |
| rdata, gindx, sindx = None, None, None |
|
|
| try: |
| a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha, |
| precision_config=SwiGLUPrecisionConfig(swiglu_limit)) |
| b = matmul_ogs( |
| x, w, bias, rdata, gindx, sindx, precision_opt, |
| fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), |
| (swiglu_alpha, swiglu_limit), 2)) |
| except opt_flags.InapplicableConstraint: |
| pytest.skip("inapplicable constraint") |
| assert_close(a, b) |
|
|