| import pytest |
| import torch |
|
|
| from triton_kernels.numerics_details.mxfp import ( |
| DequantScaleRoundingMode, |
| downcast_to_mxfp, |
| downcast_to_mxfp_torch, |
| get_max_quant_val, |
| upcast_from_mxfp, |
| upcast_from_mxfp_torch, |
| ) |
| from triton_kernels.testing import assert_close, assert_equal |
|
|
|
|
| def dtype_str_to_torch(dtype_str: str) -> torch.dtype: |
| return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) |
|
|
|
|
| @pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16"]) |
| def test_mxfp4_rounding_cases(dst_dtype): |
| dst_dtype = dtype_str_to_torch(dst_dtype) |
| x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3]).cuda().bfloat16().view(1, -1, 1) |
| quant, scale = downcast_to_mxfp(x, torch.uint8, axis=1) |
| dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1) |
| assert dequant.flatten().tolist() == [6, 0, 0, 0.5, 1.0, 1.0, 1.0, 1.5], f"{dequant=}" |
|
|
| quant_torch, scale_torch = downcast_to_mxfp_torch(x, torch.uint8, axis=1) |
| assert_equal(quant_torch, quant) |
| assert_equal(scale_torch, scale) |
|
|
| dequant_torch = upcast_from_mxfp_torch(quant_torch, scale_torch, dst_dtype, axis=1) |
| assert_equal(dequant_torch, dequant) |
|
|
|
|
| @pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"]) |
| @pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16"]) |
| def test_mxfp_quant_dequant(src_dtype, dst_dtype): |
| if "float8" in src_dtype and torch.cuda.get_device_capability()[0] < 9: |
| pytest.skip("Float8 not tested on A100") |
| limit_range = src_dtype == "float8_e5m2" and dst_dtype == "float16" |
|
|
| |
| |
| src_dtype = dtype_str_to_torch(src_dtype) |
| dst_dtype = dtype_str_to_torch(dst_dtype) |
| max_val = get_max_quant_val(src_dtype) |
| if limit_range: |
| |
| max_val = 128 |
|
|
| |
| pos_vals = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, max_val], device="cuda", dtype=dst_dtype) |
| neg_vals = -pos_vals |
| k_dim = torch.cat([pos_vals, neg_vals]) |
| k_dim = k_dim.reshape([k_dim.shape[0], 1]) |
|
|
| |
| |
| powers = torch.arange(-8, 8, device="cuda", dtype=dst_dtype) |
| scales = 2**powers |
| scales = scales.reshape([1, powers.shape[0]]) |
| weight = k_dim * scales |
| weight = weight.repeat((9, 32)) |
| weight = weight.reshape([1, weight.shape[0], weight.shape[1]]) |
| weight = weight.mT.contiguous().mT |
| quant, scale = downcast_to_mxfp(weight, src_dtype, axis=1) |
| dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1) |
| assert_equal(weight, dequant) |
|
|
|
|
| |
| @pytest.mark.parametrize( |
| "shape, axis, quant_dtype, rounding_mode", |
| [ |
| ((3, 4096, 1024), 1, "float4_e2m1", DequantScaleRoundingMode.ROUND_UP), |
| ((10, 254, 60), 0, "float4_e2m1", DequantScaleRoundingMode.ROUND_DOWN), |
| ((1, 320, 160), 2, "float8_e5m2", DequantScaleRoundingMode.ROUND_UP), |
| ((2, 16, 512), -1, "float8_e4m3fn", DequantScaleRoundingMode.ROUND_DOWN), |
| ], |
| ) |
| |
| @pytest.mark.parametrize("dequant_dtype", ["float16", "bfloat16"]) |
| def test_mxfp_casting( |
| shape: tuple[int, ...], |
| axis: int, |
| quant_dtype: str, |
| dequant_dtype: str, |
| rounding_mode: DequantScaleRoundingMode, |
| ): |
| if "float8" in quant_dtype and torch.cuda.get_device_capability()[0] < 9: |
| pytest.skip("Float8 not tested on A100") |
| quant_torch_type = dtype_str_to_torch(quant_dtype) |
| dequant_torch_type = dtype_str_to_torch(dequant_dtype) |
| |
| x = torch.randn(shape, device="cuda", dtype=dequant_torch_type) |
|
|
| |
| quant, scale = downcast_to_mxfp(x, quant_torch_type, axis, DEQUANT_SCALE_ROUNDING_MODE=rounding_mode) |
| quant_torch, scale_torch = downcast_to_mxfp_torch(x, quant_torch_type, axis, |
| DEQUANT_SCALE_ROUNDING_MODE=rounding_mode) |
|
|
| assert_equal(quant_torch, quant) |
| assert_equal(scale_torch, scale) |
| assert_equal(1, quant.stride(axis)) |
| assert_equal(1, quant_torch.stride(axis)) |
|
|
| |
| dequant = upcast_from_mxfp(quant, scale, dequant_torch_type, axis) |
| dequant_torch = upcast_from_mxfp_torch(quant_torch, scale_torch, dequant_torch_type, axis) |
| assert_equal(dequant, dequant_torch) |
|
|
| |
| assert_close(x, dequant, maxtol=0.5, rmstol=0.15) |
|
|