# SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional, Tuple import triton import torch from .._triton_kernels.quant.quant import ( _static_per_tensor_quant_fp8_i8_kernel, _dynamic_per_tensor_quant_fp8_i8_kernel, _dynamic_per_token_quant_fp8_i8_kernel, _dynamic_mxfp4_quant_kernel, _mxfp4_quant_op, _dynamic_mxfp8_quant_kernel, _mxfp8_quant_op, _fp8_legacy_to_mxfp8_kernel, _dynamic_nvfp4_quant_kernel, _nvfp4_quant_op, ) from ..utils.logger import AiterTritonLogger from ..utils.types import e4m3_dtype __all__ = [ "static_per_tensor_quant_fp8_i8", "dynamic_per_tensor_quant_fp8_i8", "dynamic_per_token_quant_fp8_i8", "dynamic_mxfp4_quant", "_mxfp4_quant_op", "dynamic_mxfp8_quant", "fp8_legacy_to_mxfp8", "_mxfp8_quant_op", "dynamic_nvfp4_quant", "_nvfp4_quant_op", ] _MXFP8_QUANT_BLOCK_SIZE = 32 _MXFP8_LEGACY_BLOCK_SIZE = 128 _LOGGER = AiterTritonLogger() def static_per_tensor_quant_fp8_i8( qx: torch.Tensor, x_in: torch.Tensor, scale_in: torch.Tensor ): """ Quantizes tensor using the provided scale to int8 or fp8 Parameters: - qx: Output tensor of same shape as x_in. Must be fp8 or int8 dtype and allocated by the caller - x_in: Input tensor of shape (M, N). - scale_in: Input Scale tensor of shape (1,) and dtype fp32 Returns: - qx: Quantized output values. """ _LOGGER.info(f"STAIC_PER_TENSOR_QUANT_FP8_I8: x={tuple(x_in.shape)}") assert scale_in.numel() == 1 # only single scale value rows = x_in.shape[0] cols = x_in.shape[1] NUM_COL_POW2 = triton.next_power_of_2(cols) grid = lambda meta: (rows,) # noqa: E731 _static_per_tensor_quant_fp8_i8_kernel[grid]( qx, x_in, scale_in, cols, x_in.stride(0), NUM_COL_POW2=NUM_COL_POW2 ) return qx def dynamic_per_tensor_quant_fp8_i8( qx: torch.Tensor, x_in: torch.Tensor, scale_out: torch.Tensor ): """ Calculate per tensor scale and then uses the scale to quantize input tensor to fp8 or int8 Parameters: - x_in: Input tensor of shape (M, N). - qx: Output tensor of same shape as x_in. Must be fp8 or int8 dtype and allocated by the caller - scale_out: Output scale tensor of shape (1,), dtype fp32 and allocated by the caller Returns: - qx: Quantized output values of shape (M, N) with dtype fp8 or int8 - scale_out: Single scale value of shape (1,) """ _LOGGER.info(f"DYNAMIC_PER_TENSOR_QUANT_FP8_I8: x={tuple(x_in.shape)}") rows = x_in.shape[0] cols = x_in.shape[1] NUM_COL_POW2 = triton.next_power_of_2(cols) grid = lambda meta: (rows,) # noqa: E731 _dynamic_per_tensor_quant_fp8_i8_kernel[grid]( x_in, scale_out, cols, x_in.stride(0), NUM_COL_POW2=NUM_COL_POW2, DTYPE_MAX=( torch.finfo(qx.dtype).max if torch.is_floating_point(qx) else torch.iinfo(qx.dtype).max ), ) _static_per_tensor_quant_fp8_i8_kernel[grid]( qx, x_in, scale_out, cols, x_in.stride(0), NUM_COL_POW2=NUM_COL_POW2 ) return qx, scale_out def dynamic_per_token_quant_fp8_i8( qx: torch.Tensor, x_in: torch.Tensor, scale_out: torch.Tensor, ): """ Quantizes tensor using the provided scale Parameters: - x_in: Input tensor of shape (M, N). - dtype_max: Optional parameter which specifies the max value of the dtype of x_in. - qx: Output tensor of same shape as x_in. Must be fp8 dtype and allocated by the caller - scale_out: Output scale tensor of shape (M,) dtype fp32 and allocated by the caller Returns: - qx: Quantized output values. - scale_out: Scale tensor of shape (M, ) """ _LOGGER.info(f"DYNAMIC_PER_TOKEN_QUANT_FP8_I8: x={tuple(x_in.shape)}") rows = x_in.shape[0] cols = x_in.shape[1] NUM_COL_POW2 = triton.next_power_of_2(cols) grid = lambda meta: (rows,) # noqa: E731 _dynamic_per_token_quant_fp8_i8_kernel[grid]( qx, scale_out, x_in, cols, x_in.stride(0), NUM_COL_POW2=NUM_COL_POW2, DTYPE_MAX=( torch.finfo(qx.dtype).max if torch.is_floating_point(qx) else torch.iinfo(qx.dtype).max ), ) return qx, scale_out def dynamic_mxfp4_quant( x: torch.Tensor, scaling_mode: str = "even" ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to MX FP4 format. Args: x: The input tensor, typically fp16 or bf16. scaling_mode: The method to calculate MX block scaling. - "even" (default): `even_round` in `quark.torch.quantization.utils`. - etc. Returns: A tuple of (x_fp4, blockscale_e8m0). """ _LOGGER.info(f"DYNAMIC_MXFP4_QUANT: x={tuple(x.shape)}") # Assume x is 2D-Tensor for now M, N = x.shape assert (N // 2) % 2 == 0 # This is fixed by spec for MXFP4. Do not tune this. MXFP4_QUANT_BLOCK_SIZE = 32 x_fp4 = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) blockscale_e8m0 = torch.empty( ((N + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE, M), dtype=torch.uint8, device=x.device, ).T # for large N values if M <= 32: NUM_ITER = 1 BLOCK_SIZE_M = triton.next_power_of_2(M) BLOCK_SIZE_N = 32 NUM_WARPS = 1 NUM_STAGES = 1 else: NUM_ITER = 4 BLOCK_SIZE_M = 64 BLOCK_SIZE_N = 64 NUM_WARPS = 4 NUM_STAGES = 2 if N <= 16384: BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 128 # for small N values if N <= 1024: NUM_ITER = 1 NUM_STAGES = 1 NUM_WARPS = 4 BLOCK_SIZE_N = min(256, triton.next_power_of_2(N)) # BLOCK_SIZE_N needs to be multiple of 32 BLOCK_SIZE_N = max(32, BLOCK_SIZE_N) BLOCK_SIZE_M = min(8, triton.next_power_of_2(M)) grid = ( triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N * NUM_ITER), ) _dynamic_mxfp4_quant_kernel[grid]( x, x_fp4, blockscale_e8m0, *x.stride(), *x_fp4.stride(), *blockscale_e8m0.stride(), M=M, N=N, MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, SCALING_MODE=0, NUM_ITER=NUM_ITER, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, NUM_STAGES=NUM_STAGES, num_warps=NUM_WARPS, waves_per_eu=0, num_stages=1, ) return (x_fp4, blockscale_e8m0) def dynamic_mxfp8_quant( x: torch.Tensor, scale: Optional[torch.Tensor] = None, quant_dtype: torch.dtype = torch.float8_e4m3fn, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Per-1x32 MXFP8 quantization (e8m0 scale + FP8 e4m3 values). Args: x: Input tensor (..., K). Typically bf16 or fp16. K % 32 == 0. scale: Pre-allocated scale tensor (M, K // 32) uint8. Optional. quant_dtype: FP8 dtype to cast quantized values to. On MI3xx torch.float8_e4m3fnuz is the canonical FP8 e4m3 type. torch.float8_e4m3fn is acceptable on hardware that supports it. Returns: Tuple of: y: FP8 tensor of shape x.shape. s: e8m0 (uint8) scale tensor of shape (..., K // 32). """ assert x.dim() >= 2, f"x must be at least 2D, got {x.dim()}" orig_shape = x.shape K = orig_shape[-1] assert ( K % _MXFP8_QUANT_BLOCK_SIZE == 0 ), f"last dim K={K} must be a multiple of {_MXFP8_QUANT_BLOCK_SIZE}" x2d = x.reshape(-1, K).contiguous() M = x2d.shape[0] Ns = K // _MXFP8_QUANT_BLOCK_SIZE # number of scales per row y = torch.empty((M, K), dtype=quant_dtype, device=x.device) if scale is None: scale = torch.empty((M, Ns), dtype=torch.uint8, device=x.device) else: assert scale.shape == (M, Ns), f"scale shape {scale.shape} != ({M},{Ns})" assert scale.dtype == torch.uint8 BLOCK_SIZE_N = triton.next_power_of_2(K) NUM_PRGMS = M grid = (NUM_PRGMS,) _dynamic_mxfp8_quant_kernel[grid]( x2d, y, scale, M, K, x2d.stride(0), x2d.stride(1), y.stride(0), y.stride(1), scale.stride(0), scale.stride(1), BLOCK_SIZE_N=BLOCK_SIZE_N, QUANT_BLOCK_SIZE=_MXFP8_QUANT_BLOCK_SIZE, NUM_PRGMS=NUM_PRGMS, ) y = y.view(*orig_shape[:-1], K) s = scale.view(*orig_shape[:-1], Ns) return y, s def fp8_legacy_to_mxfp8( x_fnuz: torch.Tensor, x_scale_fp32: torch.Tensor, y_fn: Optional[torch.Tensor] = None, y_scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Transcode (FP8 e4m3fnuz, fp32 1x128 scale) -> (FP8 e4m3fn, e8m0 1x32 scale) in a single Triton launch. Replaces the Python dequant+requant cascade used when MXFP8 path receives legacy-formatted (FP8 + fp32 1x128) inputs. Args: x_fnuz: FP8 e4m3fnuz tensor of shape (M, N), N % 32 == 0. x_scale_fp32: fp32 scale of shape (M, N // 128). y_fn: optional preallocated output FP8 e4m3fn tensor. y_scale: optional preallocated uint8 e8m0 scale tensor. Returns: y_fn (M, N) fp8 e4m3fn, y_scale (M, N // 32) uint8 e8m0. """ assert x_fnuz.dim() == 2, f"x must be 2D, got {x_fnuz.dim()}" M, N = x_fnuz.shape assert N % _MXFP8_QUANT_BLOCK_SIZE == 0 assert N % _MXFP8_LEGACY_BLOCK_SIZE == 0 assert x_scale_fp32.shape == ( M, N // _MXFP8_LEGACY_BLOCK_SIZE, ), f"x_scale_fp32 shape {x_scale_fp32.shape} != ({M},{N // _MXFP8_LEGACY_BLOCK_SIZE})" Ns = N // _MXFP8_QUANT_BLOCK_SIZE if y_fn is None: y_fn = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x_fnuz.device) if y_scale is None: y_scale = torch.empty((M, Ns), dtype=torch.uint8, device=x_fnuz.device) BLOCK_SIZE_M = 1 grid = (triton.cdiv(M, BLOCK_SIZE_M), Ns) _fp8_legacy_to_mxfp8_kernel[grid]( x_fnuz, x_scale_fp32, y_fn, y_scale, M, N, x_fnuz.stride(0), x_fnuz.stride(1), x_scale_fp32.stride(0), x_scale_fp32.stride(1), y_fn.stride(0), y_fn.stride(1), y_scale.stride(0), y_scale.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, QUANT_BLOCK_SIZE=_MXFP8_QUANT_BLOCK_SIZE, LEGACY_BLOCK_SIZE=_MXFP8_LEGACY_BLOCK_SIZE, ) return y_fn, y_scale def dynamic_nvfp4_quant( x: torch.Tensor, global_scale: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize a tensor to MX FP4 format. Args: x: The input tensor, typically fp16 or bf16. Returns: A tuple of (x_fp4, blockscale_e4m3). """ _LOGGER.info(f"DYNAMIC_NVFP4_QUANT: x={tuple(x.shape)}") # Assume x is 2D-Tensor for now M, N = x.shape assert (N // 2) % 2 == 0 # This is fixed by spec for MXFP4. Do not tune this. NVFP4_QUANT_BLOCK_SIZE = 16 x_fp4 = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) blockscale_e4m3 = torch.empty( ((N + NVFP4_QUANT_BLOCK_SIZE - 1) // NVFP4_QUANT_BLOCK_SIZE, M), dtype=e4m3_dtype, device=x.device, ).T # for large N values if M <= 32: NUM_ITER = 1 BLOCK_SIZE_M = triton.next_power_of_2(M) BLOCK_SIZE_N = 32 NUM_WARPS = 1 NUM_STAGES = 1 else: NUM_ITER = 4 BLOCK_SIZE_M = 64 BLOCK_SIZE_N = 64 NUM_WARPS = 4 NUM_STAGES = 2 if N <= 16384: BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 128 # for small N values if N <= 1024: NUM_ITER = 1 NUM_STAGES = 1 NUM_WARPS = 4 BLOCK_SIZE_N = min(256, triton.next_power_of_2(N)) # BLOCK_SIZE_N needs to be multiple of 32 BLOCK_SIZE_N = max(32, BLOCK_SIZE_N) BLOCK_SIZE_M = min(8, triton.next_power_of_2(M)) grid = ( triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N * NUM_ITER), ) _dynamic_nvfp4_quant_kernel[grid]( x, x_fp4, blockscale_e4m3, *x.stride(), *x_fp4.stride(), *blockscale_e4m3.stride(), M=M, N=N, NVFP4_QUANT_BLOCK_SIZE=NVFP4_QUANT_BLOCK_SIZE, NUM_ITER=NUM_ITER, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, NUM_STAGES=NUM_STAGES, num_warps=NUM_WARPS, waves_per_eu=0, num_stages=1, ) return x_fp4, blockscale_e4m3