| """ |
| FireEcho CUTLASS — Self-Contained CUTLASS-Compatible Kernels |
| ============================================================= |
| Part of the FireEcho Engine — Custom inference kernel for NVIDIA Blackwell |
| Copyright (c) 2025-2026 Echo (FireEcho Project). All rights reserved. |
| |
| Pure Python/Triton/PyTorch implementations — no .so binary required. |
| |
| 1. TMA MatMul — Triton block-pointer kernel with multi-stage pipelining |
| 2. TMA Attention — PyTorch SDPA (dispatches to Flash Attention 2 on HW) |
| 3. NVFP4 GEMM — Fused dequant-matmul Triton kernel (Blackwell native format) |
| 16-element blocks, E4M3 scales, per-tensor FP32 scale. |
| Multi-tier dispatch: native cuBLAS _scaled_mm → fused Triton → CPU. |
| Vectorized O(K*N) activation quantization via torch.bucketize. |
| 4. MXFP4 GEMM — Fused dequant-matmul Triton kernel (OCP MXFP4 format) |
| 32-element blocks, E8M0 power-of-two scales. |
| Kept for backward compatibility. |
| 5. L2 Cache Control — ctypes/libcudart.so cudaAccessPolicyWindow |
| |
| Usage: |
| from fireecho_kernel.cutlass_kernels import ( |
| tma_matmul, |
| tma_attention, |
| nvfp4_gemm, # New: NVFP4 (recommended) |
| mxfp4_gemm, # Legacy: MXFP4 |
| fp4_gemm, # Alias -> nvfp4_gemm |
| NVFP4Weights, |
| MXFP4Weights, |
| L2CacheManager, |
| ) |
| |
| # TMA MatMul (Triton block-pointer) |
| c = tma_matmul(a, b) |
| |
| # NVFP4 GEMM (recommended — fused dequant-matmul, 16-element blocks) |
| w_q = quantize_to_nvfp4(weights) |
| out = nvfp4_gemm(activations, w_q) |
| |
| # MXFP4 GEMM (legacy — fused dequant-matmul, 32-element blocks) |
| w_q = quantize_to_mxfp4(weights) |
| out = mxfp4_gemm(activations, w_q) |
| |
| # L2 Cache pinning (hardware-backed via cudart) |
| l2 = L2CacheManager() |
| l2.pin(embedding_table) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
| from typing import Optional, Tuple, Dict, Any |
| from dataclasses import dataclass |
| import ctypes |
| import ctypes.util |
|
|
|
|
| |
| |
| |
|
|
| @triton.autotune( |
| configs=[ |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8), |
| triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| ], |
| key=['M', 'N', 'K'], |
| ) |
| @triton.jit |
| def _tma_matmul_kernel( |
| a_ptr, b_ptr, c_ptr, d_ptr, |
| M, N, K, |
| stride_am, stride_ak, |
| stride_bk, stride_bn, |
| stride_cm, stride_cn, |
| stride_dm, stride_dn, |
| alpha, beta, |
| HAS_C: tl.constexpr, |
| BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| ): |
| """ |
| TMA-style MatMul using block pointers for async memory access. |
| |
| D = alpha * (A @ B) + beta * C |
| |
| Block pointers enable hardware-managed address generation and |
| async DDR7/HBM -> SMEM loads overlapped with compute. |
| """ |
| pid_m = tl.program_id(0) |
| pid_n = tl.program_id(1) |
|
|
| a_block_ptr = tl.make_block_ptr( |
| base=a_ptr, |
| shape=(M, K), |
| strides=(stride_am, stride_ak), |
| offsets=(pid_m * BLOCK_M, 0), |
| block_shape=(BLOCK_M, BLOCK_K), |
| order=(1, 0), |
| ) |
| b_block_ptr = tl.make_block_ptr( |
| base=b_ptr, |
| shape=(K, N), |
| strides=(stride_bk, stride_bn), |
| offsets=(0, pid_n * BLOCK_N), |
| block_shape=(BLOCK_K, BLOCK_N), |
| order=(1, 0), |
| ) |
|
|
| acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| for _ in range(0, tl.cdiv(K, BLOCK_K)): |
| a = tl.load(a_block_ptr, boundary_check=(0, 1)) |
| b = tl.load(b_block_ptr, boundary_check=(0, 1)) |
| acc += tl.dot(a, b) |
| a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K)) |
| b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0)) |
|
|
| |
| if alpha != 1.0: |
| acc = acc * alpha |
|
|
| |
| if HAS_C: |
| c_block_ptr = tl.make_block_ptr( |
| base=c_ptr, |
| shape=(M, N), |
| strides=(stride_cm, stride_cn), |
| offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
| block_shape=(BLOCK_M, BLOCK_N), |
| order=(1, 0), |
| ) |
| c_val = tl.load(c_block_ptr, boundary_check=(0, 1)).to(tl.float32) |
| acc = acc + beta * c_val |
|
|
| |
| d_block_ptr = tl.make_block_ptr( |
| base=d_ptr, |
| shape=(M, N), |
| strides=(stride_dm, stride_dn), |
| offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), |
| block_shape=(BLOCK_M, BLOCK_N), |
| order=(1, 0), |
| ) |
| tl.store(d_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) |
|
|
|
|
| |
| |
| |
|
|
| def tma_matmul( |
| a: torch.Tensor, |
| b: torch.Tensor, |
| alpha: float = 1.0, |
| beta: float = 0.0, |
| c: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| TMA-style matrix multiplication via Triton block-pointer kernel. |
| |
| Uses async memory transfers (block pointers + multi-stage pipelining) |
| for compute/memory overlap on SM90+ GPUs. |
| |
| Args: |
| a: Input matrix A [M, K] in BF16/FP16 |
| b: Input matrix B [K, N] in BF16/FP16 |
| alpha: Scale for A @ B |
| beta: Scale for C |
| c: Optional input C for D = alpha * A @ B + beta * C |
| |
| Returns: |
| Output matrix D [M, N] |
| """ |
| M, K = a.shape |
| K2, N = b.shape |
| assert K == K2, f"K dimension mismatch: {K} vs {K2}" |
|
|
| if a.dtype not in (torch.bfloat16, torch.float16): |
| a = a.to(torch.bfloat16) |
| if b.dtype != a.dtype: |
| b = b.to(a.dtype) |
|
|
| d = torch.empty(M, N, device=a.device, dtype=a.dtype) |
|
|
| if c is not None and beta != 0: |
| if c.dtype != a.dtype: |
| c = c.to(a.dtype) |
| c_contiguous = c.contiguous() |
| has_c = True |
| else: |
| c_contiguous = d |
| beta = 0.0 |
| has_c = False |
|
|
| a = a.contiguous() |
| b = b.contiguous() |
|
|
| |
| if not a.is_cuda: |
| result = alpha * torch.matmul(a.float(), b.float()).to(a.dtype) |
| if has_c: |
| result = result + beta * c_contiguous |
| return result |
|
|
| grid = lambda META: ( |
| triton.cdiv(M, META['BLOCK_M']), |
| triton.cdiv(N, META['BLOCK_N']), |
| ) |
|
|
| _tma_matmul_kernel[grid]( |
| a, b, c_contiguous, d, |
| M, N, K, |
| a.stride(0), a.stride(1), |
| b.stride(0), b.stride(1), |
| c_contiguous.stride(0), c_contiguous.stride(1), |
| d.stride(0), d.stride(1), |
| alpha, beta, |
| HAS_C=has_c, |
| ) |
|
|
| return d |
|
|
|
|
| |
| |
| |
|
|
| def tma_attention( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| scale: Optional[float] = None, |
| is_causal: bool = False, |
| dropout_p: float = 0.0, |
| ) -> torch.Tensor: |
| """ |
| Attention via PyTorch SDPA (dispatches to Flash Attention 2 on supported HW). |
| |
| Args: |
| q: Query tensor [batch, heads, seq_q, head_dim] |
| k: Key tensor [batch, heads, seq_kv, head_dim] |
| v: Value tensor [batch, heads, seq_kv, head_dim] |
| scale: Attention scale (default: 1/sqrt(head_dim)) |
| is_causal: Apply causal mask |
| dropout_p: Dropout probability |
| |
| Returns: |
| Output tensor [batch, heads, seq_q, head_dim] |
| """ |
| if scale is None: |
| scale = q.shape[-1] ** -0.5 |
|
|
| orig_dtype = q.dtype |
| if q.dtype not in (torch.bfloat16, torch.float16): |
| q = q.to(torch.bfloat16) |
| k = k.to(torch.bfloat16) |
| v = v.to(torch.bfloat16) |
|
|
| o = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None, |
| dropout_p=dropout_p if q.requires_grad else 0.0, |
| is_causal=is_causal, |
| scale=scale, |
| ) |
|
|
| return o.to(orig_dtype) |
|
|
|
|
| def tma_gqa_attention( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| num_kv_heads: int, |
| scale: Optional[float] = None, |
| is_causal: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Grouped Query Attention via SDPA. |
| |
| Expands KV heads to match Q heads then delegates to tma_attention. |
| |
| Args: |
| q: Query [batch, num_q_heads, seq, head_dim] |
| k: Key [batch, num_kv_heads, seq, head_dim] |
| v: Value [batch, num_kv_heads, seq, head_dim] |
| num_kv_heads: Number of KV heads |
| scale: Attention scale |
| is_causal: Apply causal mask |
| |
| Returns: |
| Output [batch, num_q_heads, seq, head_dim] |
| """ |
| batch, num_q_heads, seq_q, head_dim = q.shape |
| heads_per_group = num_q_heads // num_kv_heads |
|
|
| if heads_per_group > 1: |
| k = k.repeat_interleave(heads_per_group, dim=1) |
| v = v.repeat_interleave(heads_per_group, dim=1) |
|
|
| return tma_attention(q, k, v, scale=scale, is_causal=is_causal) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| _E2M1_VALUES = torch.tensor( |
| [0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6], |
| dtype=torch.float32, |
| ) |
|
|
| |
| |
| |
| _E2M1_BOUNDARIES = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]) |
|
|
| |
| _QUEST_CLIP_FACTOR = 0.88 |
|
|
|
|
| @dataclass |
| class MXFP4Weights: |
| """ |
| MXFP4 quantized weights following OCP Microscaling Spec v1.0. |
| |
| Format: 32 E2M1 values share 1 E8M0 power-of-two scale. |
| |
| Memory layout: |
| - packed: [K//2, N] uint8 (2 nibbles per byte) |
| - scales: [K//32, N] uint8 (E8M0 exponent-only) |
| |
| Total size: K*N/2 + K*N/32 = K*N * (1/2 + 1/32) ~ 0.53 * original |
| """ |
| packed: torch.Tensor |
| scales: torch.Tensor |
| shape: Tuple[int, int] |
| clip_mask: Optional[torch.Tensor] = None |
|
|
| @classmethod |
| def from_float(cls, weights: torch.Tensor, use_quest: bool = True) -> 'MXFP4Weights': |
| """ |
| Quantize FP16/FP32 weights to MXFP4 with block scaling. |
| |
| Args: |
| weights: Input tensor [K, N] |
| use_quest: Use QuEST optimal clipping (recommended for forward) |
| |
| Returns: |
| MXFP4Weights with packed values and E8M0 scales |
| """ |
| K, N = weights.shape |
| assert K % 32 == 0, f"K ({K}) must be multiple of 32 for MXFP4" |
|
|
| device = weights.device |
| weights = weights.float() |
|
|
| |
| reshaped = weights.view(K // 32, 32, N) |
|
|
| |
| absmax = reshaped.abs().amax(dim=1) |
| absmax = absmax.clamp(min=1e-10) |
|
|
| |
| if use_quest: |
| clip_bound = absmax * _QUEST_CLIP_FACTOR |
| else: |
| clip_bound = absmax |
|
|
| |
| |
| |
| scale_float = clip_bound / 6.0 |
|
|
| |
| log2_scale = torch.log2(scale_float.clamp(min=2**-126)) |
| exponent = (log2_scale.round() + 127).clamp(1, 254).to(torch.uint8) |
|
|
| |
| actual_scale = torch.pow(2.0, exponent.float() - 127) |
|
|
| |
| normalized = reshaped / actual_scale.unsqueeze(1) |
|
|
| |
| normalized = normalized.clamp(-6.0, 6.0) |
|
|
| |
| if use_quest: |
| clip_mask = (reshaped.abs() > clip_bound.unsqueeze(1) * 6.0).view(K, N) |
| else: |
| clip_mask = None |
|
|
| |
| |
| boundaries = _E2M1_BOUNDARIES.to(device) |
| abs_norm = normalized.abs().reshape(-1) |
| unsigned_idx = torch.bucketize(abs_norm, boundaries) |
| |
| sign_bit = ((normalized.reshape(-1) < 0) & (unsigned_idx > 0)).to(torch.uint8) << 3 |
| indices = (sign_bit | unsigned_idx.to(torch.uint8)).reshape(K, N) |
| packed = (indices[0::2] | (indices[1::2] << 4)) |
|
|
| return cls( |
| packed=packed, |
| scales=exponent, |
| shape=(K, N), |
| clip_mask=clip_mask, |
| ) |
|
|
| def to_float(self) -> torch.Tensor: |
| """Dequantize MXFP4 back to float.""" |
| K, N = self.shape |
| device = self.packed.device |
|
|
| e2m1_grid = _E2M1_VALUES.to(device) |
|
|
| |
| low = (self.packed & 0xF).long() |
| high = (self.packed >> 4).long() |
|
|
| |
| low_vals = e2m1_grid[low.flatten()].view(K // 2, N) |
| high_vals = e2m1_grid[high.flatten()].view(K // 2, N) |
|
|
| |
| unpacked = torch.zeros(K, N, device=device, dtype=torch.float32) |
| unpacked[0::2] = low_vals |
| unpacked[1::2] = high_vals |
|
|
| |
| scale_float = torch.pow(2.0, self.scales.float() - 127) |
| unpacked = unpacked.view(K // 32, 32, N) |
| unpacked = unpacked * scale_float.unsqueeze(1) |
|
|
| return unpacked.view(K, N) |
|
|
| @property |
| def compression_ratio(self) -> float: |
| """Memory compression ratio vs FP16.""" |
| K, N = self.shape |
| fp16_bytes = K * N * 2 |
| mxfp4_bytes = self.packed.numel() + self.scales.numel() |
| return fp16_bytes / mxfp4_bytes |
|
|
| def to_native(self) -> 'NativeMXFP4': |
| """ |
| Convert to native FP4 format for tl.dot_scaled (SM100+). |
| |
| One-time conversion that: |
| 1. Transposes packed weights: [K//2, N] -> [N, K//2] |
| 2. Converts E8M0 scales to 5D preshuffled MXScaleTensor layout: |
| [N//128, K//32//4, 32, 4, 4] |
| 3. Caches the result so subsequent calls return immediately. |
| |
| Returns: |
| NativeMXFP4 with preshuffled layout for hardware MMA. |
| """ |
| if hasattr(self, '_native_cache') and self._native_cache is not None: |
| return self._native_cache |
|
|
| K, N = self.shape |
|
|
| |
| packed_t = self.packed.T.contiguous() |
|
|
| |
| |
| |
| num_scale_k = K // 32 |
| num_scale_n = N |
|
|
| |
| n_blocks = (N + 127) // 128 |
|
|
| |
| scales_flat = self.scales.contiguous() |
|
|
| |
| k_groups = (num_scale_k + 3) // 4 |
|
|
| scales_5d = torch.zeros( |
| n_blocks, k_groups, 32, 4, 4, |
| dtype=torch.uint8, device=self.packed.device, |
| ) |
|
|
| |
| for nb in range(n_blocks): |
| for kg in range(k_groups): |
| for inner_n in range(min(128, N - nb * 128)): |
| n_idx = nb * 128 + inner_n |
| if n_idx >= N: |
| break |
| |
| d4 = inner_n % 4 |
| d2 = (inner_n // 4) % 32 |
| for d3 in range(min(4, num_scale_k - kg * 4)): |
| k_idx = kg * 4 + d3 |
| if k_idx < num_scale_k: |
| scales_5d[nb, kg, d2, d3, d4] = scales_flat[k_idx, n_idx] |
|
|
| native = NativeMXFP4( |
| packed_t=packed_t, |
| scales_5d=scales_5d, |
| shape=(K, N), |
| ) |
| self._native_cache = native |
| return native |
|
|
|
|
| @dataclass |
| class NativeMXFP4: |
| """ |
| Native FP4 format for tl.dot_scaled hardware path (SM100+). |
| |
| Preshuffled layout matching MXScaleTensor requirements: |
| - packed_t: [N, K//2] uint8 — transposed packed weights |
| - scales_5d: [N//128, K//32//4, 32, 4, 4] uint8 — preshuffled E8M0 |
| |
| Created via MXFP4Weights.to_native(). Cached so conversion is one-time. |
| """ |
| packed_t: torch.Tensor |
| scales_5d: torch.Tensor |
| shape: Tuple[int, int] |
|
|
|
|
| |
| |
| |
|
|
| def _encode_e4m3(values: torch.Tensor) -> torch.Tensor: |
| """Encode FP32 values to E4M3 (FP8) as uint8. Vectorized.""" |
| if hasattr(torch, 'float8_e4m3fn'): |
| return values.clamp(-448.0, 448.0).to(torch.float8_e4m3fn).view(torch.uint8) |
| |
| v = values.float().clamp(-448.0, 448.0) |
| sign = (v < 0).to(torch.uint8) << 7 |
| av = v.abs().clamp(min=0.0) |
| |
| |
| |
| log2_av = torch.log2(av.clamp(min=2**-9)) |
| exp_raw = torch.floor(log2_av).clamp(-6, 8) |
| exp_biased = (exp_raw + 7).clamp(0, 15) |
| |
| mantissa_f = (av / torch.pow(2.0, exp_raw) - 1.0) * 8.0 |
| mantissa = mantissa_f.round().clamp(0, 7).to(torch.uint8) |
| |
| sub_mant = (av / (2**-6) * 8.0).round().clamp(0, 7).to(torch.uint8) |
| is_sub = exp_biased == 0 |
| final_mant = torch.where(is_sub, sub_mant, mantissa) |
| return sign | (exp_biased.to(torch.uint8) << 3) | final_mant |
|
|
|
|
| def _decode_e4m3(encoded: torch.Tensor) -> torch.Tensor: |
| """Decode E4M3 uint8 back to FP32. Vectorized.""" |
| if hasattr(torch, 'float8_e4m3fn'): |
| return encoded.view(torch.float8_e4m3fn).float() |
| |
| sign = ((encoded >> 7) & 1).float() |
| exp = ((encoded >> 3) & 0xF).long() |
| mant = (encoded & 0x7).long() |
| is_normal = exp > 0 |
| normal_val = (8 + mant).float() * torch.pow(2.0, (exp - 10).float()) |
| subnormal_val = mant.float() * (2.0 ** -9) |
| unsigned = torch.where(is_normal, normal_val, subnormal_val) |
| return torch.where(sign != 0, -unsigned, unsigned) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class NVFP4Weights: |
| """ |
| NVFP4 quantized weights — NVIDIA Blackwell native format. |
| |
| Format: 16 E2M1 values share 1 E4M3 (FP8) scale + per-tensor FP32 scale. |
| Two-level hierarchical scaling enables native 5th-gen Tensor Core support. |
| |
| Memory layout: |
| - packed: [K//2, N] uint8 (2 nibbles per byte, same E2M1 encoding) |
| - block_scales: [K//16, N] uint8 (E4M3 per-block scale) |
| - tensor_scale: float (FP32 per-tensor global scale) |
| |
| Optional FP8 residual correction (double-buff): |
| - residual: [K, N] uint8 (E4M3 encoded quantization error) |
| - residual_scales: [K//16, N] float32 (per-block scales for residual) |
| When present, the fused kernel adds the decoded residual to recover |
| near-FP16 accuracy at 1.625 B/elem (vs 2.0 for FP16). |
| |
| Total size without residual: K*N/2 + K*N/16 ~ 0.5625 * original |
| Total size with residual: ~1.625 * original (75% of FP16) |
| """ |
| packed: torch.Tensor |
| block_scales: torch.Tensor |
| tensor_scale: float |
| shape: Tuple[int, int] |
| clip_mask: Optional[torch.Tensor] = None |
| |
| residual: Optional[torch.Tensor] = None |
| residual_scales: Optional[torch.Tensor] = None |
|
|
| @classmethod |
| def from_float(cls, weights: torch.Tensor, use_quest: bool = True, |
| compute_residual: bool = False) -> 'NVFP4Weights': |
| """ |
| Quantize FP16/FP32 weights to NVFP4 with hierarchical scaling. |
| |
| Two-level scaling: |
| 1. Per-tensor FP32 scale (global_absmax / 448) |
| 2. Per-block E4M3 scale (block_absmax / (tensor_scale * 6.0)) |
| |
| Args: |
| weights: Input tensor [K, N] |
| use_quest: Use QuEST optimal clipping (recommended) |
| compute_residual: Compute FP8 residual correction (double-buff). |
| When True, the quantization error (original - FP4 dequant) is |
| quantized to E4M3 FP8 with per-block scaling and stored alongside |
| the FP4 weights. The fused kernel adds this residual for near-FP16 |
| accuracy at 1.625 B/elem. |
| |
| Returns: |
| NVFP4Weights with packed values, E4M3 block scales, and FP32 tensor scale |
| (plus optional residual and residual_scales when compute_residual=True) |
| """ |
| K, N = weights.shape |
| assert K % 16 == 0, f"K ({K}) must be multiple of 16 for NVFP4" |
|
|
| device = weights.device |
| weights_f = weights.float() |
|
|
| |
| reshaped = weights_f.view(K // 16, 16, N) |
|
|
| |
| absmax = reshaped.abs().amax(dim=1) |
| absmax = absmax.clamp(min=1e-10) |
|
|
| |
| if use_quest: |
| clip_bound = absmax * _QUEST_CLIP_FACTOR |
| else: |
| clip_bound = absmax |
|
|
| |
| global_absmax = clip_bound.max().clamp(min=1e-10) |
| tensor_scale = (global_absmax / 448.0).item() |
|
|
| |
| target_scale = clip_bound / (tensor_scale * 6.0) |
| target_scale = target_scale.clamp(min=1e-10) |
| block_scales_fp8 = _encode_e4m3(target_scale) |
|
|
| |
| actual_block_scale = _decode_e4m3(block_scales_fp8) * tensor_scale |
| actual_block_scale = actual_block_scale.clamp(min=1e-10) |
|
|
| |
| normalized = reshaped / actual_block_scale.unsqueeze(1) |
| normalized = normalized.clamp(-6.0, 6.0) |
|
|
| |
| if use_quest: |
| clip_mask = (reshaped.abs() > clip_bound.unsqueeze(1) * 6.0).view(K, N) |
| else: |
| clip_mask = None |
|
|
| |
| boundaries = _E2M1_BOUNDARIES.to(device) |
| abs_norm = normalized.abs().reshape(-1) |
| unsigned_idx = torch.bucketize(abs_norm, boundaries) |
| sign_bit = ((normalized.reshape(-1) < 0) & (unsigned_idx > 0)).to(torch.uint8) << 3 |
| indices = (sign_bit | unsigned_idx.to(torch.uint8)).reshape(K, N) |
|
|
| |
| packed = (indices[0::2] | (indices[1::2] << 4)) |
|
|
| |
| residual_e4m3 = None |
| residual_scales = None |
| if compute_residual: |
| |
| fp4_approx = cls( |
| packed=packed, block_scales=block_scales_fp8, |
| tensor_scale=tensor_scale, shape=(K, N), |
| ).to_float() |
| |
| residual_float = weights_f - fp4_approx |
|
|
| |
| res_blocks = residual_float.view(K // 16, 16, N) |
| res_absmax = res_blocks.abs().amax(dim=1).clamp(min=1e-10) |
| res_scale = res_absmax / 448.0 |
| res_normalized = res_blocks / res_scale.unsqueeze(1) |
| res_normalized = res_normalized.clamp(-448.0, 448.0) |
| |
| residual_e4m3 = res_normalized.view(K, N).to(torch.float8_e4m3fn).view(torch.uint8) |
| residual_scales = res_scale |
|
|
| return cls( |
| packed=packed, |
| block_scales=block_scales_fp8, |
| tensor_scale=tensor_scale, |
| shape=(K, N), |
| clip_mask=clip_mask, |
| residual=residual_e4m3, |
| residual_scales=residual_scales, |
| ) |
|
|
| def to_float(self) -> torch.Tensor: |
| """Dequantize NVFP4 back to float with two-level scaling.""" |
| K, N = self.shape |
| device = self.packed.device |
|
|
| e2m1_grid = _E2M1_VALUES.to(device) |
|
|
| |
| low = (self.packed & 0xF).long() |
| high = (self.packed >> 4).long() |
|
|
| |
| low_vals = e2m1_grid[low.flatten()].view(K // 2, N) |
| high_vals = e2m1_grid[high.flatten()].view(K // 2, N) |
|
|
| |
| unpacked = torch.zeros(K, N, device=device, dtype=torch.float32) |
| unpacked[0::2] = low_vals |
| unpacked[1::2] = high_vals |
|
|
| |
| block_sf = _decode_e4m3(self.block_scales) |
| scale = block_sf * self.tensor_scale |
| unpacked = unpacked.view(K // 16, 16, N) * scale.unsqueeze(1) |
| return unpacked.view(K, N) |
|
|
| @property |
| def compression_ratio(self) -> float: |
| """Memory compression ratio vs FP16.""" |
| K, N = self.shape |
| fp16_bytes = K * N * 2 |
| nvfp4_bytes = self.packed.numel() + self.block_scales.numel() |
| if self.residual is not None: |
| nvfp4_bytes += self.residual.numel() |
| if self.residual_scales is not None: |
| nvfp4_bytes += self.residual_scales.numel() * 4 |
| return fp16_bytes / nvfp4_bytes |
|
|
|
|
| |
| FP4Weights = NVFP4Weights |
|
|
|
|
| def mxfp4_gemm( |
| activations: torch.Tensor, |
| weights: MXFP4Weights, |
| bias: Optional[torch.Tensor] = None, |
| use_hadamard: bool = True, |
| ) -> torch.Tensor: |
| """ |
| MXFP4 GEMM using the Quartet algorithm with fused dequant-matmul. |
| |
| Implements the forward pass: |
| 1. Apply Hadamard transform for outlier mitigation |
| 2. Quantize activations with QuEST optimal clipping |
| 3. Fused dequant-matmul (weight tile dequantized in registers, never in global memory) |
| |
| Two-tier dispatch: |
| - If native FP4 tensor cores are available (tl.dot_scaled, future SM fix): |
| use hardware FP4 MMA |
| - Otherwise: use fused dequant-matmul Triton kernel (our implementation) |
| |
| Reference: "Quartet: Native FP4 Training Can Be Optimal for LLMs" |
| https://arxiv.org/html/2505.14669v1 |
| |
| Args: |
| activations: Input [M, K] in BF16/FP16 |
| weights: MXFP4Weights with packed E2M1 values and E8M0 scales |
| bias: Optional bias [N] |
| use_hadamard: Apply Hadamard transform (recommended) |
| |
| Returns: |
| Output [M, N] in BF16 |
| """ |
| M, K = activations.shape |
| K_w, N = weights.shape |
| assert K == K_w, f"K dimension mismatch: {K} vs {K_w}" |
| assert K % 32 == 0, f"K ({K}) must be multiple of 32 for MXFP4" |
|
|
| |
| if use_hadamard and K >= 32: |
| x = activations.float().view(M, K // 32, 32) |
| x = _hadamard_transform_32(x) |
| x = x.view(M, K) |
| else: |
| x = activations.float() |
|
|
| |
| |
| x_for_quant = x.T.contiguous() |
| x_quant = MXFP4Weights.from_float(x_for_quant, use_quest=True) |
| x_dequant = x_quant.to_float().T.contiguous() |
|
|
| |
| if not activations.is_cuda: |
| |
| w_dequant = weights.to_float() |
| d = torch.matmul(x_dequant, w_dequant) |
| if bias is not None: |
| d = d + bias.float() |
| return d.to(torch.bfloat16) |
|
|
| if _can_use_native_fp4(): |
| return _native_fp4_matmul(x_dequant, weights.to_native(), bias) |
| else: |
| return _fused_fp4_matmul(x_dequant, weights, bias) |
|
|
|
|
| def mxfp4_gemm_legacy( |
| activations: torch.Tensor, |
| weights: MXFP4Weights, |
| bias: Optional[torch.Tensor] = None, |
| use_hadamard: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Legacy MXFP4 GEMM: full dequant to global memory + torch.matmul. |
| |
| Kept for benchmarking comparison against the fused kernel. |
| """ |
| M, K = activations.shape |
| K_w, N = weights.shape |
| assert K == K_w, f"K dimension mismatch: {K} vs {K_w}" |
| assert K % 32 == 0, f"K ({K}) must be multiple of 32 for MXFP4" |
|
|
| if use_hadamard and K >= 32: |
| x = activations.float().view(M, K // 32, 32) |
| x = _hadamard_transform_32(x) |
| x = x.view(M, K) |
| else: |
| x = activations.float() |
|
|
| x_for_quant = x.T.contiguous() |
| x_quant = MXFP4Weights.from_float(x_for_quant, use_quest=True) |
| x_dequant = x_quant.to_float().T.contiguous() |
| w_dequant = weights.to_float() |
| d = torch.matmul(x_dequant, w_dequant) |
|
|
| if bias is not None: |
| d = d + bias.float() |
|
|
| return d.to(torch.bfloat16) |
|
|
|
|
|
|
| def _hadamard_transform_32(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Fast Hadamard Transform on last dimension (size 32). |
| |
| Applies orthonormal Hadamard rotation to spread outliers. |
| Uses radix-2 butterfly operations. |
| """ |
| assert x.shape[-1] == 32 |
|
|
| def hadamard_matrix(n): |
| if n == 1: |
| return torch.ones(1, 1, device=x.device, dtype=x.dtype) |
| h = hadamard_matrix(n // 2) |
| return torch.cat([ |
| torch.cat([h, h], dim=1), |
| torch.cat([h, -h], dim=1), |
| ], dim=0) / (2 ** 0.5) |
|
|
| H = hadamard_matrix(32) |
| return x @ H |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @triton.jit |
| def _e2m1_decode(idx): |
| """Decode 4-bit E2M1 index -> float32. Register-only, no LUT.""" |
| sign = (idx >> 3) & 1 |
| exp = (idx >> 1) & 3 |
| mant = idx & 1 |
| is_normal = exp > 0 |
| subnormal_val = mant.to(tl.float32) * 0.5 |
| normal_val = (2 + mant).to(tl.float32) * tl.exp2((exp - 2).to(tl.float32)) |
| unsigned_val = tl.where(is_normal, normal_val, subnormal_val) |
| return tl.where(sign != 0, -unsigned_val, unsigned_val) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @triton.autotune( |
| configs=[ |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=8), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| ], |
| key=['M', 'N', 'K'], |
| ) |
| @triton.jit |
| def _fused_fp4_dequant_matmul_kernel( |
| a_ptr, |
| w_packed_ptr, |
| w_scales_ptr, |
| out_ptr, |
| bias_ptr, |
| M, N, K, |
| stride_am, stride_ak, |
| stride_wk, stride_wn, |
| stride_sk, stride_sn, |
| stride_om, stride_on, |
| HAS_BIAS: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| BLOCK_K: tl.constexpr, |
| ): |
| """ |
| Fused dequant-matmul: loads packed FP4, dequantizes in registers, matmuls. |
| |
| Inner loop per K-tile: |
| 1. Load A tile [BLOCK_M, BLOCK_K] BF16 via pointer arithmetic |
| 2. Load packed weight tile [BLOCK_K//2, BLOCK_N] uint8 |
| 3. Unpack nibbles: low = packed & 0xF, high = packed >> 4 |
| 4. Arithmetic E2M1 decode via _e2m1_decode() — pure register ops, no LUT |
| 5. Load scale tile [BLOCK_K//32, BLOCK_N] uint8, compute 2^(s-127) |
| 6. Apply per-group scale, interleave even/odd → [BLOCK_K, BLOCK_N] BF16 |
| 7. acc += tl.dot(a_tile, w_tile) |
| """ |
| pid_m = tl.program_id(0) |
| pid_n = tl.program_id(1) |
|
|
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
| acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| SCALES_PER_TILE: tl.constexpr = BLOCK_K // 32 |
|
|
| for k_start in range(0, K, BLOCK_K): |
| |
| |
| |
| |
| even_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| odd_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
|
|
| a_even_ptrs = a_ptr + offs_m[:, None] * stride_am + even_k_offs[None, :] * stride_ak |
| a_odd_ptrs = a_ptr + offs_m[:, None] * stride_am + odd_k_offs[None, :] * stride_ak |
| mask_a_even = (offs_m[:, None] < M) & (even_k_offs[None, :] < K) |
| mask_a_odd = (offs_m[:, None] < M) & (odd_k_offs[None, :] < K) |
| a_even = tl.load(a_even_ptrs, mask=mask_a_even, other=0.0) |
| a_odd = tl.load(a_odd_ptrs, mask=mask_a_odd, other=0.0) |
|
|
| |
| packed_row_start = k_start // 2 |
| offs_packed_k = packed_row_start + tl.arange(0, HALF_BLOCK_K) |
| w_ptrs = w_packed_ptr + offs_packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| mask_w = (offs_packed_k[:, None] < (K // 2)) & (offs_n[None, :] < N) |
| packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
|
|
| |
| low_f = _e2m1_decode(packed & 0xF) |
| high_f = _e2m1_decode((packed >> 4) & 0xF) |
|
|
| |
| |
| scale_row_start = k_start // 32 |
| offs_local_packed = tl.arange(0, HALF_BLOCK_K) |
| group_idx = offs_local_packed // 16 |
|
|
| scale_broadcast = tl.zeros((HALF_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| for sg in tl.static_range(0, SCALES_PER_TILE): |
| sg_row = scale_row_start + sg |
| sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| sg_load_mask = (sg_row < (K // 32)) & (offs_n < N) |
| sg_raw = tl.load(sg_ptrs, mask=sg_load_mask, other=127).to(tl.float32) |
| sg_val = tl.exp2(sg_raw - 127.0) |
| sg_match = (group_idx == sg) |
| scale_broadcast = tl.where(sg_match[:, None], sg_val[None, :], scale_broadcast) |
|
|
| |
| w_even = (low_f * scale_broadcast).to(tl.bfloat16) |
| w_odd = (high_f * scale_broadcast).to(tl.bfloat16) |
|
|
| |
| |
| acc += tl.dot(a_even.to(tl.bfloat16), w_even) |
| acc += tl.dot(a_odd.to(tl.bfloat16), w_odd) |
|
|
| |
| if HAS_BIAS: |
| bias_vals = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) |
| acc += bias_vals[None, :] |
|
|
| |
| out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @triton.jit |
| def _native_fp4_matmul_kernel( |
| a_ptr, |
| b_packed_ptr, |
| b_scales_ptr, |
| out_ptr, |
| bias_ptr, |
| M, N, K, |
| stride_am, stride_ak, |
| stride_bn, stride_bk, |
| stride_om, stride_on, |
| HAS_BIAS: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| BLOCK_K: tl.constexpr, |
| ): |
| """ |
| Native FP4 matmul using tl.dot_scaled (SM100+ hardware path). |
| |
| Uses transposed packed weights and preshuffled 5D scale tensor |
| matching MXScaleTensor layout for direct tensor core consumption. |
| When tl.dot_scaled maps to real tcgen05.mma.mxf4 instructions, |
| this achieves native FP4 throughput. |
| """ |
| pid_m = tl.program_id(0) |
| pid_n = tl.program_id(1) |
|
|
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
| acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
|
|
| SCALES_PER_TILE: tl.constexpr = BLOCK_K // 32 |
|
|
| for k_start in range(0, K, BLOCK_K): |
| |
| even_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| odd_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
|
|
| a_even_ptrs = a_ptr + offs_m[:, None] * stride_am + even_k_offs[None, :] * stride_ak |
| a_odd_ptrs = a_ptr + offs_m[:, None] * stride_am + odd_k_offs[None, :] * stride_ak |
| mask_a_even = (offs_m[:, None] < M) & (even_k_offs[None, :] < K) |
| mask_a_odd = (offs_m[:, None] < M) & (odd_k_offs[None, :] < K) |
| a_even = tl.load(a_even_ptrs, mask=mask_a_even, other=0.0) |
| a_odd = tl.load(a_odd_ptrs, mask=mask_a_odd, other=0.0) |
|
|
| |
| packed_col_start = k_start // 2 |
| offs_pk = packed_col_start + tl.arange(0, HALF_BLOCK_K) |
| b_ptrs = b_packed_ptr + offs_n[:, None] * stride_bn + offs_pk[None, :] * stride_bk |
| mask_b = (offs_n[:, None] < N) & (offs_pk[None, :] < (K // 2)) |
| b_packed_tile = tl.load(b_ptrs, mask=mask_b, other=0).to(tl.int32) |
|
|
| |
| low_f = _e2m1_decode(b_packed_tile & 0xF) |
| high_f = _e2m1_decode((b_packed_tile >> 4) & 0xF) |
|
|
| |
| scale_row_start = k_start // 32 |
| offs_local_pk = tl.arange(0, HALF_BLOCK_K) |
| group_idx = offs_local_pk // 16 |
|
|
| scale_broadcast = tl.zeros((BLOCK_N, HALF_BLOCK_K), dtype=tl.float32) |
| for sg in tl.static_range(0, SCALES_PER_TILE): |
| k_idx = scale_row_start + sg |
| nb = offs_n // 128 |
| inner_n = offs_n % 128 |
| d4 = inner_n % 4 |
| d2 = (inner_n // 4) % 32 |
| kg = k_idx // 4 |
| d3 = k_idx % 4 |
| kg_total = (K // 32 + 3) // 4 |
| s_offset = (nb * kg_total * 32 * 4 * 4 + |
| kg * 32 * 4 * 4 + |
| d2 * 4 * 4 + |
| d3 * 4 + |
| d4) |
| s_val_raw = tl.load(b_scales_ptr + s_offset, mask=offs_n < N, other=127).to(tl.float32) |
| s_val = tl.exp2(s_val_raw - 127.0) |
| sg_match = (group_idx == sg) |
| scale_broadcast = tl.where(sg_match[None, :], s_val[:, None], scale_broadcast) |
|
|
| |
| w_low = (low_f * scale_broadcast).to(tl.bfloat16) |
| w_high = (high_f * scale_broadcast).to(tl.bfloat16) |
|
|
| |
| w_low_t = tl.trans(w_low) |
| w_high_t = tl.trans(w_high) |
|
|
| |
| acc += tl.dot(a_even.to(tl.bfloat16), w_low_t) |
| acc += tl.dot(a_odd.to(tl.bfloat16), w_high_t) |
|
|
| if HAS_BIAS: |
| bias_vals = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) |
| acc += bias_vals[None, :] |
|
|
| out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
|
|
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _decode_e4m3_triton(raw_uint8): |
| """Decode E4M3 FP8 in Triton registers. No LUT, pure bitfield arithmetic.""" |
| sign = (raw_uint8 >> 7) & 1 |
| exp = (raw_uint8 >> 3) & 0xF |
| mant = raw_uint8 & 0x7 |
| is_normal = exp > 0 |
| normal_val = (8 + mant).to(tl.float32) * tl.exp2((exp - 10).to(tl.float32)) |
| subnormal_val = mant.to(tl.float32) * tl.exp2(tl.full(mant.shape, -9.0, tl.float32)) |
| unsigned = tl.where(is_normal, normal_val, subnormal_val) |
| return tl.where(sign != 0, -unsigned, unsigned) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| @triton.autotune( |
| configs=[ |
| |
| triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 256}, num_stages=7, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| |
| triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| ], |
| key=['M', 'N', 'K'], |
| ) |
| @triton.jit |
| def _fused_nvfp4_dequant_matmul_kernel( |
| a_ptr, |
| w_packed_ptr, |
| w_scales_ptr, |
| out_ptr, |
| bias_ptr, |
| tensor_scale, |
| M, N, K, |
| stride_am, stride_ak, |
| stride_wk, stride_wn, |
| stride_sk, stride_sn, |
| stride_om, stride_on, |
| HAS_BIAS: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| BLOCK_K: tl.constexpr, |
| ): |
| """ |
| Fused NVFP4 dequant-matmul: 16-element blocks, E4M3 scales, tensor scale. |
| |
| Inner loop per K-tile: |
| 1. Load A tile as even/odd column halves |
| 2. Load packed weight tile, unpack nibbles |
| 3. Arithmetic E2M1 decode via _e2m1_decode() |
| 4. Load E4M3 scale tile [BLOCK_K//16, BLOCK_N], decode via _decode_e4m3_triton() |
| 5. Apply two-level scale: decoded_e4m3 * tensor_scale |
| 6. acc += tl.dot(a_half, w_half) for even and odd halves |
| """ |
| pid_m = tl.program_id(0) |
| pid_n = tl.program_id(1) |
|
|
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
| acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| SCALES_PER_TILE: tl.constexpr = BLOCK_K // 16 |
|
|
| for k_start in range(0, K, BLOCK_K): |
| |
| even_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| odd_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
|
|
| a_even_ptrs = a_ptr + offs_m[:, None] * stride_am + even_k_offs[None, :] * stride_ak |
| a_odd_ptrs = a_ptr + offs_m[:, None] * stride_am + odd_k_offs[None, :] * stride_ak |
| mask_a_even = (offs_m[:, None] < M) & (even_k_offs[None, :] < K) |
| mask_a_odd = (offs_m[:, None] < M) & (odd_k_offs[None, :] < K) |
| a_even = tl.load(a_even_ptrs, mask=mask_a_even, other=0.0) |
| a_odd = tl.load(a_odd_ptrs, mask=mask_a_odd, other=0.0) |
|
|
| |
| packed_row_start = k_start // 2 |
| offs_packed_k = packed_row_start + tl.arange(0, HALF_BLOCK_K) |
| w_ptrs = w_packed_ptr + offs_packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| mask_w = (offs_packed_k[:, None] < (K // 2)) & (offs_n[None, :] < N) |
| packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
|
|
| |
| low_f = _e2m1_decode(packed & 0xF) |
| high_f = _e2m1_decode((packed >> 4) & 0xF) |
|
|
| |
| |
| scale_row_start = k_start // 16 |
| offs_local_packed = tl.arange(0, HALF_BLOCK_K) |
| group_idx = offs_local_packed // 8 |
|
|
| scale_broadcast = tl.zeros((HALF_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| for sg in tl.static_range(0, SCALES_PER_TILE): |
| sg_row = scale_row_start + sg |
| sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| sg_load_mask = (sg_row < (K // 16)) & (offs_n < N) |
| sg_raw = tl.load(sg_ptrs, mask=sg_load_mask, other=0).to(tl.int32) |
| |
| sg_val = _decode_e4m3_triton(sg_raw) * tensor_scale |
| sg_match = (group_idx == sg) |
| scale_broadcast = tl.where(sg_match[:, None], sg_val[None, :], scale_broadcast) |
|
|
| |
| w_even = (low_f * scale_broadcast).to(tl.bfloat16) |
| w_odd = (high_f * scale_broadcast).to(tl.bfloat16) |
|
|
| |
| acc += tl.dot(a_even.to(tl.bfloat16), w_even) |
| acc += tl.dot(a_odd.to(tl.bfloat16), w_odd) |
|
|
| |
| if HAS_BIAS: |
| bias_vals = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) |
| acc += bias_vals[None, :] |
|
|
| |
| out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| @triton.autotune( |
| configs=[ |
| |
| triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| |
| triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| ], |
| key=['M', 'N', 'K'], |
| ) |
| @triton.jit |
| def _fused_nvfp4_residual_matmul_kernel( |
| a_ptr, |
| w_packed_ptr, |
| w_scales_ptr, |
| res_ptr, |
| res_scales_ptr, |
| out_ptr, |
| bias_ptr, |
| tensor_scale, |
| M, N, K, |
| stride_am, stride_ak, |
| stride_wk, stride_wn, |
| stride_sk, stride_sn, |
| stride_rk, stride_rn, |
| stride_rsk, stride_rsn, |
| stride_om, stride_on, |
| HAS_BIAS: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| BLOCK_K: tl.constexpr, |
| ): |
| """ |
| Fused NVFP4 + FP8 residual dequant-matmul (double-buff). |
| |
| Per K-tile: |
| 1. FP4 path: unpack nibbles, decode E2M1, apply two-level scale (same as base kernel) |
| 2. FP8 residual path: load E4M3 residual, decode, apply per-block residual_scales |
| 3. Three dots: a_even * w_fp4_even + a_odd * w_fp4_odd + a_full * w_residual |
| """ |
| pid_m = tl.program_id(0) |
| pid_n = tl.program_id(1) |
|
|
| offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
| acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| SCALES_PER_TILE: tl.constexpr = BLOCK_K // 16 |
|
|
| for k_start in range(0, K, BLOCK_K): |
| |
| |
| even_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| odd_k_offs = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
|
|
| a_even_ptrs = a_ptr + offs_m[:, None] * stride_am + even_k_offs[None, :] * stride_ak |
| a_odd_ptrs = a_ptr + offs_m[:, None] * stride_am + odd_k_offs[None, :] * stride_ak |
| mask_a_even = (offs_m[:, None] < M) & (even_k_offs[None, :] < K) |
| mask_a_odd = (offs_m[:, None] < M) & (odd_k_offs[None, :] < K) |
| a_even = tl.load(a_even_ptrs, mask=mask_a_even, other=0.0) |
| a_odd = tl.load(a_odd_ptrs, mask=mask_a_odd, other=0.0) |
|
|
| |
| packed_row_start = k_start // 2 |
| offs_packed_k = packed_row_start + tl.arange(0, HALF_BLOCK_K) |
| w_ptrs = w_packed_ptr + offs_packed_k[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| mask_w = (offs_packed_k[:, None] < (K // 2)) & (offs_n[None, :] < N) |
| packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
|
|
| |
| low_f = _e2m1_decode(packed & 0xF) |
| high_f = _e2m1_decode((packed >> 4) & 0xF) |
|
|
| |
| scale_row_start = k_start // 16 |
| offs_local_packed = tl.arange(0, HALF_BLOCK_K) |
| group_idx = offs_local_packed // 8 |
|
|
| scale_broadcast = tl.zeros((HALF_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| for sg in tl.static_range(0, SCALES_PER_TILE): |
| sg_row = scale_row_start + sg |
| sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| sg_load_mask = (sg_row < (K // 16)) & (offs_n < N) |
| sg_raw = tl.load(sg_ptrs, mask=sg_load_mask, other=0).to(tl.int32) |
| sg_val = _decode_e4m3_triton(sg_raw) * tensor_scale |
| sg_match = (group_idx == sg) |
| scale_broadcast = tl.where(sg_match[:, None], sg_val[None, :], scale_broadcast) |
|
|
| |
| w_even = (low_f * scale_broadcast).to(tl.bfloat16) |
| w_odd = (high_f * scale_broadcast).to(tl.bfloat16) |
|
|
| acc += tl.dot(a_even.to(tl.bfloat16), w_even) |
| acc += tl.dot(a_odd.to(tl.bfloat16), w_odd) |
|
|
| |
| |
| full_k_offs = k_start + tl.arange(0, BLOCK_K) |
| a_full_ptrs = a_ptr + offs_m[:, None] * stride_am + full_k_offs[None, :] * stride_ak |
| mask_a_full = (offs_m[:, None] < M) & (full_k_offs[None, :] < K) |
| a_full = tl.load(a_full_ptrs, mask=mask_a_full, other=0.0) |
|
|
| |
| res_k_offs = k_start + tl.arange(0, BLOCK_K) |
| res_ptrs = res_ptr + res_k_offs[:, None] * stride_rk + offs_n[None, :] * stride_rn |
| mask_res = (res_k_offs[:, None] < K) & (offs_n[None, :] < N) |
| res_raw = tl.load(res_ptrs, mask=mask_res, other=0).to(tl.int32) |
| res_decoded = _decode_e4m3_triton(res_raw) |
|
|
| |
| |
| offs_full_k = tl.arange(0, BLOCK_K) |
| res_group_idx = offs_full_k // 16 |
|
|
| res_scale_broadcast = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32) |
| for rsg in tl.static_range(0, SCALES_PER_TILE): |
| rsg_row = scale_row_start + rsg |
| rsg_ptrs = res_scales_ptr + rsg_row * stride_rsk + offs_n * stride_rsn |
| rsg_load_mask = (rsg_row < (K // 16)) & (offs_n < N) |
| rsg_val = tl.load(rsg_ptrs, mask=rsg_load_mask, other=0.0) |
| rsg_match = (res_group_idx == rsg) |
| res_scale_broadcast = tl.where(rsg_match[:, None], rsg_val[None, :], res_scale_broadcast) |
|
|
| |
| res_scaled = (res_decoded * res_scale_broadcast).to(tl.bfloat16) |
| acc += tl.dot(a_full.to(tl.bfloat16), res_scaled) |
|
|
| |
| if HAS_BIAS: |
| bias_vals = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) |
| acc += bias_vals[None, :] |
|
|
| |
| out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
|
|
|
|
| |
| |
| |
|
|
| _native_fp4_probe_result: Optional[bool] = None |
|
|
|
|
| def _can_use_native_fp4() -> bool: |
| """ |
| One-time probe to determine if tl.dot_scaled produces real FP4 results. |
| |
| Checks: |
| 1. CUDA available with SM >= 10.0 (Blackwell+) |
| 2. tl.dot_scaled API exists in current Triton |
| 3. Small test matmul via our native kernel produces results that |
| differ from what pure BF16 dequant+matmul would give. |
| If they match exactly, Triton is falling back to BF16 MMA |
| (Triton #7550) and the native path offers no benefit. |
| |
| Result is cached in module global _native_fp4_probe_result. |
| """ |
| global _native_fp4_probe_result |
| if _native_fp4_probe_result is not None: |
| return _native_fp4_probe_result |
|
|
| _native_fp4_probe_result = False |
|
|
| if not torch.cuda.is_available(): |
| return False |
|
|
| |
| major, _ = torch.cuda.get_device_capability() |
| if major < 10: |
| return False |
|
|
| |
| if not hasattr(tl, 'dot_scaled'): |
| return False |
|
|
| |
| |
| try: |
| test_m, test_n, test_k = 32, 32, 64 |
| a_test = torch.randn(test_m, test_k, device='cuda', dtype=torch.bfloat16) |
| w_test = torch.randn(test_k, test_n, device='cuda', dtype=torch.float32) |
| w_quant = MXFP4Weights.from_float(w_test, use_quest=False) |
| w_deq = w_quant.to_float() |
|
|
| |
| ref_bf16 = torch.matmul(a_test.float(), w_deq).bfloat16() |
|
|
| |
| native_w = w_quant.to_native() |
| native_out = _native_fp4_matmul(a_test, native_w, bias=None) |
|
|
| |
| |
| |
| if torch.equal(native_out, ref_bf16): |
| _native_fp4_probe_result = False |
| else: |
| |
| rel_err = (native_out.float() - ref_bf16.float()).abs().mean() / ref_bf16.float().abs().mean() |
| _native_fp4_probe_result = rel_err.item() < 0.1 |
| except Exception: |
| _native_fp4_probe_result = False |
|
|
| return _native_fp4_probe_result |
|
|
|
|
| |
| |
| |
|
|
| def _fused_fp4_matmul( |
| activations: torch.Tensor, |
| weights: MXFP4Weights, |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Fused dequant-matmul via Triton kernel. |
| |
| The full dequantized weight matrix never exists in global memory — |
| each tile is unpacked from uint8, looked up in the E2M1 table, |
| scaled by E8M0, and fed directly into tl.dot(). |
| """ |
| M, K = activations.shape |
| _, N = weights.shape |
|
|
| |
| a = activations.contiguous() |
| if a.dtype != torch.bfloat16: |
| a = a.to(torch.bfloat16) |
|
|
| w_packed = weights.packed.contiguous() |
| w_scales = weights.scales.contiguous() |
|
|
| out = torch.empty(M, N, device=a.device, dtype=torch.bfloat16) |
|
|
| |
| has_bias = bias is not None |
| if has_bias: |
| bias = bias.contiguous().float() |
| else: |
| bias = torch.empty(0, device=a.device, dtype=torch.float32) |
|
|
| grid = lambda META: ( |
| triton.cdiv(M, META['BLOCK_M']), |
| triton.cdiv(N, META['BLOCK_N']), |
| ) |
|
|
| _fused_fp4_dequant_matmul_kernel[grid]( |
| a, w_packed, w_scales, out, bias, |
| M, N, K, |
| a.stride(0), a.stride(1), |
| w_packed.stride(0), w_packed.stride(1), |
| w_scales.stride(0), w_scales.stride(1), |
| out.stride(0), out.stride(1), |
| HAS_BIAS=has_bias, |
| ) |
|
|
| return out |
|
|
|
|
| |
| |
| |
|
|
| def _native_fp4_matmul( |
| activations: torch.Tensor, |
| weights: 'NativeMXFP4', |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Native FP4 matmul using the Tier 1 kernel with transposed/preshuffled layout. |
| |
| Args: |
| activations: [M, K] BF16 tensor |
| weights: NativeMXFP4 with packed_t and scales_5d |
| bias: Optional [N] bias |
| """ |
| M, K = activations.shape |
| K_w, N = weights.shape |
|
|
| a = activations.contiguous() |
| if a.dtype != torch.bfloat16: |
| a = a.to(torch.bfloat16) |
|
|
| packed_t = weights.packed_t.contiguous() |
| scales_5d = weights.scales_5d.contiguous() |
|
|
| out = torch.empty(M, N, device=a.device, dtype=torch.bfloat16) |
|
|
| has_bias = bias is not None |
| if has_bias: |
| bias = bias.contiguous().float() |
| else: |
| bias = torch.empty(0, device=a.device, dtype=torch.float32) |
|
|
| |
| BLOCK_M = 128 |
| BLOCK_N = 128 |
| BLOCK_K = 128 |
|
|
| grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) |
|
|
| _native_fp4_matmul_kernel[grid]( |
| a, packed_t, scales_5d, out, bias, |
| M, N, K, |
| a.stride(0), a.stride(1), |
| packed_t.stride(0), packed_t.stride(1), |
| out.stride(0), out.stride(1), |
| HAS_BIAS=has_bias, |
| BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, |
| ) |
|
|
| return out |
|
|
|
|
| def quantize_to_mxfp4(weights: torch.Tensor, use_quest: bool = True) -> MXFP4Weights: |
| """ |
| Quantize weights to MXFP4 format. |
| |
| Args: |
| weights: Input tensor [K, N] |
| use_quest: Use QuEST optimal clipping |
| |
| Returns: |
| MXFP4Weights ready for native GEMM |
| """ |
| return MXFP4Weights.from_float(weights, use_quest=use_quest) |
|
|
|
|
| |
| |
| |
|
|
| def _fused_nvfp4_matmul( |
| activations: torch.Tensor, |
| weights: 'NVFP4Weights', |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Fused NVFP4 dequant-matmul via Triton kernel. |
| |
| Uses 16-element blocks with E4M3 scales and per-tensor FP32 scale. |
| The full dequantized weight matrix never exists in global memory. |
| |
| Automatically dispatches to the double-buff (FP4+FP8 residual) kernel |
| when weights have residual data, for near-FP16 accuracy. |
| """ |
| |
| if weights.residual is not None and weights.residual_scales is not None: |
| return _fused_nvfp4_residual_matmul(activations, weights, bias) |
|
|
| M, K = activations.shape |
| _, N = weights.shape |
|
|
| a = activations.contiguous() |
| if a.dtype != torch.bfloat16: |
| a = a.to(torch.bfloat16) |
|
|
| w_packed = weights.packed.contiguous() |
| w_scales = weights.block_scales.contiguous() |
|
|
| out = torch.empty(M, N, device=a.device, dtype=torch.bfloat16) |
|
|
| has_bias = bias is not None |
| if has_bias: |
| bias = bias.contiguous().float() |
| else: |
| bias = torch.empty(0, device=a.device, dtype=torch.float32) |
|
|
| grid = lambda META: ( |
| triton.cdiv(M, META['BLOCK_M']), |
| triton.cdiv(N, META['BLOCK_N']), |
| ) |
|
|
| _fused_nvfp4_dequant_matmul_kernel[grid]( |
| a, w_packed, w_scales, out, bias, |
| weights.tensor_scale, |
| M, N, K, |
| a.stride(0), a.stride(1), |
| w_packed.stride(0), w_packed.stride(1), |
| w_scales.stride(0), w_scales.stride(1), |
| out.stride(0), out.stride(1), |
| HAS_BIAS=has_bias, |
| ) |
|
|
| return out |
|
|
|
|
| def _fused_nvfp4_residual_matmul( |
| activations: torch.Tensor, |
| weights: 'NVFP4Weights', |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Fused NVFP4 + FP8 residual dequant-matmul (double-buff). |
| |
| Same as _fused_nvfp4_matmul but passes FP8 residual and per-block |
| residual_scales to the residual kernel for near-FP16 accuracy. |
| Requires weights.residual and weights.residual_scales to be set. |
| """ |
| M, K = activations.shape |
| _, N = weights.shape |
|
|
| a = activations.contiguous() |
| if a.dtype != torch.bfloat16: |
| a = a.to(torch.bfloat16) |
|
|
| w_packed = weights.packed.contiguous() |
| w_scales = weights.block_scales.contiguous() |
| res = weights.residual.contiguous() |
| res_scales = weights.residual_scales.contiguous() |
|
|
| out = torch.empty(M, N, device=a.device, dtype=torch.bfloat16) |
|
|
| has_bias = bias is not None |
| if has_bias: |
| bias = bias.contiguous().float() |
| else: |
| bias = torch.empty(0, device=a.device, dtype=torch.float32) |
|
|
| grid = lambda META: ( |
| triton.cdiv(M, META['BLOCK_M']), |
| triton.cdiv(N, META['BLOCK_N']), |
| ) |
|
|
| _fused_nvfp4_residual_matmul_kernel[grid]( |
| a, w_packed, w_scales, res, res_scales, out, bias, |
| weights.tensor_scale, |
| M, N, K, |
| a.stride(0), a.stride(1), |
| w_packed.stride(0), w_packed.stride(1), |
| w_scales.stride(0), w_scales.stride(1), |
| res.stride(0), res.stride(1), |
| res_scales.stride(0), res_scales.stride(1), |
| out.stride(0), out.stride(1), |
| HAS_BIAS=has_bias, |
| ) |
|
|
| return out |
|
|
|
|
| |
| |
| |
|
|
| _scaled_mm_fp4_probe_result: Optional[bool] = None |
|
|
|
|
| def _can_use_scaled_mm_fp4() -> bool: |
| """ |
| Probe for PyTorch native FP4 scaled matmul (cuBLAS NVFP4 path). |
| |
| Uses 1x16 blockwise scaling: FP4 packed as uint8.view(float4_e2m1fn_x2), |
| E4M3 flat scale tensors with ceil(rows/128)*128 * max(K/16, 4) elements. |
| |
| DISABLED: cuBLAS 1x16 blockwise FP4 has correctness issues with non-128-aligned |
| dimensions (cos_sim drops to 0.30-0.50 for M=1 decode). The Triton fused |
| dequant kernel achieves cos_sim=0.999+ for all shapes. Re-enable when PyTorch |
| exposes a proper NVFP4 GEMM API with 2D scale tensors + SwizzleType support. |
| """ |
| return False |
|
|
|
|
| def _scaled_mm_fp4( |
| activations: torch.Tensor, |
| weights: 'NVFP4Weights', |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Native cuBLAS NVFP4 matmul via torch._scaled_mm. |
| |
| Activations: BF16 [M, K] — quantized to FP4 on-the-fly. |
| Weights: NVFP4Weights with packed [K//2, N] uint8, block_scales [K//16, N] E4M3. |
| |
| Uses 1x16 blockwise scaling on Blackwell 5th-gen tensor cores. |
| Scale layout: flat 1D, ceil(rows/128)*128 * max(K//16, 4) elements. |
| Output is multiplied by both tensor_scales (activation + weight). |
| """ |
| import math |
| M, K = activations.shape |
| K_w, N = weights.shape |
|
|
| |
| act_q = NVFP4Weights.from_float(activations.T.contiguous().float(), use_quest=True) |
| a_packed = act_q.packed.T.contiguous() |
| a_fp4 = a_packed.view(torch.float4_e2m1fn_x2) |
|
|
| b_packed = weights.packed.T.contiguous() |
| b_fp4 = b_packed.view(torch.float4_e2m1fn_x2) |
|
|
| |
| |
| k_groups = max(K // 16, 4) |
|
|
| |
| sa_2d = act_q.block_scales.T.contiguous().view(torch.float8_e4m3fn) |
| sa_padded_rows = math.ceil(M / 128) * 128 |
| |
| if k_groups > K // 16: |
| k_pad = torch.full((sa_2d.shape[0], k_groups - K // 16), 0x3C, |
| dtype=torch.uint8, device=sa_2d.device).view(torch.float8_e4m3fn) |
| sa_2d = torch.cat([sa_2d, k_pad], dim=1) |
| if sa_padded_rows > M: |
| row_pad = torch.full((sa_padded_rows - M, k_groups), 0x3C, |
| dtype=torch.uint8, device=sa_2d.device).view(torch.float8_e4m3fn) |
| sa_2d = torch.cat([sa_2d, row_pad], dim=0) |
| sa_flat = sa_2d.contiguous().view(-1) |
|
|
| |
| sb_2d = weights.block_scales.T.contiguous().view(torch.float8_e4m3fn) |
| sb_padded_rows = math.ceil(N / 128) * 128 |
| if k_groups > K // 16: |
| k_pad = torch.full((sb_2d.shape[0], k_groups - K // 16), 0x3C, |
| dtype=torch.uint8, device=sb_2d.device).view(torch.float8_e4m3fn) |
| sb_2d = torch.cat([sb_2d, k_pad], dim=1) |
| if sb_padded_rows > N: |
| row_pad = torch.full((sb_padded_rows - N, k_groups), 0x3C, |
| dtype=torch.uint8, device=sb_2d.device).view(torch.float8_e4m3fn) |
| sb_2d = torch.cat([sb_2d, row_pad], dim=0) |
| sb_flat = sb_2d.contiguous().view(-1) |
|
|
| |
| out = torch._scaled_mm(a_fp4, b_fp4.T, scale_a=sa_flat, scale_b=sb_flat, |
| out_dtype=torch.bfloat16) |
|
|
| |
| ts = act_q.tensor_scale * weights.tensor_scale |
| out = out.float() * ts |
|
|
| if bias is not None: |
| out = out + bias.float() |
|
|
| return out.to(torch.bfloat16) |
|
|
|
|
| |
| |
| |
|
|
| def nvfp4_gemm( |
| activations: torch.Tensor, |
| weights: 'NVFP4Weights', |
| bias: Optional[torch.Tensor] = None, |
| use_hadamard: bool = True, |
| ) -> torch.Tensor: |
| """ |
| NVFP4 GEMM with hierarchical dispatch. |
| |
| Pipeline: |
| 1. Apply Hadamard transform for outlier mitigation |
| 2. Quantize activations with bucketize (O(K*N) instead of O(K*N*16)) |
| 3. Dispatch to best available kernel: |
| - Tier 0: Native cuBLAS via torch._scaled_mm (if PyTorch supports FP4) |
| - Tier 1: Fused NVFP4 Triton kernel (16-element blocks, E4M3 scales) |
| - Tier 2: CPU fallback |
| |
| Args: |
| activations: Input [M, K] in BF16/FP16 |
| weights: NVFP4Weights with packed E2M1 values, E4M3 scales, tensor scale |
| bias: Optional bias [N] |
| use_hadamard: Apply Hadamard transform (recommended) |
| |
| Returns: |
| Output [M, N] in BF16 |
| """ |
| M, K = activations.shape |
| K_w, N = weights.shape |
| assert K == K_w, f"K dimension mismatch: {K} vs {K_w}" |
| assert K % 16 == 0, f"K ({K}) must be multiple of 16 for NVFP4" |
|
|
| |
| if use_hadamard and K >= 32: |
| x = activations.float().view(M, K // 32, 32) |
| x = _hadamard_transform_32(x) |
| x = x.view(M, K) |
| else: |
| x = activations.float() |
|
|
| |
| if not activations.is_cuda: |
| |
| x_for_quant = x.T.contiguous() |
| x_quant = NVFP4Weights.from_float(x_for_quant, use_quest=True) |
| x_dequant = x_quant.to_float().T.contiguous() |
| w_dequant = weights.to_float() |
| d = torch.matmul(x_dequant, w_dequant) |
| if bias is not None: |
| d = d + bias.float() |
| return d.to(torch.bfloat16) |
|
|
| |
| if _can_use_scaled_mm_fp4(): |
| return _scaled_mm_fp4(x.to(torch.bfloat16), weights, bias) |
|
|
| |
| x_for_quant = x.T.contiguous() |
| x_quant = NVFP4Weights.from_float(x_for_quant, use_quest=True) |
| x_dequant = x_quant.to_float().T.contiguous() |
| return _fused_nvfp4_matmul(x_dequant, weights, bias) |
|
|
|
|
| def quantize_to_nvfp4(weights: torch.Tensor, use_quest: bool = True) -> NVFP4Weights: |
| """ |
| Quantize weights to NVFP4 format. |
| |
| Args: |
| weights: Input tensor [K, N] |
| use_quest: Use QuEST optimal clipping |
| |
| Returns: |
| NVFP4Weights ready for NVFP4 GEMM |
| """ |
| return NVFP4Weights.from_float(weights, use_quest=use_quest) |
|
|
|
|
| |
| fp4_gemm = nvfp4_gemm |
| quantize_to_fp4 = quantize_to_nvfp4 |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| class _AccessPolicyWindow(ctypes.Structure): |
| """Maps to cudaAccessPolicyWindow (CUDA Runtime API).""" |
| _fields_ = [ |
| ("base_ptr", ctypes.c_void_p), |
| ("num_bytes", ctypes.c_size_t), |
| ("hitRatio", ctypes.c_float), |
| ("hitProp", ctypes.c_int), |
| ("missProp", ctypes.c_int), |
| ] |
|
|
|
|
| class _StreamAttrValue(ctypes.Union): |
| """Maps to cudaStreamAttrValue (union).""" |
| _fields_ = [ |
| ("accessPolicyWindow", _AccessPolicyWindow), |
| ("syncPolicy", ctypes.c_int), |
| ] |
|
|
|
|
| |
| _CUDA_ACCESS_PROPERTY_NORMAL = 0 |
| _CUDA_ACCESS_PROPERTY_STREAMING = 1 |
| _CUDA_ACCESS_PROPERTY_PERSISTING = 2 |
|
|
| |
| _CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW = 1 |
|
|
| |
| _CUDA_LIMIT_PERSISTING_L2_CACHE_SIZE = 0x06 |
|
|
| |
| _CUDA_DEV_ATTR_L2_CACHE_SIZE = 89 |
| _CUDA_DEV_ATTR_MAX_PERSISTING_L2_CACHE_SIZE = 108 |
|
|
|
|
| def _load_cudart(): |
| """Load the CUDA runtime shared library, return handle or None.""" |
| for name in ("libcudart.so", "libcudart.so.12", "libcudart.so.11.0"): |
| try: |
| return ctypes.CDLL(name) |
| except OSError: |
| continue |
| try: |
| path = ctypes.util.find_library("cudart") |
| if path: |
| return ctypes.CDLL(path) |
| except (OSError, TypeError): |
| pass |
| return None |
|
|
|
|
| _cudart = _load_cudart() |
|
|
|
|
| |
| |
| |
|
|
| class L2CacheManager: |
| """ |
| L2 Cache Manager for SM90+ GPUs. |
| |
| Uses ctypes/libcudart.so cudaAccessPolicyWindow to pin hot data |
| (embeddings, weights, KV cache) in L2 for 10-20% inference speedup. |
| |
| When libcudart is not loadable the manager degrades to no-op stubs |
| so the rest of the engine remains functional. |
| |
| Usage: |
| l2 = L2CacheManager() |
| |
| # Pin embedding table |
| l2.pin(embedding_table) |
| |
| # Configure for inference |
| l2.configure_inference( |
| embedding=embedding_table, |
| attention_weights=attn_weights, |
| kv_cache=kv_cache, |
| ) |
| |
| # Reset between batches |
| l2.reset() |
| """ |
|
|
| def __init__(self, device: int = 0): |
| self.device = device |
| self._hw_available = False |
| self._l2_size = 0 |
| self._max_persisting = 0 |
|
|
| self._initialize() |
|
|
| def _initialize(self): |
| """Query device L2 geometry via cudart.""" |
| if not torch.cuda.is_available(): |
| return |
|
|
| |
| props = torch.cuda.get_device_properties(self.device) |
| self._l2_size = getattr(props, 'l2_cache_size', 0) |
|
|
| if _cudart is not None: |
| try: |
| |
| val = ctypes.c_int(0) |
| if (_cudart.cudaDeviceGetAttribute( |
| ctypes.byref(val), |
| ctypes.c_int(_CUDA_DEV_ATTR_L2_CACHE_SIZE), |
| ctypes.c_int(self.device), |
| ) == 0 and val.value > 0): |
| self._l2_size = val.value |
|
|
| |
| val2 = ctypes.c_int(0) |
| if (_cudart.cudaDeviceGetAttribute( |
| ctypes.byref(val2), |
| ctypes.c_int(_CUDA_DEV_ATTR_MAX_PERSISTING_L2_CACHE_SIZE), |
| ctypes.c_int(self.device), |
| ) == 0 and val2.value > 0): |
| self._max_persisting = val2.value |
| else: |
| self._max_persisting = int(self._l2_size * 0.75) |
|
|
| self._hw_available = True |
| except Exception: |
| pass |
|
|
| if self._max_persisting == 0: |
| self._max_persisting = int(self._l2_size * 0.75) |
|
|
| |
| self._set_persisting_limit(self._max_persisting) |
|
|
| |
| |
| |
|
|
| def _set_persisting_limit(self, num_bytes: int) -> bool: |
| if not self._hw_available or _cudart is None: |
| return False |
| return _cudart.cudaDeviceSetLimit( |
| ctypes.c_int(_CUDA_LIMIT_PERSISTING_L2_CACHE_SIZE), |
| ctypes.c_size_t(num_bytes), |
| ) == 0 |
|
|
| def _apply_access_policy(self, tensor: torch.Tensor, hit_ratio: float, |
| stream_ptr: int) -> bool: |
| if not self._hw_available or _cudart is None: |
| return False |
|
|
| window = _AccessPolicyWindow() |
| window.base_ptr = tensor.data_ptr() |
| window.num_bytes = min( |
| tensor.numel() * tensor.element_size(), |
| self._max_persisting, |
| ) |
| window.hitRatio = hit_ratio |
| window.hitProp = _CUDA_ACCESS_PROPERTY_PERSISTING |
| window.missProp = _CUDA_ACCESS_PROPERTY_STREAMING |
|
|
| attr = _StreamAttrValue() |
| attr.accessPolicyWindow = window |
|
|
| return _cudart.cudaStreamSetAttribute( |
| ctypes.c_void_p(stream_ptr), |
| ctypes.c_int(_CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW), |
| ctypes.byref(attr), |
| ) == 0 |
|
|
| def _reset_stream_policy(self, stream_ptr: int) -> bool: |
| if not self._hw_available or _cudart is None: |
| return False |
| attr = _StreamAttrValue() |
| attr.accessPolicyWindow = _AccessPolicyWindow() |
| return _cudart.cudaStreamSetAttribute( |
| ctypes.c_void_p(stream_ptr), |
| ctypes.c_int(_CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW), |
| ctypes.byref(attr), |
| ) == 0 |
|
|
| def _reset_persisting_l2(self) -> bool: |
| if not self._hw_available or _cudart is None: |
| return False |
| return _cudart.cudaCtxResetPersistingL2Cache() == 0 |
|
|
| |
| |
| |
|
|
| @property |
| def l2_size(self) -> int: |
| """Total L2 cache size in bytes.""" |
| return self._l2_size |
|
|
| @property |
| def max_persisting(self) -> int: |
| """Maximum persisting L2 size in bytes.""" |
| return self._max_persisting |
|
|
| def pin( |
| self, |
| tensor: torch.Tensor, |
| hit_ratio: float = 1.0, |
| stream: Optional[torch.cuda.Stream] = None, |
| ) -> bool: |
| """ |
| Pin a tensor in L2 cache via cudaAccessPolicyWindow. |
| |
| Args: |
| tensor: Tensor to pin (must be on CUDA) |
| hit_ratio: Fraction of accesses to persist (0.0-1.0) |
| stream: CUDA stream (default: current) |
| |
| Returns: |
| True on success (or no-op when HW unavailable) |
| """ |
| if not tensor.is_cuda: |
| return False |
|
|
| if self._hw_available: |
| stream_ptr = ( |
| stream.cuda_stream if stream is not None |
| else torch.cuda.current_stream(self.device).cuda_stream |
| ) |
| return self._apply_access_policy(tensor, hit_ratio, stream_ptr) |
|
|
| return True |
|
|
| def set_streaming( |
| self, |
| tensor: torch.Tensor, |
| stream: Optional[torch.cuda.Stream] = None, |
| ) -> bool: |
| """ |
| Mark tensor as streaming (bypass L2 cache). |
| |
| Use for one-time access data to avoid L2 pollution. |
| """ |
| if not tensor.is_cuda: |
| return False |
|
|
| if self._hw_available: |
| stream_ptr = ( |
| stream.cuda_stream if stream is not None |
| else torch.cuda.current_stream(self.device).cuda_stream |
| ) |
| window = _AccessPolicyWindow() |
| window.base_ptr = tensor.data_ptr() |
| window.num_bytes = tensor.numel() * tensor.element_size() |
| window.hitRatio = 0.0 |
| window.hitProp = _CUDA_ACCESS_PROPERTY_STREAMING |
| window.missProp = _CUDA_ACCESS_PROPERTY_STREAMING |
|
|
| attr = _StreamAttrValue() |
| attr.accessPolicyWindow = window |
|
|
| return _cudart.cudaStreamSetAttribute( |
| ctypes.c_void_p(stream_ptr), |
| ctypes.c_int(_CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW), |
| ctypes.byref(attr), |
| ) == 0 |
|
|
| return True |
|
|
| def reset(self) -> bool: |
| """Reset persisting L2 cache. Call between inference batches.""" |
| return self._reset_persisting_l2() if self._hw_available else True |
|
|
| def configure_inference( |
| self, |
| embedding: Optional[torch.Tensor] = None, |
| attention_weights: Optional[torch.Tensor] = None, |
| kv_cache: Optional[torch.Tensor] = None, |
| stream: Optional[torch.cuda.Stream] = None, |
| ) -> bool: |
| """ |
| Configure L2 cache for transformer inference. |
| |
| Pins tensors with appropriate priorities: |
| 1. Embedding table (highest — hit_ratio=1.0) |
| 2. Attention weights (hit_ratio=0.9) |
| 3. KV cache (lowest — hit_ratio=0.7) |
| |
| Args: |
| embedding: Embedding table tensor |
| attention_weights: Combined attention weights |
| kv_cache: KV cache tensor |
| stream: CUDA stream |
| """ |
| success = True |
| if embedding is not None: |
| success = success and self.pin(embedding, 1.0, stream) |
| if attention_weights is not None: |
| success = success and self.pin(attention_weights, 0.9, stream) |
| if kv_cache is not None: |
| success = success and self.pin(kv_cache, 0.7, stream) |
| return success |
|
|
|
|
| |
| |
| |
|
|
| def benchmark_tma_vs_cublas(sizes=None, warmup=10, iters=100): |
| """Benchmark TMA MatMul vs cuBLAS.""" |
| import time |
|
|
| if sizes is None: |
| sizes = [(2048, 2048, 2048), (4096, 4096, 4096), (8192, 8192, 8192)] |
|
|
| print("=" * 60) |
| print("TMA MatMul vs cuBLAS Benchmark") |
| print("=" * 60) |
|
|
| for M, N, K in sizes: |
| a = torch.randn(M, K, device='cuda', dtype=torch.bfloat16) |
| b = torch.randn(K, N, device='cuda', dtype=torch.bfloat16) |
|
|
| |
| for _ in range(warmup): |
| _ = tma_matmul(a, b) |
| _ = torch.matmul(a, b) |
| torch.cuda.synchronize() |
|
|
| |
| start = time.perf_counter() |
| for _ in range(iters): |
| _ = tma_matmul(a, b) |
| torch.cuda.synchronize() |
| tma_time = (time.perf_counter() - start) / iters |
|
|
| |
| start = time.perf_counter() |
| for _ in range(iters): |
| _ = torch.matmul(a, b) |
| torch.cuda.synchronize() |
| cublas_time = (time.perf_counter() - start) / iters |
|
|
| flops = 2 * M * N * K |
| tma_tflops = flops / tma_time / 1e12 |
| cublas_tflops = flops / cublas_time / 1e12 |
| speedup = cublas_time / tma_time |
|
|
| print(f"{M}x{N}x{K}:") |
| print(f" TMA: {tma_tflops:.1f} TFLOPS ({tma_time*1000:.2f}ms)") |
| print(f" cuBLAS: {cublas_tflops:.1f} TFLOPS ({cublas_time*1000:.2f}ms)") |
| print(f" Speedup: {speedup:.2f}x") |
| print() |
|
|
|
|
| def benchmark_fp4_vs_fp16(M=4096, N=4096, K=4096, warmup=10, iters=100): |
| """Benchmark NVFP4, MXFP4, and FP16 GEMM paths.""" |
| import time |
|
|
| print("=" * 60) |
| print("FP4 vs FP16 GEMM Benchmark") |
| print("=" * 60) |
|
|
| |
| w_fp16 = torch.randn(K, N, device='cuda', dtype=torch.float16) |
| w_mxfp4 = quantize_to_mxfp4(w_fp16) |
| w_nvfp4 = quantize_to_nvfp4(w_fp16) |
| a = torch.randn(M, K, device='cuda', dtype=torch.bfloat16) |
|
|
| |
| fp16_bytes = w_fp16.numel() * 2 |
| mxfp4_bytes = w_mxfp4.packed.numel() + w_mxfp4.scales.numel() |
| nvfp4_bytes = w_nvfp4.packed.numel() + w_nvfp4.block_scales.numel() |
|
|
| print(f"Weight memory:") |
| print(f" FP16: {fp16_bytes / 1e6:.1f} MB") |
| print(f" MXFP4: {mxfp4_bytes / 1e6:.1f} MB ({fp16_bytes / mxfp4_bytes:.1f}x smaller)") |
| print(f" NVFP4: {nvfp4_bytes / 1e6:.1f} MB ({fp16_bytes / nvfp4_bytes:.1f}x smaller)") |
| print() |
|
|
| |
| print(f"{M}x{N}x{K} Kernel-only (no activation quant overhead):") |
|
|
| for _ in range(warmup): |
| _fused_nvfp4_matmul(a, w_nvfp4) |
| _fused_fp4_matmul(a, w_mxfp4) |
| torch.matmul(a.half(), w_fp16) |
| torch.cuda.synchronize() |
|
|
| start = time.perf_counter() |
| for _ in range(iters): |
| _fused_nvfp4_matmul(a, w_nvfp4) |
| torch.cuda.synchronize() |
| nvfp4_kern_time = (time.perf_counter() - start) / iters |
|
|
| start = time.perf_counter() |
| for _ in range(iters): |
| _fused_fp4_matmul(a, w_mxfp4) |
| torch.cuda.synchronize() |
| mxfp4_kern_time = (time.perf_counter() - start) / iters |
|
|
| start = time.perf_counter() |
| for _ in range(iters): |
| torch.matmul(a.half(), w_fp16) |
| torch.cuda.synchronize() |
| fp16_time = (time.perf_counter() - start) / iters |
|
|
| flops = 2 * M * N * K |
| print(f" NVFP4 kernel: {flops/nvfp4_kern_time/1e12:.1f} TFLOPS ({nvfp4_kern_time*1000:.2f}ms)") |
| print(f" MXFP4 kernel: {flops/mxfp4_kern_time/1e12:.1f} TFLOPS ({mxfp4_kern_time*1000:.2f}ms)") |
| print(f" BF16 cuBLAS: {flops/fp16_time/1e12:.1f} TFLOPS ({fp16_time*1000:.2f}ms)") |
| print() |
|
|
| |
| print(f"{M}x{N}x{K} Full pipeline (Hadamard + act quant + kernel):") |
|
|
| for _ in range(warmup): |
| nvfp4_gemm(a, w_nvfp4) |
| mxfp4_gemm(a, w_mxfp4) |
| mxfp4_gemm_legacy(a, w_mxfp4) |
| torch.cuda.synchronize() |
|
|
| start = time.perf_counter() |
| for _ in range(iters): |
| nvfp4_gemm(a, w_nvfp4) |
| torch.cuda.synchronize() |
| nvfp4_pipe_time = (time.perf_counter() - start) / iters |
|
|
| start = time.perf_counter() |
| for _ in range(iters): |
| mxfp4_gemm(a, w_mxfp4) |
| torch.cuda.synchronize() |
| mxfp4_pipe_time = (time.perf_counter() - start) / iters |
|
|
| start = time.perf_counter() |
| for _ in range(iters): |
| mxfp4_gemm_legacy(a, w_mxfp4) |
| torch.cuda.synchronize() |
| legacy_time = (time.perf_counter() - start) / iters |
|
|
| print(f" NVFP4 pipeline: {flops/nvfp4_pipe_time/1e12:.1f} TFLOPS ({nvfp4_pipe_time*1000:.2f}ms)") |
| print(f" MXFP4 pipeline: {flops/mxfp4_pipe_time/1e12:.1f} TFLOPS ({mxfp4_pipe_time*1000:.2f}ms)") |
| print(f" MXFP4 legacy: {flops/legacy_time/1e12:.1f} TFLOPS ({legacy_time*1000:.2f}ms)") |
| act_overhead_nv = nvfp4_pipe_time - nvfp4_kern_time |
| act_overhead_mx = mxfp4_pipe_time - mxfp4_kern_time |
| print(f" Act quant overhead: NVFP4={act_overhead_nv*1000:.2f}ms MXFP4={act_overhead_mx*1000:.2f}ms") |
| print() |
|
|
| |
| print(f" Native FP4 probe: {_can_use_native_fp4()}") |
| print(f" Scaled MM FP4 probe: {_can_use_scaled_mm_fp4()}") |
|
|
| |
| |
| |
| out_nv_kern = _fused_nvfp4_matmul(a, w_nvfp4) |
| out_nv_ref = torch.matmul(a.float(), w_nvfp4.to_float()).bfloat16() |
| rel_err_nv = (out_nv_kern.float() - out_nv_ref.float()).abs().mean() / out_nv_ref.float().abs().mean() |
|
|
| out_mx_kern = _fused_fp4_matmul(a, w_mxfp4) |
| out_mx_ref = torch.matmul(a.float(), w_mxfp4.to_float()).bfloat16() |
| rel_err_mx = (out_mx_kern.float() - out_mx_ref.float()).abs().mean() / out_mx_ref.float().abs().mean() |
|
|
| |
| out_mxfp4_fused = mxfp4_gemm(a, w_mxfp4) |
| out_legacy = mxfp4_gemm_legacy(a, w_mxfp4) |
| rel_err_mx_pipe = (out_mxfp4_fused - out_legacy).abs().mean() / out_legacy.abs().mean() |
|
|
| print(f" NVFP4 kernel rel_err (vs matmul): {rel_err_nv:.6f}") |
| print(f" MXFP4 kernel rel_err (vs matmul): {rel_err_mx:.6f}") |
| print(f" MXFP4 fused vs legacy rel_err: {rel_err_mx_pipe:.6f}") |
|
|
|
|
| if __name__ == "__main__": |
| print("FireEcho CUTLASS-Compatible Kernels (self-contained)") |
| print("=" * 60) |
| print(f"Triton available: True") |
| print(f"cudart loaded: {_cudart is not None}") |
|
|
| if torch.cuda.is_available(): |
| l2 = L2CacheManager() |
| print(f"L2 Cache size: {l2.l2_size / 1e6:.0f} MB") |
| print(f"Max persisting: {l2.max_persisting / 1e6:.0f} MB") |
| print(f"HW L2 pinning: {l2._hw_available}") |
| print(f"Native FP4 (dot_scaled): {_can_use_native_fp4()}") |
| print(f"Scaled MM FP4: {_can_use_scaled_mm_fp4()}") |
|
|
| print() |
| benchmark_tma_vs_cublas(sizes=[(2048, 2048, 2048)]) |
| benchmark_fp4_vs_fp16(M=2048, N=2048, K=2048) |
|
|