| | """ |
| | NVFP4 kernels for DeepSeek inference on SM120 (RTX Pro 6000 Blackwell). |
| | |
| | This module provides NVFP4 equivalents for the FP8 kernels in kernel.py: |
| | - nvfp4_gemm: Block-scaled NVFP4 matrix multiplication |
| | - act_quant_nvfp4: Quantize activations to NVFP4 |
| | |
| | Weight format: |
| | weight: [N, K/2] packed uint8 (2 FP4 E2M1 per byte) |
| | weight_scale: [N, K/16] FP8 E4M3 per-block scale |
| | weight_scale_2: [1] FP32 global scale |
| | """ |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| | from triton.tools.tensor_descriptor import TensorDescriptor |
| | from typing import Tuple, Optional |
| | import functools |
| |
|
| |
|
| | |
| | NVFP4_LUT = torch.tensor([ |
| | 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, |
| | -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, |
| | ], dtype=torch.float32) |
| |
|
| |
|
| | @functools.lru_cache(maxsize=8) |
| | def _get_nvfp4_lut(device_str: str) -> torch.Tensor: |
| | """Get NVFP4 lookup table on specified device (cached). |
| | |
| | Args: |
| | device_str: Device string (e.g., 'cpu', 'cuda:0') |
| | |
| | Returns: |
| | NVFP4 lookup table on the specified device |
| | """ |
| | return NVFP4_LUT.to(device=device_str) |
| |
|
| | |
| | NVFP4_BLOCK_SIZE = 16 |
| |
|
| |
|
| | def get_nvfp4_configs(): |
| | """Get kernel configs appropriate for SM120.""" |
| | capability = torch.cuda.get_device_capability()[0] |
| | if capability == 12: |
| | return { |
| | "BLOCK_SIZE_M": 128, |
| | "BLOCK_SIZE_N": 128, |
| | "BLOCK_SIZE_K": 128, |
| | "num_stages": 2, |
| | "VEC_SIZE": 16, |
| | } |
| | else: |
| | return { |
| | "BLOCK_SIZE_M": 128, |
| | "BLOCK_SIZE_N": 256, |
| | "BLOCK_SIZE_K": 256, |
| | "num_stages": 4, |
| | "VEC_SIZE": 16, |
| | } |
| |
|
| |
|
| | def linear_to_triton_scale( |
| | scale_linear: torch.Tensor, |
| | M: int, |
| | K: int, |
| | VEC_SIZE: int = 16, |
| | ) -> torch.Tensor: |
| | """ |
| | Convert linear scale format to Triton's 5D TMA layout. |
| | |
| | Args: |
| | scale_linear: [M, K // VEC_SIZE] FP8 E4M3 scales |
| | M: Number of rows |
| | K: Number of columns |
| | VEC_SIZE: Number of elements per scale block |
| | |
| | Returns: |
| | scale_triton: [1, M//128, K//64, 2, 256] for TMA |
| | """ |
| | assert scale_linear.shape == (M, K // VEC_SIZE), \ |
| | f"Expected shape {(M, K // VEC_SIZE)}, got {scale_linear.shape}" |
| |
|
| | num_m_chunks = M // 128 |
| | num_k_chunks = (K // VEC_SIZE) // 4 |
| |
|
| | |
| | scale = scale_linear.reshape(num_m_chunks, 4, 32, num_k_chunks, 4) |
| | scale = scale.permute(0, 3, 2, 1, 4) |
| | scale = scale.reshape(num_m_chunks, num_k_chunks, 32, 16) |
| | scale = scale.reshape(1, num_m_chunks, num_k_chunks, 2, 256) |
| |
|
| | return scale.contiguous() |
| |
|
| |
|
| | def dequantize_nvfp4( |
| | packed: torch.Tensor, |
| | scale: torch.Tensor, |
| | scale_2: torch.Tensor, |
| | dtype: torch.dtype = torch.bfloat16, |
| | ) -> torch.Tensor: |
| | """ |
| | Dequantize NVFP4 tensor to float for reference/fallback. |
| | |
| | Args: |
| | packed: [M, K/2] uint8 packed tensor |
| | scale: [M, K/16] FP8 E4M3 per-block scales |
| | scale_2: [1] FP32 global scale |
| | dtype: Output dtype |
| | |
| | Returns: |
| | tensor: [M, K] dequantized tensor |
| | """ |
| | M, K_half = packed.shape |
| | K = K_half * 2 |
| | block_size = NVFP4_BLOCK_SIZE |
| |
|
| | |
| | low = packed & 0x0F |
| | high = (packed >> 4) & 0x0F |
| | fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K) |
| |
|
| | |
| | lut = _get_nvfp4_lut(str(packed.device)) |
| | tensor = lut[fp4_tensor.long()] |
| |
|
| | |
| | scale_f32 = scale.to(torch.float32) |
| | tensor = tensor.reshape(M, K // block_size, block_size) |
| | tensor = tensor * scale_f32.unsqueeze(-1) * scale_2 |
| | tensor = tensor.reshape(M, K) |
| |
|
| | return tensor.to(dtype) |
| |
|
| |
|
| | def nvfp4_gemm_dequant( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | weight_scale: torch.Tensor, |
| | weight_scale_2: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | NVFP4 GEMM via dequantization fallback. |
| | |
| | This is a simple but slow implementation that dequantizes weights |
| | to bfloat16 and uses standard matmul. Use for testing/validation. |
| | |
| | Args: |
| | x: Input activation [M, K] in bfloat16 |
| | weight: NVFP4 weight [N, K/2] packed uint8 |
| | weight_scale: Per-block scales [N, K/16] FP8 E4M3 |
| | weight_scale_2: Global scale [1] FP32 |
| | |
| | Returns: |
| | y: Output [M, N] in bfloat16 |
| | """ |
| | N, K_half = weight.shape |
| | K = K_half * 2 |
| |
|
| | |
| | weight_bf16 = dequantize_nvfp4(weight, weight_scale, weight_scale_2, dtype=torch.bfloat16) |
| |
|
| | |
| | return torch.matmul(x, weight_bf16.T) |
| |
|
| |
|
| | @triton.jit |
| | def nvfp4_gemm_kernel( |
| | a_desc, |
| | a_scale_desc, |
| | b_desc, |
| | b_scale_desc, |
| | c_desc, |
| | M: tl.constexpr, |
| | N: tl.constexpr, |
| | K: tl.constexpr, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | VEC_SIZE: tl.constexpr, |
| | rep_m: tl.constexpr, |
| | rep_n: tl.constexpr, |
| | rep_k: tl.constexpr, |
| | NUM_STAGES: tl.constexpr, |
| | ): |
| | """Triton NVFP4 block-scaled GEMM kernel.""" |
| | pid = tl.program_id(axis=0) |
| | num_pid_m = tl.cdiv(M, BLOCK_M) |
| | pid_m = pid % num_pid_m |
| | pid_n = pid // num_pid_m |
| |
|
| | offs_am = pid_m * BLOCK_M |
| | offs_bn = pid_n * BLOCK_N |
| | offs_k_a = 0 |
| | offs_k_b = 0 |
| | offs_scale_m = pid_m * rep_m |
| | offs_scale_n = pid_n * rep_n |
| | offs_scale_k = 0 |
| |
|
| | c0 = tl.zeros((1,), dtype=tl.int32)[0] |
| |
|
| | accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): |
| | a = a_desc.load([offs_am, offs_k_a]) |
| | b = b_desc.load([offs_bn, offs_k_b]) |
| | scale_a = a_scale_desc.load([c0, offs_scale_m, offs_scale_k, c0, c0]) |
| | scale_b = b_scale_desc.load([c0, offs_scale_n, offs_scale_k, c0, c0]) |
| |
|
| | scale_a = scale_a.reshape(rep_m, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE) |
| | scale_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE) |
| |
|
| | accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator) |
| |
|
| | offs_k_a += BLOCK_K // 2 |
| | offs_k_b += BLOCK_K // 2 |
| | offs_scale_k += rep_k |
| |
|
| | c_desc.store([offs_am, offs_bn], accumulator.to(tl.float16)) |
| |
|
| |
|
| | def nvfp4_gemm( |
| | a: torch.Tensor, |
| | a_scale: torch.Tensor, |
| | b: torch.Tensor, |
| | b_scale: torch.Tensor, |
| | b_scale_2: Optional[torch.Tensor] = None, |
| | a_scale_2: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Perform NVFP4 GEMM using Triton kernel: y = a @ b.T |
| | |
| | For weight-only quantization (common case): |
| | - a: bfloat16 activation [M, K] |
| | - b: NVFP4 weight [N, K/2] packed |
| | - b_scale: [N, K/16] FP8 E4M3 |
| | - b_scale_2: [1] FP32 global scale |
| | |
| | Args: |
| | a: Activation tensor [M, K] (bfloat16 or NVFP4 packed) |
| | a_scale: Activation scale [M, K/16] (or None for bfloat16 input) |
| | b: Weight tensor [N, K/2] packed uint8 |
| | b_scale: Weight per-block scale [N, K/16] FP8 E4M3 |
| | b_scale_2: Weight global scale [1] FP32 |
| | a_scale_2: Activation global scale [1] FP32 (optional) |
| | |
| | Returns: |
| | y: Output [M, N] |
| | """ |
| | |
| | if a.dtype == torch.uint8: |
| | M, K_half = a.shape |
| | K = K_half * 2 |
| | else: |
| | M, K = a.shape |
| |
|
| | N = b.shape[0] |
| |
|
| | |
| | configs = get_nvfp4_configs() |
| | BLOCK_M = configs["BLOCK_SIZE_M"] |
| | BLOCK_N = configs["BLOCK_SIZE_N"] |
| | BLOCK_K = configs["BLOCK_SIZE_K"] |
| | VEC_SIZE = configs["VEC_SIZE"] |
| | num_stages = configs["num_stages"] |
| |
|
| | |
| | if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0: |
| | |
| | return nvfp4_gemm_dequant(a, b, b_scale, b_scale_2 if b_scale_2 is not None else torch.ones(1, device=a.device)) |
| |
|
| | |
| | if a.dtype != torch.uint8: |
| | a_nvfp4, a_scale, a_scale_2 = quantize_act_nvfp4(a) |
| | else: |
| | a_nvfp4 = a |
| |
|
| | |
| | a_scale_triton = linear_to_triton_scale(a_scale, M, K, VEC_SIZE) |
| | b_scale_triton = linear_to_triton_scale(b_scale, N, K, VEC_SIZE) |
| |
|
| | |
| | a_desc = TensorDescriptor.from_tensor(a_nvfp4, [BLOCK_M, BLOCK_K // 2]) |
| | b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // 2]) |
| |
|
| | rep_m = BLOCK_M // 128 |
| | rep_n = BLOCK_N // 128 |
| | rep_k = BLOCK_K // VEC_SIZE // 4 |
| |
|
| | a_scale_block_shape = [1, rep_m, rep_k, 2, 256] |
| | b_scale_block_shape = [1, rep_n, rep_k, 2, 256] |
| |
|
| | a_scale_desc = TensorDescriptor.from_tensor(a_scale_triton, block_shape=a_scale_block_shape) |
| | b_scale_desc = TensorDescriptor.from_tensor(b_scale_triton, block_shape=b_scale_block_shape) |
| |
|
| | |
| | output = torch.empty((M, N), dtype=torch.float16, device=a.device) |
| | c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N]) |
| |
|
| | |
| | grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) |
| |
|
| | nvfp4_gemm_kernel[grid]( |
| | a_desc, a_scale_desc, |
| | b_desc, b_scale_desc, |
| | c_desc, |
| | M, N, K, |
| | BLOCK_M, BLOCK_N, BLOCK_K, |
| | VEC_SIZE, |
| | rep_m, rep_n, rep_k, |
| | num_stages, |
| | ) |
| |
|
| | |
| | if a_scale_2 is not None and b_scale_2 is not None: |
| | output = output * (a_scale_2 * b_scale_2) |
| | elif b_scale_2 is not None: |
| | output = output * b_scale_2 |
| |
|
| | return output.to(torch.bfloat16) |
| |
|
| |
|
| | def quantize_act_nvfp4( |
| | x: torch.Tensor, |
| | block_size: int = NVFP4_BLOCK_SIZE, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Quantize activation to NVFP4 format. |
| | |
| | Args: |
| | x: Input tensor [M, K] in float/bfloat16 |
| | block_size: Number of elements per scale block |
| | |
| | Returns: |
| | packed: [M, K/2] uint8 packed tensor |
| | scale: [M, K/block_size] FP8 E4M3 per-block scales |
| | scale_2: [1] FP32 global scale |
| | """ |
| | M, K = x.shape |
| | device = x.device |
| | x = x.to(torch.float32) |
| |
|
| | |
| | amax = x.abs().max() |
| | scale_2 = amax / (6.0 * 448.0) |
| | scale_2 = scale_2.clamp(min=1e-12) |
| |
|
| | |
| | x_blocks = x.reshape(M, K // block_size, block_size) |
| | block_amax = x_blocks.abs().amax(dim=-1) |
| | scale = (block_amax / (6.0 * scale_2)).clamp(min=1e-12, max=448.0) |
| | scale = scale.to(torch.float8_e4m3fn) |
| |
|
| | |
| | scale_f32 = scale.to(torch.float32) |
| | scale_expanded = (scale_f32 * scale_2).unsqueeze(-1) |
| | scaled_x = x_blocks / scale_expanded |
| |
|
| | |
| | nvfp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=device) |
| | abs_x = scaled_x.abs() |
| | signs = scaled_x.sign() |
| |
|
| | diffs = (abs_x.unsqueeze(-1) - nvfp4_values).abs() |
| | indices = diffs.argmin(dim=-1) |
| |
|
| | fp4_values = indices.to(torch.uint8) |
| | fp4_values = torch.where(signs < 0, fp4_values + 8, fp4_values) |
| | fp4_tensor = fp4_values.reshape(M, K) |
| |
|
| | |
| | packed = (fp4_tensor[:, 0::2] & 0x0F) | ((fp4_tensor[:, 1::2] & 0x0F) << 4) |
| | packed = packed.to(torch.uint8) |
| |
|
| | return packed, scale, scale_2.reshape(1) |
| |
|
| |
|
| | def act_quant_nvfp4( |
| | x: torch.Tensor, |
| | block_size: int = NVFP4_BLOCK_SIZE, |
| | scale_fmt: Optional[str] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Quantize activation with interface matching original act_quant. |
| | |
| | Args: |
| | x: Input tensor |
| | block_size: Block size for quantization |
| | scale_fmt: Scale format (unused, for API compatibility) |
| | |
| | Returns: |
| | y: Quantized tensor [M, K/2] packed uint8 |
| | s: Scale tensor [M, K/block_size] FP8 E4M3 |
| | """ |
| | packed, scale, scale_2 = quantize_act_nvfp4(x.view(-1, x.size(-1)), block_size) |
| |
|
| | |
| | scale.scale_2 = scale_2 |
| |
|
| | return packed.view(*x.shape[:-1], x.size(-1) // 2), scale.view(*x.shape[:-1], -1) |
| |
|
| |
|
| | |
| | def test_nvfp4_gemm(): |
| | """Test NVFP4 GEMM implementation.""" |
| | M, N, K = 256, 512, 1024 |
| |
|
| | |
| | x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) |
| |
|
| | |
| | weight_bf16 = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) |
| |
|
| | |
| | weight_packed, weight_scale, weight_scale_2 = quantize_act_nvfp4(weight_bf16) |
| |
|
| | |
| | weight_deq = dequantize_nvfp4(weight_packed, weight_scale, weight_scale_2, dtype=torch.bfloat16) |
| | ref = torch.matmul(x, weight_deq.T) |
| |
|
| | |
| | out_deq = nvfp4_gemm_dequant(x, weight_packed, weight_scale, weight_scale_2) |
| |
|
| | |
| | error = (ref - out_deq).abs().mean() |
| | print(f"PASS: NVFP4 GEMM dequant test: mean abs error = {error:.6f}") |
| |
|
| | return True |
| |
|
| |
|
| | if __name__ == "__main__": |
| | test_nvfp4_gemm() |
| |
|