| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | FireEcho Goliath β Native FP4/FP8/INT2/FE-XC/FE-XVQ Fused Triton GEMM Kernel |
| | =============================================================================== |
| | Part of the FireEcho Engine β Custom inference kernel for NVIDIA Blackwell |
| | Copyright (c) 2025-2026 Echo (FireEcho Project). All rights reserved. |
| | |
| | A unified quantized GEMM kernel that dequantizes inside the Triton matmul |
| | loop β no separate dequantization step, no global memory materialization. |
| | |
| | Features: |
| | 1. NVFP4 GEMM β Fused dequant inside Triton matmul |
| | 2. NVFP8 GEMM β Same approach for FP8 |
| | 3. INT2 GEMM β 2-bit quantization for cold MoE experts |
| | 4. FE-XC GEMM β Codebook 2-bit (2x8) with CodeGEMM psumbook (near-FP16 quality) |
| | 5. FE-XVQ GEMM β Hessian-weighted codebook 2-bit (VPTQ-inspired, second-order optimal) |
| | 6. Packed MoE β Contiguous [128, K//2, N] expert buffers, GPU expert IDs |
| | 7. Fused SwiGLU+Down β Single-kernel gate+up+silu+mul+down |
| | 8. Auto-dispatch β Automatically choose FP4/FP8/INT2/FE-XC/FE-XVQ based on expert temperature |
| | 9. Unified API: |
| | - goliath_quantize(tensor, bits=4 or 8) |
| | - goliath_gemm(activations, quantized_weights) |
| | |
| | Design: |
| | - Loads packed FP4/FP8 data directly from global memory |
| | - Dequantizes in registers (not global memory) β zero extra traffic |
| | - Uses Triton block pointers for efficient memory access |
| | - Supports both BF16 and FP16 accumulation |
| | - Target: 5-10x over the 10.4 TFLOPS baseline |
| | |
| | Usage: |
| | from goliath_kernel import goliath_quantize, goliath_gemm |
| | |
| | # FP4 quantization (maximum compression) |
| | w_q4 = goliath_quantize(weights, bits=4) |
| | out = goliath_gemm(activations, w_q4) |
| | |
| | # FP8 quantization (higher accuracy) |
| | w_q8 = goliath_quantize(weights, bits=8) |
| | out = goliath_gemm(activations, w_q8) |
| | |
| | # Auto mode (let Goliath decide) |
| | w_q = goliath_quantize(weights, bits='auto') |
| | out = goliath_gemm(activations, w_q) |
| | """ |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| | from typing import Optional, Tuple, Union |
| | from dataclasses import dataclass |
| |
|
| |
|
| | |
| | |
| | |
| | try: |
| | from cutlass_kernels import ( |
| | NVFP4Weights as _BridgeNVFP4Weights, |
| | _fused_nvfp4_matmul as _bridge_fused_nvfp4, |
| | _can_use_scaled_mm_fp4 as _bridge_can_use_cublas_fp4, |
| | _scaled_mm_fp4 as _bridge_scaled_mm_fp4, |
| | ) |
| | _NVFP4_BRIDGE_AVAILABLE = True |
| | except Exception: |
| | _BridgeNVFP4Weights = None |
| | _bridge_fused_nvfp4 = None |
| | _bridge_can_use_cublas_fp4 = None |
| | _bridge_scaled_mm_fp4 = None |
| | _NVFP4_BRIDGE_AVAILABLE = False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | _E2M1_VALUES = torch.tensor( |
| | [0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6], |
| | dtype=torch.float32, |
| | ) |
| |
|
| | _E2M1_BOUNDARIES = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _encode_e4m3(values: torch.Tensor) -> torch.Tensor: |
| | """Encode FP32 values to E4M3 (FP8) as uint8.""" |
| | if hasattr(torch, 'float8_e4m3fn'): |
| | return values.clamp(-448.0, 448.0).to(torch.float8_e4m3fn).view(torch.uint8) |
| | v = values.float().clamp(-448.0, 448.0) |
| | sign = (v < 0).to(torch.uint8) << 7 |
| | av = v.abs().clamp(min=0.0) |
| | log2_av = torch.log2(av.clamp(min=2**-9)) |
| | exp_raw = torch.floor(log2_av).clamp(-6, 8) |
| | exp_biased = (exp_raw + 7).clamp(0, 15) |
| | mantissa_f = (av / torch.pow(2.0, exp_raw) - 1.0) * 8.0 |
| | mantissa = mantissa_f.round().clamp(0, 7).to(torch.uint8) |
| | sub_mant = (av / (2**-6) * 8.0).round().clamp(0, 7).to(torch.uint8) |
| | is_sub = exp_biased == 0 |
| | final_mant = torch.where(is_sub, sub_mant, mantissa) |
| | return sign | (exp_biased.to(torch.uint8) << 3) | final_mant |
| |
|
| |
|
| | def _decode_e4m3(encoded: torch.Tensor) -> torch.Tensor: |
| | """Decode E4M3 uint8 back to FP32.""" |
| | if hasattr(torch, 'float8_e4m3fn'): |
| | return encoded.view(torch.float8_e4m3fn).float() |
| | sign = ((encoded >> 7) & 1).float() |
| | exp = ((encoded >> 3) & 0xF).long() |
| | mant = (encoded & 0x7).long() |
| | is_normal = exp > 0 |
| | normal_val = (8 + mant).float() * torch.pow(2.0, (exp - 10).float()) |
| | subnormal_val = mant.float() * (2.0 ** -9) |
| | unsigned = torch.where(is_normal, normal_val, subnormal_val) |
| | return torch.where(sign != 0, -unsigned, unsigned) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class GoliathFP4Weights: |
| | """ |
| | Goliath FP4 quantized weights β fused-dequant format. |
| | |
| | 16-element blocks, E4M3 scales, per-tensor FP32 scale. |
| | Packed as 2 nibbles per uint8 byte (same as NVFP4). |
| | |
| | Optional FP8 residual correction (double-buff) for near-FP16 accuracy. |
| | """ |
| | packed: torch.Tensor |
| | block_scales: torch.Tensor |
| | tensor_scale: float |
| | shape: Tuple[int, int] |
| | bits: int = 4 |
| | |
| | residual: Optional[torch.Tensor] = None |
| | residual_scales: Optional[torch.Tensor] = None |
| |
|
| | @classmethod |
| | def from_float(cls, weights: torch.Tensor, training: bool = False, |
| | sr_seed: Optional[int] = None, |
| | compute_residual: bool = False) -> 'GoliathFP4Weights': |
| | K, N = weights.shape |
| | assert K % 16 == 0, f"K ({K}) must be multiple of 16 for FP4" |
| |
|
| | device = weights.device |
| | w = weights.float() |
| |
|
| | reshaped = w.view(K // 16, 16, N) |
| | absmax = reshaped.abs().amax(dim=1).clamp(min=1e-10) |
| |
|
| | |
| | global_absmax = absmax.max().clamp(min=1e-10) |
| | tensor_scale = (global_absmax / 448.0).item() |
| |
|
| | target = absmax / (tensor_scale * 6.0) |
| | target = target.clamp(min=1e-10) |
| | block_scales_fp8 = _encode_e4m3(target) |
| |
|
| | actual_scale = _decode_e4m3(block_scales_fp8) * tensor_scale |
| | actual_scale = actual_scale.clamp(min=1e-10) |
| |
|
| | normalized = (reshaped / actual_scale.unsqueeze(1)).clamp(-6.0, 6.0) |
| |
|
| | |
| | boundaries = _E2M1_BOUNDARIES.to(device) |
| | abs_norm = normalized.abs().reshape(-1) |
| | unsigned_idx = torch.bucketize(abs_norm, boundaries) |
| |
|
| | |
| | if training and sr_seed is not None: |
| | e2m1_unsigned = _E2M1_VALUES[:8].to(device) |
| | lower_val = e2m1_unsigned[unsigned_idx.clamp(max=7)] |
| | upper_val = e2m1_unsigned[(unsigned_idx + 1).clamp(max=7)] |
| | spread = (upper_val - lower_val).clamp(min=1e-10) |
| | frac = (abs_norm - lower_val) / spread |
| | gen = torch.Generator(device=device).manual_seed(sr_seed) |
| | rand = torch.rand_like(frac, generator=gen) |
| | unsigned_idx = torch.where(rand < frac, unsigned_idx + 1, unsigned_idx).clamp(0, 7) |
| |
|
| | sign_bit = ((normalized.reshape(-1) < 0) & (unsigned_idx > 0)).to(torch.uint8) << 3 |
| | indices = (sign_bit | unsigned_idx.to(torch.uint8)).reshape(K, N) |
| |
|
| | packed = (indices[0::2] | (indices[1::2] << 4)) |
| |
|
| | |
| | residual_e4m3 = None |
| | residual_scales = None |
| | if compute_residual: |
| | fp4_approx = cls( |
| | packed=packed, block_scales=block_scales_fp8, |
| | tensor_scale=tensor_scale, shape=(K, N), |
| | ).to_float() |
| | residual_float = w - fp4_approx |
| | res_blocks = residual_float.view(K // 16, 16, N) |
| | res_absmax = res_blocks.abs().amax(dim=1).clamp(min=1e-10) |
| | res_scale = res_absmax / 448.0 |
| | res_normalized = res_blocks / res_scale.unsqueeze(1) |
| | res_normalized = res_normalized.clamp(-448.0, 448.0) |
| | residual_e4m3 = res_normalized.view(K, N).to(torch.float8_e4m3fn).view(torch.uint8) |
| | residual_scales = res_scale |
| |
|
| | return cls( |
| | packed=packed, |
| | block_scales=block_scales_fp8, |
| | tensor_scale=tensor_scale, |
| | shape=(K, N), |
| | residual=residual_e4m3, |
| | residual_scales=residual_scales, |
| | ) |
| |
|
| | def to_float(self) -> torch.Tensor: |
| | K, N = self.shape |
| | device = self.packed.device |
| | e2m1 = _E2M1_VALUES.to(device) |
| |
|
| | low = (self.packed & 0xF).long() |
| | high = (self.packed >> 4).long() |
| |
|
| | low_vals = e2m1[low.flatten()].view(K // 2, N) |
| | high_vals = e2m1[high.flatten()].view(K // 2, N) |
| |
|
| | unpacked = torch.zeros(K, N, device=device, dtype=torch.float32) |
| | unpacked[0::2] = low_vals |
| | unpacked[1::2] = high_vals |
| |
|
| | block_sf = _decode_e4m3(self.block_scales) |
| | scale = block_sf * self.tensor_scale |
| | unpacked = unpacked.view(K // 16, 16, N) * scale.unsqueeze(1) |
| | return unpacked.view(K, N) |
| |
|
| |
|
| | @dataclass |
| | class GoliathFP4NativeWeights: |
| | """ |
| | Goliath FP4 native weights for ``tl.dot_scaled`` code path. |
| | |
| | Layout required by ``tl.dot_scaled(..., "e2m1")``: |
| | - packed_col_major: [N, K//2] uint8 β E2M1 nibbles, col-major for RHS |
| | - scales_e8m0: [N, K//32] uint8 β E8M0 power-of-two scales (32-element blocks) |
| | |
| | Created from a GoliathFP4Weights instance via ``from_goliath_fp4()``. |
| | """ |
| | packed_col_major: torch.Tensor |
| | scales_e8m0: torch.Tensor |
| | shape: Tuple[int, int] |
| | bits: int = 4 |
| |
|
| | @classmethod |
| | def from_goliath_fp4(cls, w: GoliathFP4Weights) -> 'GoliathFP4NativeWeights': |
| | """Convert GoliathFP4Weights β native dot_scaled layout. |
| | |
| | Transposes packed data to col-major [N, K//2] and converts |
| | E4M3 16-element block scales β E8M0 32-element block scales. |
| | """ |
| | K, N = w.shape |
| | device = w.packed.device |
| |
|
| | |
| | packed_col = w.packed.T.contiguous() |
| |
|
| | |
| | |
| | |
| | scales_f = _decode_e4m3(w.block_scales) |
| | num_16_groups = K // 16 |
| |
|
| | if num_16_groups >= 2: |
| | |
| | scales_paired = scales_f.view(num_16_groups // 2, 2, N) |
| | |
| | scales_32 = scales_paired.amax(dim=1) * w.tensor_scale |
| | else: |
| | scales_32 = scales_f * w.tensor_scale |
| |
|
| | |
| | |
| | safe_scales = scales_32.clamp(min=1e-20) |
| | log2_val = torch.log2(safe_scales / 6.0) |
| | e8m0 = (log2_val.round() + 127).clamp(0, 254).to(torch.uint8) |
| |
|
| | |
| | scales_col = e8m0.T.contiguous() |
| |
|
| | return cls( |
| | packed_col_major=packed_col, |
| | scales_e8m0=scales_col, |
| | shape=(K, N), |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class GoliathFP8Weights: |
| | """ |
| | Goliath FP8 quantized weights β fused-dequant format. |
| | |
| | Per-block FP32 scales, data stored as uint8 E4M3 encoding. |
| | Block size: 32 elements (matches Triton tile granularity for shared memory). |
| | """ |
| | data: torch.Tensor |
| | block_scales: torch.Tensor |
| | shape: Tuple[int, int] |
| | bits: int = 8 |
| |
|
| | @classmethod |
| | def from_float(cls, weights: torch.Tensor) -> 'GoliathFP8Weights': |
| | K, N = weights.shape |
| | assert K % 32 == 0, f"K ({K}) must be multiple of 32 for FP8" |
| |
|
| | device = weights.device |
| | w = weights.float() |
| |
|
| | reshaped = w.view(K // 32, 32, N) |
| | absmax = reshaped.abs().amax(dim=1).clamp(min=1e-10) |
| |
|
| | |
| | block_scales = absmax / 448.0 |
| |
|
| | normalized = reshaped / block_scales.unsqueeze(1) |
| | normalized = normalized.clamp(-448.0, 448.0) |
| |
|
| | |
| | data = _encode_e4m3(normalized.reshape(K, N)) |
| |
|
| | return cls( |
| | data=data, |
| | block_scales=block_scales, |
| | shape=(K, N), |
| | ) |
| |
|
| | def to_float(self) -> torch.Tensor: |
| | K, N = self.shape |
| | decoded = _decode_e4m3(self.data) |
| | decoded = decoded.view(K // 32, 32, N) * self.block_scales.unsqueeze(1) |
| | return decoded.view(K, N) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class GoliathINT2Weights: |
| | """ |
| | Goliath INT2 quantized weights β aggressive compression for cold experts. |
| | |
| | 32-element groups, FP16 scales, 4 weights packed per uint8 byte. |
| | Values map to {-2, -1, 0, 1} Γ scale (symmetric 2-bit). |
| | """ |
| | packed: torch.Tensor |
| | block_scales: torch.Tensor |
| | shape: Tuple[int, int] |
| | bits: int = 2 |
| |
|
| | @classmethod |
| | def from_float(cls, weights: torch.Tensor, stochastic: bool = False) -> 'GoliathINT2Weights': |
| | """Quantize FP32/BF16 weights to INT2 format. |
| | |
| | Args: |
| | weights: Input [K, N] tensor |
| | stochastic: Use stochastic rounding (reduces bias for training) |
| | |
| | Returns: |
| | GoliathINT2Weights with packed 2-bit data |
| | """ |
| | K, N = weights.shape |
| | assert K % 32 == 0, f"K ({K}) must be multiple of 32 for INT2" |
| |
|
| | device = weights.device |
| | w = weights.float() |
| |
|
| | |
| | reshaped = w.view(K // 32, 32, N) |
| | absmax = reshaped.abs().amax(dim=1).clamp(min=1e-10) |
| |
|
| | |
| | block_scales = (absmax / 2.0).to(torch.float16) |
| |
|
| | |
| | scale_expanded = block_scales.unsqueeze(1).float() |
| | normalized = reshaped / scale_expanded.clamp(min=1e-10) |
| |
|
| | |
| | if stochastic: |
| | |
| | noise = torch.rand_like(normalized) - 0.5 |
| | quantized = torch.round(normalized + noise * 0.5) |
| | else: |
| | quantized = torch.round(normalized) |
| |
|
| | quantized = quantized.clamp(-2, 1).to(torch.int8) |
| |
|
| | |
| | unsigned = (quantized + 2).to(torch.uint8) |
| |
|
| | |
| | unsigned = unsigned.view(K, N) |
| |
|
| | |
| | packed = (unsigned[0::4] | |
| | (unsigned[1::4] << 2) | |
| | (unsigned[2::4] << 4) | |
| | (unsigned[3::4] << 6)) |
| |
|
| | return cls( |
| | packed=packed, |
| | block_scales=block_scales, |
| | shape=(K, N), |
| | ) |
| |
|
| | @classmethod |
| | def from_fp4(cls, fp4_weights: GoliathFP4Weights) -> 'GoliathINT2Weights': |
| | """Convert FP4 weights to INT2 (for demoting cold experts).""" |
| | return cls.from_float(fp4_weights.to_float()) |
| |
|
| | def to_float(self) -> torch.Tensor: |
| | """Dequantize INT2 weights back to FP32.""" |
| | K, N = self.shape |
| | device = self.packed.device |
| |
|
| | |
| | w0 = (self.packed & 0x3).to(torch.int8) - 2 |
| | w1 = ((self.packed >> 2) & 0x3).to(torch.int8) - 2 |
| | w2 = ((self.packed >> 4) & 0x3).to(torch.int8) - 2 |
| | w3 = ((self.packed >> 6) & 0x3).to(torch.int8) - 2 |
| |
|
| | |
| | unpacked = torch.zeros(K, N, device=device, dtype=torch.float32) |
| | unpacked[0::4] = w0.float() |
| | unpacked[1::4] = w1.float() |
| | unpacked[2::4] = w2.float() |
| | unpacked[3::4] = w3.float() |
| |
|
| | |
| | unpacked = unpacked.view(K // 32, 32, N) |
| | unpacked = unpacked * self.block_scales.unsqueeze(1).float() |
| |
|
| | return unpacked.view(K, N) |
| |
|
| | def memory_bytes(self) -> int: |
| | """Return memory usage in bytes.""" |
| | return self.packed.numel() + self.block_scales.numel() * 2 |
| |
|
| |
|
| | @dataclass |
| | class GoliathFEXCWeights: |
| | """ |
| | FE-XC (FireEcho Xtreme Compress) β Codebook-based 2-bit quantization. |
| | |
| | Uses AQLM-style 2x8 additive codebooks: each group of 8 weights is |
| | represented as C0[idx0] + C1[idx1] where C0/C1 are learned codebooks |
| | with 256 centroids each. Achieves near-FP16 quality at 2 bits/weight. |
| | |
| | Combined with CodeGEMM-style psumbook precomputation for fast inference: |
| | precompute psumbook[m,c,j] = dot(codebook[m][c], input[j*8:(j+1)*8]) |
| | once per token, then the matmul reduces to scalar gathers + adds. |
| | |
| | Layout: |
| | codes: [K, N//8, 2] uint8 β 2 codebook indices per 8-weight group |
| | codebooks: [2, 256, 8] float16 β 2 codebooks, 256 centroids, 8 elements |
| | scales: [K] float16 β per-output-channel scale |
| | """ |
| | codes: torch.Tensor |
| | codebooks: torch.Tensor |
| | scales: torch.Tensor |
| | shape: Tuple[int, int] |
| | bits: int = 2 |
| | group_size: int = 8 |
| |
|
| | @classmethod |
| | def from_float( |
| | cls, |
| | weights: torch.Tensor, |
| | codebooks: Optional[torch.Tensor] = None, |
| | n_centroids: int = 256, |
| | n_iters: int = 20, |
| | ) -> 'GoliathFEXCWeights': |
| | """Quantize FP32/BF16 weights to FE-XC 2x8 codebook format. |
| | |
| | Uses residual k-means: learn codebook_0 on raw groups, then |
| | codebook_1 on residuals. If codebooks are provided (shared), |
| | skips k-means and only assigns codes. |
| | |
| | Args: |
| | weights: Input [K, N] tensor |
| | codebooks: Optional pre-learned [2, 256, 8] codebooks (shared across experts) |
| | n_centroids: Number of codebook entries (default 256) |
| | n_iters: K-means iterations (default 20) |
| | |
| | Returns: |
| | GoliathFEXCWeights with codebook indices + shared codebooks |
| | """ |
| | K, N = weights.shape |
| | g = 8 |
| | assert N % g == 0, f"N ({N}) must be multiple of group_size {g}" |
| |
|
| | device = weights.device |
| | w = weights.float() |
| |
|
| | |
| | groups = w.view(-1, g) |
| | num_groups = groups.shape[0] |
| |
|
| | if codebooks is None: |
| | |
| | codebooks = torch.zeros(2, n_centroids, g, device=device, dtype=torch.float32) |
| |
|
| | for cb_idx in range(2): |
| | if cb_idx == 0: |
| | data = groups |
| | else: |
| | |
| | nearest_0 = codebooks[0][codes_0.long()] |
| | data = groups - nearest_0 |
| |
|
| | |
| | perm = torch.randperm(num_groups, device=device)[:n_centroids] |
| | centroids = data[perm].clone() |
| |
|
| | for _ in range(n_iters): |
| | |
| | |
| | dists = torch.cdist(data, centroids) |
| | assignments = dists.argmin(dim=1) |
| |
|
| | |
| | for c in range(n_centroids): |
| | mask = (assignments == c) |
| | if mask.any(): |
| | centroids[c] = data[mask].mean(dim=0) |
| |
|
| | codebooks[cb_idx] = centroids |
| |
|
| | if cb_idx == 0: |
| | codes_0 = dists.argmin(dim=1).to(torch.uint8) |
| |
|
| | codebooks = codebooks.to(torch.float16) |
| | else: |
| | codebooks = codebooks.to(device=device) |
| |
|
| | |
| | cb_float = codebooks.float() |
| |
|
| | |
| | dists_0 = torch.cdist(groups, cb_float[0]) |
| | codes_0 = dists_0.argmin(dim=1).to(torch.uint8) |
| |
|
| | |
| | residual = groups - cb_float[0][codes_0.long()] |
| |
|
| | |
| | dists_1 = torch.cdist(residual, cb_float[1]) |
| | codes_1 = dists_1.argmin(dim=1).to(torch.uint8) |
| |
|
| | |
| | codes = torch.stack([codes_0, codes_1], dim=1).view(K, N // g, 2) |
| |
|
| | |
| | reconstructed = (cb_float[0][codes_0.long()] + cb_float[1][codes_1.long()]).view(K, N) |
| | row_norms_orig = w.norm(dim=1).clamp(min=1e-10) |
| | row_norms_recon = reconstructed.norm(dim=1).clamp(min=1e-10) |
| | scales = (row_norms_orig / row_norms_recon).to(torch.float16) |
| |
|
| | return cls( |
| | codes=codes, |
| | codebooks=codebooks.to(torch.float16), |
| | scales=scales, |
| | shape=(K, N), |
| | ) |
| |
|
| | @classmethod |
| | def from_fp4(cls, fp4_weights: GoliathFP4Weights, |
| | codebooks: Optional[torch.Tensor] = None) -> 'GoliathFEXCWeights': |
| | """Convert FP4 weights to FE-XC (for demoting cold experts).""" |
| | return cls.from_float(fp4_weights.to_float(), codebooks=codebooks) |
| |
|
| | def to_float(self) -> torch.Tensor: |
| | """Dequantize FE-XC weights back to FP32.""" |
| | K, N = self.shape |
| | g = self.group_size |
| | cb = self.codebooks.float() |
| |
|
| | codes_flat = self.codes.view(-1, 2).long() |
| | reconstructed = cb[0][codes_flat[:, 0]] + cb[1][codes_flat[:, 1]] |
| | reconstructed = reconstructed.view(K, N) |
| |
|
| | |
| | reconstructed = reconstructed * self.scales.float().unsqueeze(1) |
| |
|
| | return reconstructed |
| |
|
| | def memory_bytes(self) -> int: |
| | """Return memory usage in bytes.""" |
| | return (self.codes.numel() |
| | + self.codebooks.numel() * 2 |
| | + self.scales.numel() * 2) |
| |
|
| |
|
| | @dataclass |
| | class GoliathFEXVQWeights: |
| | """ |
| | FE-XVQ (FireEcho XVector Quantization) β Hessian-weighted codebook 2-bit. |
| | |
| | VPTQ-inspired: uses second-order information (Hessian diagonal from |
| | calibration data) to weight the k-means objective. Errors in important |
| | dimensions (high Hessian) are penalized more, producing better codebooks |
| | than FE-XC's plain MSE k-means. |
| | |
| | Same storage format as FE-XC β reuses the same inference kernel (psumbook |
| | CodeGEMM). Only the codebook LEARNING differs. |
| | |
| | For a linear layer y = Wx, the Hessian H = X^T X (input covariance). |
| | Quantization error Ξ΄w contributes Ξ΄w^T H Ξ΄w to output loss. |
| | FE-XVQ minimizes this weighted error instead of plain ||Ξ΄w||^2. |
| | |
| | Layout (identical to FE-XC): |
| | codes: [K, N//8, 2] uint8 β 2 codebook indices per 8-weight group |
| | codebooks: [2, 256, 8] float16 β 2 codebooks, 256 centroids, 8 elements |
| | scales: [K] float16 β per-output-channel scale |
| | """ |
| | codes: torch.Tensor |
| | codebooks: torch.Tensor |
| | scales: torch.Tensor |
| | shape: Tuple[int, int] |
| | bits: int = 2 |
| | group_size: int = 8 |
| |
|
| | @classmethod |
| | def from_float( |
| | cls, |
| | weights: torch.Tensor, |
| | hessian_diag: Optional[torch.Tensor] = None, |
| | codebooks: Optional[torch.Tensor] = None, |
| | n_centroids: int = 256, |
| | n_iters: int = 20, |
| | ) -> 'GoliathFEXVQWeights': |
| | """Quantize weights to FE-XVQ format with Hessian-weighted codebooks. |
| | |
| | When hessian_diag is provided, uses importance-weighted k-means: |
| | distance = sum(h_i * (w_i - c_i)^2) instead of plain MSE. |
| | This prioritizes accuracy on dimensions that matter most for output. |
| | |
| | The Hessian diagonal is averaged within each group of 8 to produce |
| | per-element importance weights [8]. Both data and centroids are |
| | pre-scaled by sqrt(h_avg), converting Mahalanobis distance to |
| | Euclidean β enabling efficient torch.cdist. |
| | |
| | Falls back to plain k-means (FE-XC equivalent) when hessian_diag=None. |
| | |
| | Args: |
| | weights: Input [K, N] tensor (K=out_features, N=in_features) |
| | hessian_diag: Optional [N] tensor β diagonal of H = X^T X |
| | codebooks: Optional pre-learned [2, 256, 8] codebooks |
| | n_centroids: Number of codebook entries (default 256) |
| | n_iters: K-means iterations (default 20) |
| | |
| | Returns: |
| | GoliathFEXVQWeights with Hessian-optimal codebook indices |
| | """ |
| | K, N = weights.shape |
| | g = 8 |
| | assert N % g == 0, f"N ({N}) must be multiple of group_size {g}" |
| |
|
| | device = weights.device |
| | w = weights.float() |
| |
|
| | |
| | groups = w.view(-1, g) |
| | num_groups = groups.shape[0] |
| |
|
| | |
| | |
| | |
| | |
| | if hessian_diag is not None: |
| | h = hessian_diag.float().to(device) |
| | h = h / h.mean().clamp(min=1e-10) |
| | h_groups = h.view(-1, g) |
| | h_avg = h_groups.mean(dim=0) |
| | sqrt_h = h_avg.sqrt().unsqueeze(0) |
| | else: |
| | sqrt_h = None |
| | h_avg = None |
| |
|
| | |
| | groups_w = groups * sqrt_h if sqrt_h is not None else groups |
| |
|
| | if codebooks is None: |
| | codebooks = torch.zeros(2, n_centroids, g, device=device, dtype=torch.float32) |
| |
|
| | for cb_idx in range(2): |
| | if cb_idx == 0: |
| | data = groups |
| | else: |
| | nearest_0 = codebooks[0][codes_0.long()] |
| | data = groups - nearest_0 |
| |
|
| | data_w = data * sqrt_h if sqrt_h is not None else data |
| |
|
| | |
| | perm = torch.randperm(num_groups, device=device)[:n_centroids] |
| | centroids = data[perm].clone() |
| |
|
| | for _ in range(n_iters): |
| | |
| | cent_w = centroids * sqrt_h if sqrt_h is not None else centroids |
| | dists = torch.cdist(data_w, cent_w) |
| | assignments = dists.argmin(dim=1) |
| |
|
| | |
| | for c in range(n_centroids): |
| | mask = (assignments == c) |
| | if mask.any(): |
| | if h_avg is not None: |
| | |
| | d_masked = data[mask] |
| | centroids[c] = (h_avg * d_masked).sum(0) / (h_avg * mask.sum()).clamp(min=1e-10) |
| | else: |
| | centroids[c] = data[mask].mean(dim=0) |
| |
|
| | codebooks[cb_idx] = centroids |
| | if cb_idx == 0: |
| | codes_0 = dists.argmin(dim=1).to(torch.uint8) |
| |
|
| | codebooks = codebooks.to(torch.float16) |
| | else: |
| | codebooks = codebooks.to(device=device) |
| |
|
| | |
| | cb_float = codebooks.float() |
| | cb0_w = cb_float[0] * sqrt_h if sqrt_h is not None else cb_float[0] |
| | dists_0 = torch.cdist(groups_w, cb0_w) |
| | codes_0 = dists_0.argmin(dim=1).to(torch.uint8) |
| |
|
| | residual = groups - cb_float[0][codes_0.long()] |
| | residual_w = residual * sqrt_h if sqrt_h is not None else residual |
| | cb1_w = cb_float[1] * sqrt_h if sqrt_h is not None else cb_float[1] |
| | dists_1 = torch.cdist(residual_w, cb1_w) |
| | codes_1 = dists_1.argmin(dim=1).to(torch.uint8) |
| |
|
| | |
| | codes = torch.stack([codes_0, codes_1], dim=1).view(K, N // g, 2) |
| |
|
| | |
| | reconstructed = (cb_float[0][codes_0.long()] + cb_float[1][codes_1.long()]).view(K, N) |
| | if hessian_diag is not None: |
| | h_row = hessian_diag.float().to(device).unsqueeze(0) |
| | row_norms_orig = (w * w * h_row).sum(dim=1).sqrt().clamp(min=1e-10) |
| | row_norms_recon = (reconstructed * reconstructed * h_row).sum(dim=1).sqrt().clamp(min=1e-10) |
| | else: |
| | row_norms_orig = w.norm(dim=1).clamp(min=1e-10) |
| | row_norms_recon = reconstructed.norm(dim=1).clamp(min=1e-10) |
| | scales = (row_norms_orig / row_norms_recon).to(torch.float16) |
| |
|
| | return cls( |
| | codes=codes, |
| | codebooks=codebooks.to(torch.float16), |
| | scales=scales, |
| | shape=(K, N), |
| | ) |
| |
|
| | @classmethod |
| | def from_fexc(cls, fexc_weights: GoliathFEXCWeights) -> 'GoliathFEXVQWeights': |
| | """Promote FE-XC weights to FE-XVQ (same data, different type tag).""" |
| | return cls( |
| | codes=fexc_weights.codes, |
| | codebooks=fexc_weights.codebooks, |
| | scales=fexc_weights.scales, |
| | shape=fexc_weights.shape, |
| | ) |
| |
|
| | def to_fexc(self) -> GoliathFEXCWeights: |
| | """Downcast to FE-XC (for using FE-XC inference kernel).""" |
| | return GoliathFEXCWeights( |
| | codes=self.codes, |
| | codebooks=self.codebooks, |
| | scales=self.scales, |
| | shape=self.shape, |
| | ) |
| |
|
| | def to_float(self) -> torch.Tensor: |
| | """Dequantize FE-XVQ weights back to FP32 (same as FE-XC).""" |
| | K, N = self.shape |
| | g = self.group_size |
| | cb = self.codebooks.float() |
| | codes_flat = self.codes.view(-1, 2).long() |
| | reconstructed = cb[0][codes_flat[:, 0]] + cb[1][codes_flat[:, 1]] |
| | reconstructed = reconstructed.view(K, N) |
| | reconstructed = reconstructed * self.scales.float().unsqueeze(1) |
| | return reconstructed |
| |
|
| | def memory_bytes(self) -> int: |
| | """Return memory usage in bytes.""" |
| | return (self.codes.numel() |
| | + self.codebooks.numel() * 2 |
| | + self.scales.numel() * 2) |
| |
|
| |
|
| | |
| | GoliathWeights = Union[GoliathFP4Weights, GoliathFP8Weights, GoliathINT2Weights, GoliathFEXCWeights, GoliathFEXVQWeights] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @triton.jit |
| | def _int2_decode(packed_byte, offset): |
| | """Decode 2-bit value from packed byte at given offset (0-3). |
| | |
| | packed_byte: uint8 with 4 packed 2-bit values |
| | offset: which 2-bit value to extract (0=bits 0-1, 1=bits 2-3, etc.) |
| | |
| | Returns: float32 value in {-2, -1, 0, 1} |
| | """ |
| | shift = offset * 2 |
| | unsigned = (packed_byte >> shift) & 0x3 |
| | return (unsigned.to(tl.float32) - 2.0) |
| |
|
| |
|
| | @triton.jit |
| | def _e2m1_decode(idx): |
| | """Decode 4-bit E2M1 index -> float32.""" |
| | sign = (idx >> 3) & 1 |
| | exp = (idx >> 1) & 3 |
| | mant = idx & 1 |
| | is_normal = exp > 0 |
| | subnormal_val = mant.to(tl.float32) * 0.5 |
| | normal_val = (2 + mant).to(tl.float32) * tl.exp2((exp - 2).to(tl.float32)) |
| | unsigned_val = tl.where(is_normal, normal_val, subnormal_val) |
| | return tl.where(sign != 0, -unsigned_val, unsigned_val) |
| |
|
| |
|
| | @triton.jit |
| | def _decode_e4m3_triton(raw_uint8): |
| | """Decode E4M3 FP8 in Triton registers.""" |
| | sign = (raw_uint8 >> 7) & 1 |
| | exp = (raw_uint8 >> 3) & 0xF |
| | mant = raw_uint8 & 0x7 |
| | is_normal = exp > 0 |
| | normal_val = (8 + mant).to(tl.float32) * tl.exp2((exp - 10).to(tl.float32)) |
| | subnormal_val = mant.to(tl.float32) * tl.exp2(tl.full(mant.shape, -9.0, tl.float32)) |
| | unsigned = tl.where(is_normal, normal_val, subnormal_val) |
| | return tl.where(sign != 0, -unsigned, unsigned) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @triton.autotune( |
| | configs=[ |
| | |
| | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8), |
| | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | |
| | |
| | |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=2), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | ], |
| | key=['M', 'N', 'K'], |
| | ) |
| | @triton.jit |
| | def _goliath_fp4_kernel( |
| | a_ptr, |
| | w_packed_ptr, |
| | w_scales_ptr, |
| | out_ptr, |
| | bias_ptr, |
| | tensor_scale, |
| | M, N, K, |
| | stride_am, stride_ak, |
| | stride_wk, stride_wn, |
| | stride_sk, stride_sn, |
| | stride_om, stride_on, |
| | HAS_BIAS: tl.constexpr, |
| | ACC_DTYPE: tl.constexpr, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | ): |
| | """ |
| | Goliath FP4 fused dequant-matmul kernel. |
| | |
| | Loads packed E2M1 nibbles, dequantizes in registers via arithmetic decode, |
| | applies two-level E4M3 + FP32 scaling, and accumulates via tl.dot(). |
| | The full dequantized weight matrix NEVER exists in global memory. |
| | """ |
| | pid_m = tl.program_id(0) |
| | pid_n = tl.program_id(1) |
| |
|
| | offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| | SCALES_PER_TILE: tl.constexpr = BLOCK_K // 16 |
| |
|
| | for k_start in range(0, K, BLOCK_K): |
| | |
| | even_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| | odd_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
| |
|
| | a_even_ptrs = a_ptr + offs_m[:, None] * stride_am + even_k[None, :] * stride_ak |
| | a_odd_ptrs = a_ptr + offs_m[:, None] * stride_am + odd_k[None, :] * stride_ak |
| | mask_ae = (offs_m[:, None] < M) & (even_k[None, :] < K) |
| | mask_ao = (offs_m[:, None] < M) & (odd_k[None, :] < K) |
| | a_even = tl.load(a_even_ptrs, mask=mask_ae, other=0.0) |
| | a_odd = tl.load(a_odd_ptrs, mask=mask_ao, other=0.0) |
| |
|
| | |
| | pk_start = k_start // 2 |
| | offs_pk = pk_start + tl.arange(0, HALF_BLOCK_K) |
| | w_ptrs = w_packed_ptr + offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| | mask_w = (offs_pk[:, None] < (K // 2)) & (offs_n[None, :] < N) |
| | packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
| |
|
| | |
| | low_f = _e2m1_decode(packed & 0xF) |
| | high_f = _e2m1_decode((packed >> 4) & 0xF) |
| |
|
| | |
| | scale_start = k_start // 16 |
| | offs_local = tl.arange(0, HALF_BLOCK_K) |
| | group_idx = offs_local // 8 |
| |
|
| | scale_bc = tl.zeros((HALF_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| | for sg in tl.static_range(0, SCALES_PER_TILE): |
| | sg_row = scale_start + sg |
| | sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| | sg_mask = (sg_row < (K // 16)) & (offs_n < N) |
| | sg_raw = tl.load(sg_ptrs, mask=sg_mask, other=0).to(tl.int32) |
| | sg_val = _decode_e4m3_triton(sg_raw) * tensor_scale |
| | sg_match = (group_idx == sg) |
| | scale_bc = tl.where(sg_match[:, None], sg_val[None, :], scale_bc) |
| |
|
| | |
| | w_even = (low_f * scale_bc).to(tl.bfloat16) |
| | w_odd = (high_f * scale_bc).to(tl.bfloat16) |
| |
|
| | |
| | acc += tl.dot(a_even.to(tl.bfloat16), w_even) |
| | acc += tl.dot(a_odd.to(tl.bfloat16), w_odd) |
| |
|
| | |
| | if HAS_BIAS: |
| | bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) |
| | acc += bias[None, :] |
| |
|
| | |
| | out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| | mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=3, num_warps=8), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 256}, num_stages=3, num_warps=8), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 256}, num_stages=3, num_warps=8), |
| | ], |
| | key=['M', 'N', 'K'], |
| | ) |
| | @triton.jit |
| | def _goliath_fp4_dot_scaled_kernel( |
| | a_ptr, |
| | w_packed_ptr, |
| | w_scales_ptr, |
| | out_ptr, |
| | bias_ptr, |
| | M, N, K, |
| | stride_am, stride_ak, |
| | stride_wn, stride_wk, |
| | stride_sn, stride_sk, |
| | stride_om, stride_on, |
| | HAS_BIAS: tl.constexpr, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | ): |
| | """ |
| | Goliath FP4 dot_scaled kernel β uses tl.dot_scaled for native MXFP4 TCs. |
| | |
| | LHS = BF16 activations (no scale), RHS = E2M1 packed weights with E8M0 scales. |
| | When Triton maps this to tcgen05.mma.mxf4, throughput reaches 200-800+ TFLOPS. |
| | """ |
| | pid_m = tl.program_id(0) |
| | pid_n = tl.program_id(1) |
| |
|
| | offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| | SCALES_PER_TILE: tl.constexpr = BLOCK_K // 32 |
| |
|
| | for k_start in range(0, K, BLOCK_K): |
| | |
| | offs_k = k_start + tl.arange(0, BLOCK_K) |
| | a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak |
| | mask_a = (offs_m[:, None] < M) & (offs_k[None, :] < K) |
| | a_tile = tl.load(a_ptrs, mask=mask_a, other=0.0).to(tl.bfloat16) |
| |
|
| | |
| | pk_start = k_start // 2 |
| | offs_pk = pk_start + tl.arange(0, HALF_BLOCK_K) |
| | w_ptrs = w_packed_ptr + offs_n[:, None] * stride_wn + offs_pk[None, :] * stride_wk |
| | mask_w = (offs_n[:, None] < N) & (offs_pk[None, :] < (K // 2)) |
| | b_tile = tl.load(w_ptrs, mask=mask_w, other=0) |
| |
|
| | |
| | sc_start = k_start // 32 |
| | offs_sc = sc_start + tl.arange(0, SCALES_PER_TILE) |
| | s_ptrs = w_scales_ptr + offs_n[:, None] * stride_sn + offs_sc[None, :] * stride_sk |
| | mask_s = (offs_n[:, None] < N) & (offs_sc[None, :] < (K // 32)) |
| | b_scale = tl.load(s_ptrs, mask=mask_s, other=127) |
| |
|
| | |
| | |
| | acc = tl.dot_scaled(a_tile, None, "bf16", |
| | b_tile.T, b_scale, "e2m1", |
| | acc) |
| |
|
| | |
| | if HAS_BIAS: |
| | bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) |
| | acc += bias[None, :] |
| |
|
| | |
| | out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| | mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @triton.autotune( |
| | configs=[ |
| | |
| | triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=16), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| | triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=8), |
| | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=5, num_warps=8), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=5, num_warps=4), |
| | |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=2), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | ], |
| | key=['M', 'N', 'K'], |
| | ) |
| | @triton.jit |
| | def _goliath_fp8_kernel( |
| | a_ptr, |
| | w_data_ptr, |
| | w_scales_ptr, |
| | out_ptr, |
| | bias_ptr, |
| | M, N, K, |
| | stride_am, stride_ak, |
| | stride_wk, stride_wn, |
| | stride_sk, stride_sn, |
| | stride_om, stride_on, |
| | HAS_BIAS: tl.constexpr, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | ): |
| | """ |
| | Goliath FP8 fused dequant-matmul kernel. |
| | |
| | Loads E4M3-encoded uint8 weights, dequantizes in registers via |
| | _decode_e4m3_triton, applies per-block FP32 scales, accumulates via tl.dot(). |
| | No separate dequantization step β everything happens in-register. |
| | 32-element scale blocks for shared-memory-friendly tile sizes. |
| | """ |
| | pid_m = tl.program_id(0) |
| | pid_n = tl.program_id(1) |
| |
|
| | offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | SCALES_PER_TILE: tl.constexpr = BLOCK_K // 32 |
| |
|
| | for k_start in range(0, K, BLOCK_K): |
| | |
| | offs_k = k_start + tl.arange(0, BLOCK_K) |
| | a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak |
| | mask_a = (offs_m[:, None] < M) & (offs_k[None, :] < K) |
| | a_tile = tl.load(a_ptrs, mask=mask_a, other=0.0) |
| |
|
| | |
| | w_ptrs = w_data_ptr + offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| | mask_w = (offs_k[:, None] < K) & (offs_n[None, :] < N) |
| | w_raw = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
| |
|
| | |
| | w_decoded = _decode_e4m3_triton(w_raw) |
| |
|
| | |
| | scale_start = k_start // 32 |
| | offs_local_k = tl.arange(0, BLOCK_K) |
| | group_idx = offs_local_k // 32 |
| |
|
| | scale_bc = tl.zeros((BLOCK_K, BLOCK_N), dtype=tl.float32) |
| | for sg in tl.static_range(0, SCALES_PER_TILE): |
| | sg_row = scale_start + sg |
| | sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| | sg_mask = (sg_row < (K // 32)) & (offs_n < N) |
| | sg_val = tl.load(sg_ptrs, mask=sg_mask, other=1.0) |
| | sg_match = (group_idx == sg) |
| | scale_bc = tl.where(sg_match[:, None], sg_val[None, :], scale_bc) |
| |
|
| | w_scaled = (w_decoded * scale_bc).to(tl.bfloat16) |
| |
|
| | |
| | acc += tl.dot(a_tile.to(tl.bfloat16), w_scaled) |
| |
|
| | |
| | if HAS_BIAS: |
| | bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0).to(tl.float32) |
| | acc += bias[None, :] |
| |
|
| | |
| | out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| | mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _goliath_fp4_matmul( |
| | activations: torch.Tensor, |
| | weights: GoliathFP4Weights, |
| | bias: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """Launch Goliath FP4 fused dequant-matmul kernel. |
| | |
| | Dispatches to native dot_scaled path when available (SM >= 10.0 + |
| | Triton MXFP4 support), otherwise uses the manual dequant kernel. |
| | """ |
| | |
| | if _can_use_goliath_dot_scaled(): |
| | native_w = GoliathFP4NativeWeights.from_goliath_fp4(weights) |
| | return _goliath_fp4_dot_scaled_matmul(activations, native_w, bias) |
| |
|
| | |
| | |
| | if _NVFP4_BRIDGE_AVAILABLE: |
| | nvfp4_w = _BridgeNVFP4Weights( |
| | packed=weights.packed, |
| | block_scales=weights.block_scales, |
| | tensor_scale=weights.tensor_scale, |
| | shape=weights.shape, |
| | residual=weights.residual, |
| | residual_scales=weights.residual_scales, |
| | ) |
| | |
| | if _bridge_can_use_cublas_fp4 and _bridge_can_use_cublas_fp4(): |
| | return _bridge_scaled_mm_fp4(activations, nvfp4_w, bias) |
| | return _bridge_fused_nvfp4(activations, nvfp4_w, bias) |
| |
|
| | |
| | M, K = activations.shape |
| | _, N = weights.shape |
| |
|
| | a = activations.contiguous() |
| | if a.dtype != torch.bfloat16: |
| | a = a.to(torch.bfloat16) |
| |
|
| | w_packed = weights.packed.contiguous() |
| | w_scales = weights.block_scales.contiguous() |
| |
|
| | out = torch.empty(M, N, device=a.device, dtype=torch.bfloat16) |
| |
|
| | has_bias = bias is not None |
| | if has_bias: |
| | bias_t = bias.contiguous().float() |
| | else: |
| | bias_t = torch.empty(0, device=a.device, dtype=torch.float32) |
| |
|
| | grid = lambda META: ( |
| | triton.cdiv(M, META['BLOCK_M']), |
| | triton.cdiv(N, META['BLOCK_N']), |
| | ) |
| |
|
| | _goliath_fp4_kernel[grid]( |
| | a, w_packed, w_scales, out, bias_t, |
| | weights.tensor_scale, |
| | M, N, K, |
| | a.stride(0), a.stride(1), |
| | w_packed.stride(0), w_packed.stride(1), |
| | w_scales.stride(0), w_scales.stride(1), |
| | out.stride(0), out.stride(1), |
| | HAS_BIAS=has_bias, |
| | ACC_DTYPE=0, |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | def _goliath_fp8_matmul( |
| | activations: torch.Tensor, |
| | weights: GoliathFP8Weights, |
| | bias: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """Launch Goliath FP8 fused dequant-matmul kernel.""" |
| | M, K = activations.shape |
| | _, N = weights.shape |
| |
|
| | a = activations.contiguous() |
| | if a.dtype != torch.bfloat16: |
| | a = a.to(torch.bfloat16) |
| |
|
| | w_data = weights.data.contiguous() |
| | w_scales = weights.block_scales.contiguous() |
| |
|
| | out = torch.empty(M, N, device=a.device, dtype=torch.bfloat16) |
| |
|
| | has_bias = bias is not None |
| | if has_bias: |
| | bias_t = bias.contiguous().float() |
| | else: |
| | bias_t = torch.empty(0, device=a.device, dtype=torch.float32) |
| |
|
| | grid = lambda META: ( |
| | triton.cdiv(M, META['BLOCK_M']), |
| | triton.cdiv(N, META['BLOCK_N']), |
| | ) |
| |
|
| | _goliath_fp8_kernel[grid]( |
| | a, w_data, w_scales, out, bias_t, |
| | M, N, K, |
| | a.stride(0), a.stride(1), |
| | w_data.stride(0), w_data.stride(1), |
| | w_scales.stride(0), w_scales.stride(1), |
| | out.stride(0), out.stride(1), |
| | HAS_BIAS=has_bias, |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _goliath_fp4_dot_scaled_matmul( |
| | activations: torch.Tensor, |
| | weights: GoliathFP4NativeWeights, |
| | bias: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """Launch Goliath FP4 dot_scaled kernel (native MXFP4 tensor cores).""" |
| | M, K = activations.shape |
| | _, N = weights.shape |
| |
|
| | a = activations.contiguous() |
| | if a.dtype != torch.bfloat16: |
| | a = a.to(torch.bfloat16) |
| |
|
| | w_packed = weights.packed_col_major.contiguous() |
| | w_scales = weights.scales_e8m0.contiguous() |
| |
|
| | out = torch.empty(M, N, device=a.device, dtype=torch.bfloat16) |
| |
|
| | has_bias = bias is not None |
| | if has_bias: |
| | bias_t = bias.contiguous().float() |
| | else: |
| | bias_t = torch.empty(0, device=a.device, dtype=torch.float32) |
| |
|
| | grid = lambda META: ( |
| | triton.cdiv(M, META['BLOCK_M']), |
| | triton.cdiv(N, META['BLOCK_N']), |
| | ) |
| |
|
| | _goliath_fp4_dot_scaled_kernel[grid]( |
| | a, w_packed, w_scales, out, bias_t, |
| | M, N, K, |
| | a.stride(0), a.stride(1), |
| | w_packed.stride(0), w_packed.stride(1), |
| | w_scales.stride(0), w_scales.stride(1), |
| | out.stride(0), out.stride(1), |
| | HAS_BIAS=has_bias, |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | _DOT_SCALED_PROBE_RESULT: Optional[bool] = None |
| |
|
| |
|
| | def _can_use_goliath_dot_scaled() -> bool: |
| | """Probe whether tl.dot_scaled maps to real FP4 tensor cores. |
| | |
| | Checks: |
| | 1. SM >= 10.0 (Blackwell or later) |
| | 2. tl.dot_scaled exists in Triton |
| | 3. Runtime probe: run a small matmul with both the manual-dequant kernel |
| | and the dot_scaled kernel. If outputs are bit-identical, Triton is |
| | falling back to BF16 MMA β return False. If outputs differ with |
| | rel_err < 0.15 β native FP4 TCs are active β return True. |
| | """ |
| | global _DOT_SCALED_PROBE_RESULT |
| | if _DOT_SCALED_PROBE_RESULT is not None: |
| | return _DOT_SCALED_PROBE_RESULT |
| |
|
| | _DOT_SCALED_PROBE_RESULT = False |
| |
|
| | |
| | if not torch.cuda.is_available(): |
| | return False |
| | cap = torch.cuda.get_device_capability(0) |
| | if cap[0] < 10: |
| | return False |
| |
|
| | |
| | if not hasattr(tl, 'dot_scaled'): |
| | return False |
| |
|
| | |
| | try: |
| | M, N, K = 64, 64, 64 |
| | torch.manual_seed(42) |
| | w_f = torch.randn(K, N, device='cuda', dtype=torch.float32) |
| | a = torch.randn(M, K, device='cuda', dtype=torch.bfloat16) |
| |
|
| | |
| | w_q = GoliathFP4Weights.from_float(w_f) |
| |
|
| | |
| | out_manual = _goliath_fp4_matmul(a, w_q) |
| |
|
| | |
| | w_native = GoliathFP4NativeWeights.from_goliath_fp4(w_q) |
| | out_scaled = _goliath_fp4_dot_scaled_matmul(a, w_native) |
| |
|
| | |
| | if torch.equal(out_manual, out_scaled): |
| | |
| | _DOT_SCALED_PROBE_RESULT = False |
| | else: |
| | ref = out_manual.float() |
| | diff = (out_scaled.float() - ref).abs().mean() |
| | rel_err = diff / ref.abs().mean().clamp(min=1e-10) |
| | _DOT_SCALED_PROBE_RESULT = rel_err.item() < 0.15 |
| |
|
| | except Exception: |
| | _DOT_SCALED_PROBE_RESULT = False |
| |
|
| | return _DOT_SCALED_PROBE_RESULT |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _estimate_accuracy_need(weights: torch.Tensor) -> int: |
| | """ |
| | Estimate whether FP4 or FP8 is appropriate for these weights. |
| | |
| | Heuristic: if the weight distribution has high kurtosis (heavy tails) |
| | or many outliers, FP8 preserves more fidelity. Otherwise FP4 suffices. |
| | |
| | Returns: |
| | 4 for FP4, 8 for FP8 |
| | """ |
| | w = weights.float() |
| | absmax = w.abs().max() |
| | mean_abs = w.abs().mean() |
| |
|
| | |
| | outlier_ratio = absmax / mean_abs.clamp(min=1e-10) |
| |
|
| | |
| | if outlier_ratio > 20.0: |
| | return 8 |
| |
|
| | |
| | |
| | std = w.std() |
| | |
| | if std / absmax.clamp(min=1e-10) < 0.05: |
| | return 8 |
| |
|
| | return 4 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def goliath_quantize( |
| | weights: torch.Tensor, |
| | bits: Union[int, str] = 4, |
| | training: bool = False, |
| | sr_seed: Optional[int] = None, |
| | compute_residual: bool = False, |
| | ) -> GoliathWeights: |
| | """ |
| | Quantize weights for Goliath fused GEMM. |
| | |
| | Args: |
| | weights: Input tensor [K, N] in any float dtype |
| | bits: 4 for FP4, 8 for FP8, 'auto' for automatic selection |
| | training: If True, use stochastic rounding for unbiased quantization |
| | sr_seed: Seed for stochastic rounding RNG |
| | compute_residual: Compute FP8 residual correction for FP4 weights (double-buff) |
| | |
| | Returns: |
| | GoliathFP4Weights or GoliathFP8Weights |
| | """ |
| | if bits == 'auto': |
| | bits = _estimate_accuracy_need(weights) |
| |
|
| | if bits == 4: |
| | return GoliathFP4Weights.from_float(weights, training=training, sr_seed=sr_seed, |
| | compute_residual=compute_residual) |
| | elif bits == 8: |
| | return GoliathFP8Weights.from_float(weights) |
| | else: |
| | raise ValueError(f"bits must be 4, 8, or 'auto', got {bits}") |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | MAX_EXPERTS: int = 16 |
| |
|
| | @triton.jit |
| | def _goliath_fp4_multi_expert_kernel( |
| | a_ptr, |
| | |
| | w0_ptr, w1_ptr, w2_ptr, w3_ptr, w4_ptr, w5_ptr, w6_ptr, w7_ptr, |
| | |
| | s0_ptr, s1_ptr, s2_ptr, s3_ptr, s4_ptr, s5_ptr, s6_ptr, s7_ptr, |
| | |
| | tscale_ptr, |
| | out_ptr, |
| | M, N, K, |
| | num_experts, |
| | stride_am, stride_ak, |
| | stride_wk, stride_wn, |
| | stride_sk, stride_sn, |
| | stride_oe, stride_om, stride_on, |
| | a_expert_stride, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | ): |
| | """Multi-expert FP4 fused dequant-matmul: all experts in ONE launch. |
| | |
| | When a_expert_stride=0: all experts share the same input (gate_up case). |
| | When a_expert_stride>0: each expert reads from row pid_expert*M of a_ptr (down case). |
| | """ |
| | pid_expert = tl.program_id(0) |
| | pid_n = tl.program_id(1) |
| |
|
| | if pid_expert >= num_experts: |
| | return |
| |
|
| | |
| | a_base = a_ptr + pid_expert * a_expert_stride |
| |
|
| | |
| | w_packed_ptr = w0_ptr |
| | w_scales_ptr = s0_ptr |
| | if pid_expert == 1: |
| | w_packed_ptr = w1_ptr; w_scales_ptr = s1_ptr |
| | elif pid_expert == 2: |
| | w_packed_ptr = w2_ptr; w_scales_ptr = s2_ptr |
| | elif pid_expert == 3: |
| | w_packed_ptr = w3_ptr; w_scales_ptr = s3_ptr |
| | elif pid_expert == 4: |
| | w_packed_ptr = w4_ptr; w_scales_ptr = s4_ptr |
| | elif pid_expert == 5: |
| | w_packed_ptr = w5_ptr; w_scales_ptr = s5_ptr |
| | elif pid_expert == 6: |
| | w_packed_ptr = w6_ptr; w_scales_ptr = s6_ptr |
| | elif pid_expert == 7: |
| | w_packed_ptr = w7_ptr; w_scales_ptr = s7_ptr |
| |
|
| | tensor_scale = tl.load(tscale_ptr + pid_expert) |
| |
|
| | offs_m = tl.arange(0, BLOCK_M) |
| | offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| | SCALES_PER_TILE: tl.constexpr = BLOCK_K // 16 |
| |
|
| | for k_start in range(0, K, BLOCK_K): |
| | even_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| | odd_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
| |
|
| | a_even_ptrs = a_base + offs_m[:, None] * stride_am + even_k[None, :] * stride_ak |
| | a_odd_ptrs = a_base + offs_m[:, None] * stride_am + odd_k[None, :] * stride_ak |
| | mask_ae = (offs_m[:, None] < M) & (even_k[None, :] < K) |
| | mask_ao = (offs_m[:, None] < M) & (odd_k[None, :] < K) |
| | a_even = tl.load(a_even_ptrs, mask=mask_ae, other=0.0) |
| | a_odd = tl.load(a_odd_ptrs, mask=mask_ao, other=0.0) |
| |
|
| | pk_start = k_start // 2 |
| | offs_pk = pk_start + tl.arange(0, HALF_BLOCK_K) |
| | w_ptrs = w_packed_ptr + offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| | mask_w = (offs_pk[:, None] < (K // 2)) & (offs_n[None, :] < N) |
| | packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
| |
|
| | low_f = _e2m1_decode(packed & 0xF) |
| | high_f = _e2m1_decode((packed >> 4) & 0xF) |
| |
|
| | scale_start = k_start // 16 |
| | offs_local = tl.arange(0, HALF_BLOCK_K) |
| | group_idx = offs_local // 8 |
| |
|
| | scale_bc = tl.zeros((HALF_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| | for sg in tl.static_range(0, SCALES_PER_TILE): |
| | sg_row = scale_start + sg |
| | sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| | sg_mask = (sg_row < (K // 16)) & (offs_n < N) |
| | sg_raw = tl.load(sg_ptrs, mask=sg_mask, other=0).to(tl.int32) |
| | sg_val = _decode_e4m3_triton(sg_raw) * tensor_scale |
| | sg_match = (group_idx == sg) |
| | scale_bc = tl.where(sg_match[:, None], sg_val[None, :], scale_bc) |
| |
|
| | w_even = (low_f * scale_bc).to(tl.bfloat16) |
| | w_odd = (high_f * scale_bc).to(tl.bfloat16) |
| |
|
| | acc += tl.dot(a_even.to(tl.bfloat16), w_even) |
| | acc += tl.dot(a_odd.to(tl.bfloat16), w_odd) |
| |
|
| | |
| | out_ptrs = out_ptr + pid_expert * stride_oe + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| | mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
| |
|
| |
|
| | def goliath_multi_expert_gemm( |
| | activations: torch.Tensor, |
| | expert_weights: list, |
| | num_experts: int = 8, |
| | per_expert_input: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | Fused multi-expert FP4 GEMM: all experts in ONE kernel launch. |
| | |
| | Args: |
| | activations: Input in BF16. |
| | If per_expert_input=False: [M, K] shared across all experts (gate_up case) |
| | If per_expert_input=True: [num_experts*M, K] stacked per-expert inputs (down case) |
| | expert_weights: List of GoliathFP4Weights (one per active expert) |
| | num_experts: Number of active experts |
| | per_expert_input: If True, each expert reads from its own M rows of activations |
| | |
| | Returns: |
| | Output [num_experts, M, N] in BF16 |
| | """ |
| | assert num_experts <= MAX_EXPERTS, f"max {MAX_EXPERTS} experts, got {num_experts}" |
| | assert all(isinstance(w, GoliathFP4Weights) for w in expert_weights) |
| |
|
| | _, N = expert_weights[0].shape |
| |
|
| | a = activations.contiguous() |
| | if a.dtype != torch.bfloat16: |
| | a = a.to(torch.bfloat16) |
| |
|
| | if per_expert_input: |
| | |
| | total_rows, K = a.shape |
| | M = total_rows // num_experts |
| | a_expert_stride = M * a.stride(0) |
| | else: |
| | M, K = a.shape |
| | a_expert_stride = 0 |
| |
|
| | out = torch.empty(num_experts, M, N, device=a.device, dtype=torch.bfloat16) |
| |
|
| | |
| | tscales = torch.tensor( |
| | [w.tensor_scale for w in expert_weights], |
| | dtype=torch.float32, device=a.device) |
| |
|
| | packed_ptrs = [w.packed.contiguous() for w in expert_weights] |
| | scale_ptrs = [w.block_scales.contiguous() for w in expert_weights] |
| |
|
| | |
| | while len(packed_ptrs) < 8: |
| | packed_ptrs.append(packed_ptrs[0]) |
| | scale_ptrs.append(scale_ptrs[0]) |
| |
|
| | grid = (num_experts, triton.cdiv(N, 64)) |
| |
|
| | _goliath_fp4_multi_expert_kernel[grid]( |
| | a, |
| | packed_ptrs[0], packed_ptrs[1], packed_ptrs[2], packed_ptrs[3], |
| | packed_ptrs[4], packed_ptrs[5], packed_ptrs[6], packed_ptrs[7], |
| | scale_ptrs[0], scale_ptrs[1], scale_ptrs[2], scale_ptrs[3], |
| | scale_ptrs[4], scale_ptrs[5], scale_ptrs[6], scale_ptrs[7], |
| | tscales, out, |
| | M, N, K, num_experts, |
| | a.stride(0), a.stride(1), |
| | packed_ptrs[0].stride(0), packed_ptrs[0].stride(1), |
| | scale_ptrs[0].stride(0), scale_ptrs[0].stride(1), |
| | out.stride(0), out.stride(1), out.stride(2), |
| | a_expert_stride, |
| | BLOCK_M=16, BLOCK_N=64, BLOCK_K=128, |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | def goliath_gemm( |
| | activations: torch.Tensor, |
| | weights: GoliathWeights, |
| | bias: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Goliath fused dequant-matmul GEMM. |
| | |
| | Automatically dispatches to FP4 or FP8 kernel based on weight type. |
| | All dequantization happens in Triton registers β zero extra memory traffic. |
| | |
| | Args: |
| | activations: Input [M, K] in BF16/FP16/FP32 |
| | weights: GoliathFP4Weights or GoliathFP8Weights |
| | bias: Optional bias [N] |
| | |
| | Returns: |
| | Output [M, N] in BF16 |
| | """ |
| | if not activations.is_cuda: |
| | |
| | w_deq = weights.to_float() |
| | d = torch.matmul(activations.float(), w_deq) |
| | if bias is not None: |
| | d = d + bias.float() |
| | return d.to(torch.bfloat16) |
| |
|
| | if isinstance(weights, GoliathFP4Weights): |
| | return _goliath_fp4_matmul(activations, weights, bias) |
| | elif isinstance(weights, GoliathFP8Weights): |
| | return _goliath_fp8_matmul(activations, weights, bias) |
| | else: |
| | raise TypeError(f"Expected GoliathFP4Weights or GoliathFP8Weights, got {type(weights)}") |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @triton.jit |
| | def _goliath_fp4_packed_moe_kernel( |
| | a_ptr, |
| | packed_w_ptr, |
| | packed_s_ptr, |
| | tscale_ptr, |
| | expert_ids_ptr, |
| | out_ptr, |
| | M, N, K, |
| | num_active, |
| | stride_ew, |
| | stride_wk, stride_wn, |
| | stride_es, |
| | stride_sk, stride_sn, |
| | stride_oe, stride_om, stride_on, |
| | stride_am, stride_ak, |
| | a_expert_stride, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | ): |
| | """Packed MoE FP4 fused dequant-matmul: GPU-resident expert selection. |
| | |
| | Grid: (num_active, ceil(N/BLOCK_N), ceil(M/BLOCK_M)) |
| | Each program handles one active expert's matmul for one tile of output rows Γ columns. |
| | Expert IDs are read from a GPU tensor β no CPU involvement. |
| | """ |
| | pid_active = tl.program_id(0) |
| | pid_n = tl.program_id(1) |
| | pid_m = tl.program_id(2) |
| |
|
| | if pid_active >= num_active: |
| | return |
| |
|
| | |
| | expert_id = tl.load(expert_ids_ptr + pid_active) |
| |
|
| | |
| | w_packed_ptr = packed_w_ptr + expert_id * stride_ew |
| | w_scales_ptr = packed_s_ptr + expert_id * stride_es |
| | tensor_scale = tl.load(tscale_ptr + expert_id) |
| |
|
| | |
| | a_base = a_ptr + pid_active * a_expert_stride |
| |
|
| | offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| | SCALES_PER_TILE: tl.constexpr = BLOCK_K // 16 |
| |
|
| | for k_start in range(0, K, BLOCK_K): |
| | even_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| | odd_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
| |
|
| | a_even_ptrs = a_base + offs_m[:, None] * stride_am + even_k[None, :] * stride_ak |
| | a_odd_ptrs = a_base + offs_m[:, None] * stride_am + odd_k[None, :] * stride_ak |
| | mask_ae = (offs_m[:, None] < M) & (even_k[None, :] < K) |
| | mask_ao = (offs_m[:, None] < M) & (odd_k[None, :] < K) |
| | a_even = tl.load(a_even_ptrs, mask=mask_ae, other=0.0) |
| | a_odd = tl.load(a_odd_ptrs, mask=mask_ao, other=0.0) |
| |
|
| | pk_start = k_start // 2 |
| | offs_pk = pk_start + tl.arange(0, HALF_BLOCK_K) |
| | w_ptrs = w_packed_ptr + offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| | mask_w = (offs_pk[:, None] < (K // 2)) & (offs_n[None, :] < N) |
| | packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
| |
|
| | low_f = _e2m1_decode(packed & 0xF) |
| | high_f = _e2m1_decode((packed >> 4) & 0xF) |
| |
|
| | scale_start = k_start // 16 |
| | offs_local = tl.arange(0, HALF_BLOCK_K) |
| | group_idx = offs_local // 8 |
| |
|
| | scale_bc = tl.zeros((HALF_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| | for sg in tl.static_range(0, SCALES_PER_TILE): |
| | sg_row = scale_start + sg |
| | sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| | sg_mask = (sg_row < (K // 16)) & (offs_n < N) |
| | sg_raw = tl.load(sg_ptrs, mask=sg_mask, other=0).to(tl.int32) |
| | sg_val = _decode_e4m3_triton(sg_raw) * tensor_scale |
| | sg_match = (group_idx == sg) |
| | scale_bc = tl.where(sg_match[:, None], sg_val[None, :], scale_bc) |
| |
|
| | w_even = (low_f * scale_bc).to(tl.bfloat16) |
| | w_odd = (high_f * scale_bc).to(tl.bfloat16) |
| |
|
| | acc += tl.dot(a_even.to(tl.bfloat16), w_even) |
| | acc += tl.dot(a_odd.to(tl.bfloat16), w_odd) |
| |
|
| | |
| | out_ptrs = out_ptr + pid_active * stride_oe + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| | mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
| |
|
| |
|
| | def goliath_packed_moe_gemm( |
| | activations: torch.Tensor, |
| | packed_w: torch.Tensor, |
| | packed_s: torch.Tensor, |
| | packed_ts: torch.Tensor, |
| | expert_ids: torch.Tensor, |
| | num_active: int = 8, |
| | per_expert_input: bool = False, |
| | ) -> torch.Tensor: |
| | """Packed MoE FP4 GEMM: contiguous expert buffer + GPU-resident expert IDs. |
| | |
| | All expert selection happens on GPU β zero .item() calls, CUDA-graph-safe. |
| | |
| | Args: |
| | activations: Input in BF16. |
| | If per_expert_input=False: [M, K] shared across all experts |
| | If per_expert_input=True: [num_active*M, K] stacked per-expert inputs |
| | packed_w: [E_total, K//2, N] contiguous packed FP4 weights for all experts |
| | packed_s: [E_total, K//16, N] contiguous FP8 block scales |
| | packed_ts: [E_total] per-expert tensor scales (float32) |
| | expert_ids: [num_active] selected expert indices (GPU tensor, int64) |
| | num_active: Number of active experts |
| | per_expert_input: If True, each expert reads from its own M rows |
| | |
| | Returns: |
| | Output [num_active, M, N] in BF16 |
| | """ |
| | N = packed_w.shape[2] |
| |
|
| | a = activations.contiguous() |
| | if a.dtype != torch.bfloat16: |
| | a = a.to(torch.bfloat16) |
| |
|
| | if per_expert_input: |
| | total_rows, K = a.shape |
| | M = total_rows // num_active |
| | a_expert_stride = M * a.stride(0) |
| | else: |
| | M, K = a.shape |
| | a_expert_stride = 0 |
| |
|
| | out = torch.empty(num_active, M, N, device=a.device, dtype=torch.bfloat16) |
| |
|
| | BLOCK_M = 16 |
| | BLOCK_N = 64 |
| | grid = (num_active, triton.cdiv(N, BLOCK_N), triton.cdiv(M, BLOCK_M)) |
| |
|
| | _goliath_fp4_packed_moe_kernel[grid]( |
| | a, |
| | packed_w, packed_s, packed_ts, |
| | expert_ids, |
| | out, |
| | M, N, K, |
| | num_active, |
| | packed_w.stride(0), |
| | packed_w.stride(1), packed_w.stride(2), |
| | packed_s.stride(0), |
| | packed_s.stride(1), packed_s.stride(2), |
| | out.stride(0), out.stride(1), out.stride(2), |
| | a.stride(0), a.stride(1), |
| | a_expert_stride, |
| | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=128, |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @triton.jit |
| | def _goliath_fp4_swiglu_down_kernel( |
| | gate_up_ptr, |
| | packed_w_ptr, |
| | packed_s_ptr, |
| | tscale_ptr, |
| | expert_ids_ptr, |
| | out_ptr, |
| | M, N, K, |
| | num_active, |
| | inter_size, |
| | stride_ge, stride_gm, stride_gk, |
| | stride_ew, stride_wk, stride_wn, |
| | stride_es, stride_sk, stride_sn, |
| | stride_oe, stride_om, stride_on, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | ): |
| | """Fused SwiGLU + FP4 down_proj GEMM. |
| | |
| | Reads gate_up output [num_active, M, 2*inter], applies SiLU(gate)*up inline, |
| | then performs FP4 GEMM with down_proj weights. Eliminates hidden tensor write/read. |
| | """ |
| | pid_active = tl.program_id(0) |
| | pid_n = tl.program_id(1) |
| |
|
| | if pid_active >= num_active: |
| | return |
| |
|
| | expert_id = tl.load(expert_ids_ptr + pid_active) |
| |
|
| | w_packed_ptr = packed_w_ptr + expert_id * stride_ew |
| | w_scales_ptr = packed_s_ptr + expert_id * stride_es |
| | tensor_scale = tl.load(tscale_ptr + expert_id) |
| |
|
| | gate_up_base = gate_up_ptr + pid_active * stride_ge |
| |
|
| | offs_m = tl.arange(0, BLOCK_M) |
| | offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | HALF_BLOCK_K: tl.constexpr = BLOCK_K // 2 |
| | SCALES_PER_TILE: tl.constexpr = BLOCK_K // 16 |
| |
|
| | for k_start in range(0, K, BLOCK_K): |
| | even_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 |
| | odd_k = k_start + tl.arange(0, HALF_BLOCK_K) * 2 + 1 |
| |
|
| | |
| | gate_even_ptrs = gate_up_base + offs_m[:, None] * stride_gm + even_k[None, :] * stride_gk |
| | gate_odd_ptrs = gate_up_base + offs_m[:, None] * stride_gm + odd_k[None, :] * stride_gk |
| | up_even_ptrs = gate_up_base + offs_m[:, None] * stride_gm + (inter_size + even_k[None, :]) * stride_gk |
| | up_odd_ptrs = gate_up_base + offs_m[:, None] * stride_gm + (inter_size + odd_k[None, :]) * stride_gk |
| |
|
| | mask_e = (offs_m[:, None] < M) & (even_k[None, :] < K) |
| | mask_o = (offs_m[:, None] < M) & (odd_k[None, :] < K) |
| |
|
| | gate_even = tl.load(gate_even_ptrs, mask=mask_e, other=0.0).to(tl.float32) |
| | gate_odd = tl.load(gate_odd_ptrs, mask=mask_o, other=0.0).to(tl.float32) |
| | up_even = tl.load(up_even_ptrs, mask=mask_e, other=0.0).to(tl.float32) |
| | up_odd = tl.load(up_odd_ptrs, mask=mask_o, other=0.0).to(tl.float32) |
| |
|
| | |
| | a_even = (tl.sigmoid(gate_even) * gate_even * up_even).to(tl.bfloat16) |
| | a_odd = (tl.sigmoid(gate_odd) * gate_odd * up_odd).to(tl.bfloat16) |
| |
|
| | |
| | pk_start = k_start // 2 |
| | offs_pk = pk_start + tl.arange(0, HALF_BLOCK_K) |
| | w_ptrs = w_packed_ptr + offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| | mask_w = (offs_pk[:, None] < (K // 2)) & (offs_n[None, :] < N) |
| | packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
| |
|
| | low_f = _e2m1_decode(packed & 0xF) |
| | high_f = _e2m1_decode((packed >> 4) & 0xF) |
| |
|
| | scale_start = k_start // 16 |
| | offs_local = tl.arange(0, HALF_BLOCK_K) |
| | group_idx = offs_local // 8 |
| |
|
| | scale_bc = tl.zeros((HALF_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| | for sg in tl.static_range(0, SCALES_PER_TILE): |
| | sg_row = scale_start + sg |
| | sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| | sg_mask = (sg_row < (K // 16)) & (offs_n < N) |
| | sg_raw = tl.load(sg_ptrs, mask=sg_mask, other=0).to(tl.int32) |
| | sg_val = _decode_e4m3_triton(sg_raw) * tensor_scale |
| | sg_match = (group_idx == sg) |
| | scale_bc = tl.where(sg_match[:, None], sg_val[None, :], scale_bc) |
| |
|
| | w_even = (low_f * scale_bc).to(tl.bfloat16) |
| | w_odd = (high_f * scale_bc).to(tl.bfloat16) |
| |
|
| | acc += tl.dot(a_even, w_even) |
| | acc += tl.dot(a_odd, w_odd) |
| |
|
| | out_ptrs = out_ptr + pid_active * stride_oe + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| | mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
| |
|
| |
|
| | def goliath_packed_moe_swiglu_down( |
| | gate_up_output: torch.Tensor, |
| | packed_w: torch.Tensor, |
| | packed_s: torch.Tensor, |
| | packed_ts: torch.Tensor, |
| | expert_ids: torch.Tensor, |
| | intermediate_size: int, |
| | num_active: int = 8, |
| | ) -> torch.Tensor: |
| | """Fused SwiGLU + FP4 down_proj GEMM. |
| | |
| | Takes gate_up output, applies SiLU(gate)*up inline, performs down_proj. |
| | Eliminates the hidden tensor write/read cycle. |
| | |
| | Args: |
| | gate_up_output: [num_active, M, 2*intermediate] output from gate_up GEMM |
| | packed_w: Down projection FP4 weights [E_total, K//2, N] |
| | packed_s: Down projection FP8 block scales |
| | packed_ts: Down projection tensor scales |
| | expert_ids: Selected expert indices on GPU |
| | intermediate_size: K dimension (matches gate_up output's second half) |
| | num_active: Number of active experts |
| | |
| | Returns: |
| | [num_active, M, N] final output |
| | """ |
| | num_active_in, M, gu_dim = gate_up_output.shape |
| | assert gu_dim == 2 * intermediate_size, f"gate_up dim {gu_dim} != 2*{intermediate_size}" |
| |
|
| | K = intermediate_size |
| | _, _, N = packed_w.shape[0], packed_w.shape[1] * 2, packed_w.shape[2] |
| | |
| |
|
| | gate_up = gate_up_output.contiguous() |
| | if gate_up.dtype != torch.bfloat16: |
| | gate_up = gate_up.to(torch.bfloat16) |
| |
|
| | out = torch.empty(num_active, M, N, device=gate_up.device, dtype=torch.bfloat16) |
| |
|
| | grid = (num_active, triton.cdiv(N, 64)) |
| |
|
| | _goliath_fp4_swiglu_down_kernel[grid]( |
| | gate_up, |
| | packed_w, packed_s, packed_ts, expert_ids, |
| | out, |
| | M, N, K, |
| | num_active, |
| | intermediate_size, |
| | gate_up.stride(0), gate_up.stride(1), gate_up.stride(2), |
| | packed_w.stride(0), packed_w.stride(1), packed_w.stride(2), |
| | packed_s.stride(0), packed_s.stride(1), packed_s.stride(2), |
| | out.stride(0), out.stride(1), out.stride(2), |
| | BLOCK_M=1 if M == 1 else 16, |
| | BLOCK_N=64, |
| | BLOCK_K=32, |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @triton.jit |
| | def _goliath_int2_packed_moe_kernel( |
| | a_ptr, |
| | packed_w_ptr, |
| | packed_s_ptr, |
| | expert_ids_ptr, |
| | out_ptr, |
| | M, N, K, |
| | num_active, |
| | stride_ew, |
| | stride_wk, stride_wn, |
| | stride_es, |
| | stride_sk, stride_sn, |
| | stride_oe, stride_om, stride_on, |
| | stride_am, stride_ak, |
| | a_expert_stride, |
| | BLOCK_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | ): |
| | """Packed MoE INT2 fused dequant-matmul: GPU-resident expert selection. |
| | |
| | Similar to FP4 kernel but unpacks 4 weights per byte instead of 2. |
| | Used for cold experts where 2-bit precision is acceptable. |
| | """ |
| | pid_active = tl.program_id(0) |
| | pid_n = tl.program_id(1) |
| |
|
| | if pid_active >= num_active: |
| | return |
| |
|
| | expert_id = tl.load(expert_ids_ptr + pid_active) |
| |
|
| | w_packed_ptr = packed_w_ptr + expert_id * stride_ew |
| | w_scales_ptr = packed_s_ptr + expert_id * stride_es |
| |
|
| | a_base = a_ptr + pid_active * a_expert_stride |
| |
|
| | offs_m = tl.arange(0, BLOCK_M) |
| | offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | |
| | QUARTER_BLOCK_K: tl.constexpr = BLOCK_K // 4 |
| | SCALES_PER_TILE: tl.constexpr = BLOCK_K // 32 |
| |
|
| | for k_start in range(0, K, BLOCK_K): |
| | |
| | pk_start = k_start // 4 |
| | offs_pk = pk_start + tl.arange(0, QUARTER_BLOCK_K) |
| | w_ptrs = w_packed_ptr + offs_pk[:, None] * stride_wk + offs_n[None, :] * stride_wn |
| | mask_w = (offs_pk[:, None] < (K // 4)) & (offs_n[None, :] < N) |
| | packed = tl.load(w_ptrs, mask=mask_w, other=0).to(tl.int32) |
| |
|
| | |
| | w0 = ((packed >> 0) & 0x3).to(tl.float32) - 2.0 |
| | w1 = ((packed >> 2) & 0x3).to(tl.float32) - 2.0 |
| | w2 = ((packed >> 4) & 0x3).to(tl.float32) - 2.0 |
| | w3 = ((packed >> 6) & 0x3).to(tl.float32) - 2.0 |
| |
|
| | |
| | scale_start = k_start // 32 |
| | offs_local_k = tl.arange(0, QUARTER_BLOCK_K) |
| |
|
| | scale_bc = tl.zeros((QUARTER_BLOCK_K, BLOCK_N), dtype=tl.float32) |
| | for sg in tl.static_range(0, SCALES_PER_TILE): |
| | sg_row = scale_start + sg |
| | sg_ptrs = w_scales_ptr + sg_row * stride_sk + offs_n * stride_sn |
| | sg_mask = (sg_row < (K // 32)) & (offs_n < N) |
| | sg_val = tl.load(sg_ptrs, mask=sg_mask, other=1.0).to(tl.float32) |
| | sg_match = (offs_local_k // 8 == sg) |
| | scale_bc = tl.where(sg_match[:, None], sg_val[None, :], scale_bc) |
| |
|
| | |
| | w0_scaled = (w0 * scale_bc).to(tl.bfloat16) |
| | w1_scaled = (w1 * scale_bc).to(tl.bfloat16) |
| | w2_scaled = (w2 * scale_bc).to(tl.bfloat16) |
| | w3_scaled = (w3 * scale_bc).to(tl.bfloat16) |
| |
|
| | |
| | |
| | offs_qk = tl.arange(0, QUARTER_BLOCK_K) |
| | mask_a_qk = (offs_m[:, None] < M) & ((k_start + offs_qk[None, :] * 4) < K) |
| | a0 = tl.load(a_base + offs_m[:, None] * stride_am + (k_start + offs_qk[None, :] * 4 + 0) * stride_ak, |
| | mask=mask_a_qk, other=0.0).to(tl.bfloat16) |
| | a1 = tl.load(a_base + offs_m[:, None] * stride_am + (k_start + offs_qk[None, :] * 4 + 1) * stride_ak, |
| | mask=mask_a_qk, other=0.0).to(tl.bfloat16) |
| | a2 = tl.load(a_base + offs_m[:, None] * stride_am + (k_start + offs_qk[None, :] * 4 + 2) * stride_ak, |
| | mask=mask_a_qk, other=0.0).to(tl.bfloat16) |
| | a3 = tl.load(a_base + offs_m[:, None] * stride_am + (k_start + offs_qk[None, :] * 4 + 3) * stride_ak, |
| | mask=mask_a_qk, other=0.0).to(tl.bfloat16) |
| |
|
| | acc += tl.dot(a0, w0_scaled) |
| | acc += tl.dot(a1, w1_scaled) |
| | acc += tl.dot(a2, w2_scaled) |
| | acc += tl.dot(a3, w3_scaled) |
| |
|
| | out_ptrs = out_ptr + pid_active * stride_oe + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on |
| | mask_out = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_out) |
| |
|
| |
|
| | def goliath_packed_moe_int2_gemm( |
| | activations: torch.Tensor, |
| | packed_w: torch.Tensor, |
| | packed_s: torch.Tensor, |
| | expert_ids: torch.Tensor, |
| | num_active: int = 8, |
| | per_expert_input: bool = False, |
| | ) -> torch.Tensor: |
| | """Packed MoE INT2 GEMM: 2-bit weights for cold experts. |
| | |
| | 2x smaller than FP4 = 2x less bandwidth. Use for rarely-routed experts. |
| | |
| | Args: |
| | activations: Input in BF16. [M, K] or [num_active*M, K] |
| | packed_w: [E_total, K//4, N] INT2 packed weights |
| | packed_s: [E_total, K//32, N] FP16 block scales |
| | expert_ids: [num_active] selected expert indices on GPU |
| | num_active: Number of active experts |
| | per_expert_input: If True, each expert reads from its own M rows |
| | |
| | Returns: |
| | Output [num_active, M, N] in BF16 |
| | """ |
| | N = packed_w.shape[2] |
| |
|
| | a = activations.contiguous() |
| | if a.dtype != torch.bfloat16: |
| | a = a.to(torch.bfloat16) |
| |
|
| | if per_expert_input: |
| | total_rows, K = a.shape |
| | M = total_rows // num_active |
| | a_expert_stride = M * a.stride(0) |
| | else: |
| | M, K = a.shape |
| | a_expert_stride = 0 |
| |
|
| | out = torch.empty(num_active, M, N, device=a.device, dtype=torch.bfloat16) |
| |
|
| | grid = (num_active, triton.cdiv(N, 64)) |
| |
|
| | _goliath_int2_packed_moe_kernel[grid]( |
| | a, |
| | packed_w, packed_s, |
| | expert_ids, |
| | out, |
| | M, N, K, |
| | num_active, |
| | packed_w.stride(0), |
| | packed_w.stride(1), packed_w.stride(2), |
| | packed_s.stride(0), |
| | packed_s.stride(1), packed_s.stride(2), |
| | out.stride(0), out.stride(1), out.stride(2), |
| | a.stride(0), a.stride(1), |
| | a_expert_stride, |
| | BLOCK_M=16 if M > 1 else 1, |
| | BLOCK_N=64, |
| | BLOCK_K=128, |
| | ) |
| |
|
| | return out |
| |
|
| |
|
| | def pack_experts_int2( |
| | expert_weights: list, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Pack INT2 expert weights into contiguous buffers. |
| | |
| | Args: |
| | expert_weights: List of GoliathINT2Weights (one per expert) |
| | |
| | Returns: |
| | packed_w: [E, K//4, N] contiguous INT2 weights |
| | packed_s: [E, K//32, N] contiguous FP16 scales |
| | """ |
| | E = len(expert_weights) |
| | K, N = expert_weights[0].shape |
| |
|
| | device = expert_weights[0].packed.device |
| |
|
| | packed_w = torch.zeros(E, K // 4, N, dtype=torch.uint8, device=device) |
| | packed_s = torch.zeros(E, K // 32, N, dtype=torch.float16, device=device) |
| |
|
| | for i, w in enumerate(expert_weights): |
| | packed_w[i] = w.packed |
| | packed_s[i] = w.block_scales |
| |
|
| | return packed_w, packed_s |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def fexc_precompute_psumbook( |
| | codebooks: torch.Tensor, |
| | x: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Precompute psumbook: dot products of all codebook centroids with input groups. |
| | |
| | psumbook[m, c, j] = dot(codebooks[m][c], x[j*8:(j+1)*8]) |
| | |
| | This is computed ONCE per token and reused across all 8 active experts. |
| | For N=2048: psumbook is [2, 256, 256] float32 = 512KB. |
| | |
| | Args: |
| | codebooks: [2, 256, 8] float16 β shared codebooks for this layer |
| | x: [N] or [1, N] input vector |
| | |
| | Returns: |
| | psumbook: [2, 256, N//8] float32 |
| | """ |
| | x_flat = x.view(-1).float() |
| | N = x_flat.shape[0] |
| | g = 8 |
| | x_groups = x_flat.view(N // g, g) |
| | cb = codebooks.float() |
| |
|
| | |
| | psumbook = torch.bmm(cb, x_groups.T.unsqueeze(0).expand(2, -1, -1)) |
| | return psumbook |
| |
|
| |
|
| | @triton.jit |
| | def _goliath_fexc_packed_moe_kernel( |
| | psumbook_ptr, |
| | codes_ptr, |
| | scales_ptr, |
| | expert_ids_ptr, |
| | out_ptr, |
| | K: tl.constexpr, |
| | num_groups: tl.constexpr, |
| | num_active, |
| | stride_ec, |
| | stride_ck, |
| | stride_cg, |
| | stride_es, |
| | stride_oe, |
| | BLOCK_K: tl.constexpr, |
| | BLOCK_G: tl.constexpr, |
| | ): |
| | """FE-XC packed MoE kernel: CodeGEMM-style psumbook gather for M=1. |
| | |
| | For each active expert and each output row, gathers precomputed partial |
| | sums from psumbook using codebook indices. This replaces the traditional |
| | dequant-matmul with scalar gather+add operations. |
| | |
| | Grid: (num_active, ceil(K / BLOCK_K)) |
| | """ |
| | pid_expert = tl.program_id(0) |
| | pid_k = tl.program_id(1) |
| |
|
| | if pid_expert >= num_active: |
| | return |
| |
|
| | expert_id = tl.load(expert_ids_ptr + pid_expert) |
| |
|
| | offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) |
| | mask_k = offs_k < K |
| |
|
| | |
| | acc = tl.zeros((BLOCK_K,), dtype=tl.float32) |
| |
|
| | |
| | code_base = codes_ptr + expert_id * stride_ec |
| |
|
| | |
| | |
| | ps_stride_m = 256 * num_groups |
| | ps_stride_c = num_groups |
| |
|
| | |
| | for g_start in range(0, num_groups, BLOCK_G): |
| | offs_g = g_start + tl.arange(0, BLOCK_G) |
| | mask_g = offs_g < num_groups |
| |
|
| | |
| | |
| | code_ptrs_0 = code_base + offs_k[:, None] * stride_ck + offs_g[None, :] * stride_cg + 0 |
| | code_ptrs_1 = code_base + offs_k[:, None] * stride_ck + offs_g[None, :] * stride_cg + 1 |
| | mask_kg = mask_k[:, None] & mask_g[None, :] |
| |
|
| | codes_0 = tl.load(code_ptrs_0, mask=mask_kg, other=0).to(tl.int32) |
| | codes_1 = tl.load(code_ptrs_1, mask=mask_kg, other=0).to(tl.int32) |
| |
|
| | |
| | ps_ptrs_0 = psumbook_ptr + 0 * ps_stride_m + codes_0 * ps_stride_c + offs_g[None, :] |
| | ps_ptrs_1 = psumbook_ptr + 1 * ps_stride_m + codes_1 * ps_stride_c + offs_g[None, :] |
| |
|
| | ps_vals_0 = tl.load(ps_ptrs_0, mask=mask_kg, other=0.0) |
| | ps_vals_1 = tl.load(ps_ptrs_1, mask=mask_kg, other=0.0) |
| |
|
| | |
| | acc += tl.sum(ps_vals_0 + ps_vals_1, axis=1) |
| |
|
| | |
| | scale_ptrs = scales_ptr + expert_id * stride_es + offs_k |
| | scales = tl.load(scale_ptrs, mask=mask_k, other=1.0).to(tl.float32) |
| | acc = acc * scales |
| |
|
| | |
| | out_ptrs = out_ptr + pid_expert * stride_oe + offs_k |
| | tl.store(out_ptrs, acc.to(tl.bfloat16), mask=mask_k) |
| |
|
| |
|
| | def goliath_packed_moe_fexc_gemm( |
| | activations: torch.Tensor, |
| | packed_codes: torch.Tensor, |
| | codebooks: torch.Tensor, |
| | packed_scales: torch.Tensor, |
| | expert_ids: torch.Tensor, |
| | psumbook: torch.Tensor = None, |
| | num_active: int = 8, |
| | ) -> torch.Tensor: |
| | """FE-XC packed MoE GEMM: codebook 2-bit with CodeGEMM psumbook. |
| | |
| | Near-FP16 quality at 2 bits/weight. Precomputes psumbook once per token, |
| | then uses scalar gather+add instead of dequant+matmul. |
| | |
| | Args: |
| | activations: [M, N] or [1, N] input in BF16 |
| | packed_codes: [E_total, K, N//8, 2] codebook indices |
| | codebooks: [2, 256, 8] shared codebooks for this layer |
| | packed_scales: [E_total, K] per-output-channel scales |
| | expert_ids: [num_active] expert indices on GPU |
| | psumbook: Precomputed [2, 256, N//8] (computed if None) |
| | num_active: Number of active experts |
| | |
| | Returns: |
| | Output [num_active, 1, K] in BF16 (M=1) |
| | """ |
| | M, N = activations.shape |
| | K = packed_codes.shape[1] |
| | num_groups = N // 8 |
| |
|
| | |
| | if psumbook is None: |
| | psumbook = fexc_precompute_psumbook(codebooks, activations[0]) |
| |
|
| | out = torch.empty(num_active, K, device=activations.device, dtype=torch.bfloat16) |
| |
|
| | BLOCK_K = min(64, K) |
| | BLOCK_G = min(64, num_groups) |
| |
|
| | grid = (num_active, triton.cdiv(K, BLOCK_K)) |
| |
|
| | _goliath_fexc_packed_moe_kernel[grid]( |
| | psumbook, |
| | packed_codes, |
| | packed_scales, |
| | expert_ids, |
| | out, |
| | K, num_groups, num_active, |
| | packed_codes.stride(0), |
| | packed_codes.stride(1), |
| | packed_codes.stride(2), |
| | packed_scales.stride(0), |
| | out.stride(0), |
| | BLOCK_K=BLOCK_K, |
| | BLOCK_G=BLOCK_G, |
| | ) |
| |
|
| | return out.unsqueeze(1) |
| |
|
| |
|
| | def pack_experts_fexc( |
| | expert_weights: list, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Pack FE-XC expert weights into contiguous buffers. |
| | |
| | Args: |
| | expert_weights: List of GoliathFEXCWeights (one per expert) |
| | |
| | Returns: |
| | packed_codes: [E, K, N//8, 2] contiguous uint8 |
| | packed_scales: [E, K] contiguous float16 |
| | codebooks: [2, 256, 8] float16 (shared, from first expert) |
| | """ |
| | E = len(expert_weights) |
| | K, N = expert_weights[0].shape |
| | g = expert_weights[0].group_size |
| | device = expert_weights[0].codes.device |
| |
|
| | packed_codes = torch.zeros(E, K, N // g, 2, dtype=torch.uint8, device=device) |
| | packed_scales = torch.zeros(E, K, dtype=torch.float16, device=device) |
| |
|
| | for i, w in enumerate(expert_weights): |
| | packed_codes[i] = w.codes |
| | packed_scales[i] = w.scales |
| |
|
| | |
| | codebooks = expert_weights[0].codebooks |
| |
|
| | return packed_codes, packed_scales, codebooks |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def benchmark_goliath(M=4096, N=4096, K=4096, warmup=10, iters=100): |
| | """Benchmark Goliath FP4/FP8 vs BF16 cuBLAS.""" |
| | import time |
| |
|
| | if not torch.cuda.is_available(): |
| | print("No CUDA available.") |
| | return |
| |
|
| | print("=" * 60) |
| | print("Goliath FP4/FP8 Benchmark") |
| | print("=" * 60) |
| |
|
| | w = torch.randn(K, N, device='cuda', dtype=torch.float32) |
| | w_fp4 = goliath_quantize(w, bits=4) |
| | w_fp8 = goliath_quantize(w, bits=8) |
| | w_bf16 = w.to(torch.bfloat16) |
| | a = torch.randn(M, K, device='cuda', dtype=torch.bfloat16) |
| |
|
| | |
| | fp4_bytes = w_fp4.packed.numel() + w_fp4.block_scales.numel() |
| | fp8_bytes = w_fp8.data.numel() + w_fp8.block_scales.numel() * 4 |
| | bf16_bytes = w_bf16.numel() * 2 |
| | print(f"Weight memory:") |
| | print(f" FP4: {fp4_bytes / 1e6:.1f} MB ({bf16_bytes / fp4_bytes:.1f}x compression)") |
| | print(f" FP8: {fp8_bytes / 1e6:.1f} MB ({bf16_bytes / fp8_bytes:.1f}x compression)") |
| | print(f" BF16: {bf16_bytes / 1e6:.1f} MB") |
| | print() |
| |
|
| | |
| | for _ in range(warmup): |
| | goliath_gemm(a, w_fp4) |
| | goliath_gemm(a, w_fp8) |
| | torch.matmul(a, w_bf16) |
| | torch.cuda.synchronize() |
| |
|
| | |
| | start = time.perf_counter() |
| | for _ in range(iters): |
| | goliath_gemm(a, w_fp4) |
| | torch.cuda.synchronize() |
| | fp4_t = (time.perf_counter() - start) / iters |
| |
|
| | |
| | start = time.perf_counter() |
| | for _ in range(iters): |
| | goliath_gemm(a, w_fp8) |
| | torch.cuda.synchronize() |
| | fp8_t = (time.perf_counter() - start) / iters |
| |
|
| | |
| | start = time.perf_counter() |
| | for _ in range(iters): |
| | torch.matmul(a, w_bf16) |
| | torch.cuda.synchronize() |
| | bf16_t = (time.perf_counter() - start) / iters |
| |
|
| | flops = 2 * M * N * K |
| | print(f"{M}x{N}x{K} GEMM:") |
| | print(f" Goliath FP4: {flops/fp4_t/1e12:.1f} TFLOPS ({fp4_t*1000:.2f}ms)") |
| | print(f" Goliath FP8: {flops/fp8_t/1e12:.1f} TFLOPS ({fp8_t*1000:.2f}ms)") |
| | print(f" BF16 cuBLAS: {flops/bf16_t/1e12:.1f} TFLOPS ({bf16_t*1000:.2f}ms)") |
| |
|
| | |
| | out_fp4 = goliath_gemm(a, w_fp4) |
| | out_fp8 = goliath_gemm(a, w_fp8) |
| | out_ref = torch.matmul(a.float(), w).bfloat16() |
| |
|
| | err_fp4 = (out_fp4 - out_ref).abs().mean() / out_ref.abs().mean() |
| | err_fp8 = (out_fp8 - out_ref).abs().mean() / out_ref.abs().mean() |
| | print(f" FP4 vs FP32 ref: rel_err={err_fp4:.4f}") |
| | print(f" FP8 vs FP32 ref: rel_err={err_fp8:.4f}") |
| |
|
| | |
| | w_auto = goliath_quantize(w, bits='auto') |
| | print(f" Auto-selected: FP{w_auto.bits}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | import torch.nn as nn |
| |
|
| |
|
| | class _GoliathLinearFunction(torch.autograd.Function): |
| | """Custom autograd for fused FP4/FP8 forward + FP32 backward.""" |
| |
|
| | @staticmethod |
| | def forward(ctx, input, goliath_weights, bias, weight_fp32, bits): |
| | |
| | |
| | out = goliath_gemm(input, goliath_weights, bias) |
| | ctx.save_for_backward(input, weight_fp32, bias) |
| | ctx.bits = bits |
| | return out |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | input, weight_fp32, bias = ctx.saved_tensors |
| | |
| |
|
| | |
| | |
| | dA = torch.matmul(grad_output.float(), weight_fp32.float()) |
| |
|
| | |
| | dW = torch.matmul(grad_output.float().T, input.float()) |
| |
|
| | |
| | db = None |
| | if bias is not None: |
| | db = grad_output.float().sum(dim=0) |
| |
|
| | |
| | return dA.to(input.dtype), None, db, dW, None |
| |
|
| |
|
| | class GoliathLinear(nn.Module): |
| | """Linear layer with Goliath fused FP4/FP8 quantized forward + FP32 backward. |
| | |
| | Stores FP32 master weights as ``nn.Parameter`` for training. On forward, |
| | weights are quantized via Goliath and the fused kernel runs the matmul. |
| | Backward uses FP32 master weights for gradient computation. |
| | |
| | Args: |
| | in_features: Input dimension (K) |
| | out_features: Output dimension (N) |
| | bias: Whether to include a bias term |
| | bits: Quantization bits β 4, 8, or 'auto' |
| | """ |
| |
|
| | def __init__(self, in_features: int, out_features: int, bias: bool = True, |
| | bits: Union[int, str] = 4): |
| | super().__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.bits = bits |
| |
|
| | |
| | self.weight = nn.Parameter(torch.empty(out_features, in_features)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.zeros(out_features)) |
| | else: |
| | self.register_parameter('bias', None) |
| |
|
| | nn.init.kaiming_uniform_(self.weight, a=2.23606797749979) |
| |
|
| | |
| | self._goliath_weights: Optional[GoliathWeights] = None |
| | self._weight_version: int = -1 |
| |
|
| | def _ensure_quantized(self): |
| | """Re-quantize from master weights when weight data has changed.""" |
| | |
| | current_version = self.weight._version |
| | if self._goliath_weights is not None and self._weight_version == current_version: |
| | return |
| | |
| | |
| | w_kn = self.weight.data.T.contiguous().float() |
| | self._goliath_weights = goliath_quantize(w_kn, bits=self.bits) |
| | self._weight_version = current_version |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | orig_shape = x.shape[:-1] |
| | x_flat = x.reshape(-1, self.in_features) |
| |
|
| | if x_flat.is_cuda and not self.training: |
| | |
| | self._ensure_quantized() |
| | out = goliath_gemm(x_flat, self._goliath_weights, self.bias) |
| | return out.view(*orig_shape, self.out_features) |
| |
|
| | if x_flat.is_cuda: |
| | |
| | self._ensure_quantized() |
| | out = _GoliathLinearFunction.apply( |
| | x_flat, self._goliath_weights, self.bias, |
| | self.weight, self.bits, |
| | ) |
| | return out.view(*orig_shape, self.out_features) |
| |
|
| | |
| | return torch.nn.functional.linear(x, self.weight, self.bias) |
| |
|
| | def extra_repr(self) -> str: |
| | return (f'in_features={self.in_features}, out_features={self.out_features}, ' |
| | f'bias={self.bias is not None}, bits={self.bits}') |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | _quantum_matmul = None |
| | _quantum_import_attempted = False |
| |
|
| | def _ensure_quantum_import(): |
| | """Lazily import quantum_optimized_matmul from kernel/quantum/.""" |
| | global _quantum_matmul, _quantum_import_attempted |
| | if _quantum_import_attempted: |
| | return _quantum_matmul is not None |
| | _quantum_import_attempted = True |
| | try: |
| | from quantum import quantum_optimized_matmul |
| | _quantum_matmul = quantum_optimized_matmul |
| | return True |
| | except ImportError: |
| | import sys as _sys |
| | import os as _os |
| | |
| | _kernel_dir = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))) |
| | if _kernel_dir not in _sys.path: |
| | _sys.path.insert(0, _kernel_dir) |
| | try: |
| | from quantum import quantum_optimized_matmul |
| | _quantum_matmul = quantum_optimized_matmul |
| | return True |
| | except ImportError: |
| | return False |
| |
|
| |
|
| | class _GoliathQuantumFunction(torch.autograd.Function): |
| | """Custom autograd: Goliath FP8 forward + Quantum Gold L2-swizzled backward. |
| | |
| | Forward: Quantize BF16 master β FP8, run fused dequant-matmul (2x bandwidth) |
| | Backward: quantum_optimized_matmul for dX and dW (L2 cache swizzle, Triton GEMM) |
| | """ |
| |
|
| | @staticmethod |
| | def forward(ctx, input, weight_bf16, bias, goliath_fp8): |
| | |
| | out = goliath_gemm(input, goliath_fp8, bias) |
| | ctx.save_for_backward(input, weight_bf16, bias) |
| | return out |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | input, weight_bf16, bias = ctx.saved_tensors |
| | |
| |
|
| | if _quantum_matmul is not None: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | dX = _quantum_matmul( |
| | grad_output.contiguous(), |
| | weight_bf16.contiguous() |
| | ) |
| |
|
| | |
| | dW = _quantum_matmul( |
| | grad_output.T.contiguous(), |
| | input.contiguous() |
| | ) |
| | else: |
| | |
| | dX = torch.matmul(grad_output, weight_bf16) |
| | dW = torch.matmul(grad_output.T, input) |
| |
|
| | db = None |
| | if bias is not None: |
| | db = grad_output.sum(dim=0) |
| |
|
| | |
| | return dX.to(input.dtype), dW.to(weight_bf16.dtype), db, None |
| |
|
| |
|
| | class GoliathQuantumLinear(nn.Module): |
| | """FireEcho training-optimized linear layer. |
| | |
| | Combines Goliath FP8 forward (2x bandwidth savings) with Quantum Gold |
| | backward (L2-swizzled Triton GEMM for gradients). BF16 master weights |
| | save 50% memory vs FP32. |
| | |
| | This is the training counterpart of GoliathLinear. GoliathLinear uses FP32 |
| | master weights and standard torch.matmul backward. GoliathQuantumLinear uses |
| | BF16 master weights and quantum_optimized_matmul backward. |
| | |
| | Args: |
| | in_features: Input dimension (K) |
| | out_features: Output dimension (N) |
| | bias: Whether to include a bias term |
| | """ |
| |
|
| | def __init__(self, in_features: int, out_features: int, bias: bool = False): |
| | super().__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| |
|
| | |
| | self.weight = nn.Parameter( |
| | torch.empty(out_features, in_features, dtype=torch.bfloat16)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.bfloat16)) |
| | else: |
| | self.register_parameter('bias', None) |
| |
|
| | nn.init.kaiming_uniform_(self.weight, a=2.23606797749979) |
| |
|
| | |
| | self._goliath_fp8: Optional[GoliathFP8Weights] = None |
| | self._weight_version: int = -1 |
| |
|
| | |
| | _ensure_quantum_import() |
| |
|
| | def _ensure_quantized(self): |
| | """Re-quantize BF16 master β FP8 when weights have changed.""" |
| | current_version = self.weight._version |
| | if self._goliath_fp8 is not None and self._weight_version == current_version: |
| | return |
| | |
| | |
| | w_kn = self.weight.data.T.contiguous().float() |
| | self._goliath_fp8 = GoliathFP8Weights.from_float(w_kn) |
| | self._weight_version = current_version |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | orig_shape = x.shape[:-1] |
| | x_flat = x.reshape(-1, self.in_features) |
| |
|
| | if x_flat.is_cuda and not self.training: |
| | |
| | self._ensure_quantized() |
| | out = goliath_gemm(x_flat, self._goliath_fp8, self.bias) |
| | return out.view(*orig_shape, self.out_features) |
| |
|
| | if x_flat.is_cuda: |
| | |
| | self._ensure_quantized() |
| | out = _GoliathQuantumFunction.apply( |
| | x_flat, self.weight, self.bias, self._goliath_fp8, |
| | ) |
| | return out.view(*orig_shape, self.out_features) |
| |
|
| | |
| | return torch.nn.functional.linear(x, self.weight, self.bias) |
| |
|
| | def extra_repr(self) -> str: |
| | quantum_str = "quantum" if _quantum_matmul is not None else "fallback" |
| | return (f'in_features={self.in_features}, out_features={self.out_features}, ' |
| | f'bias={self.bias is not None}, backward={quantum_str}') |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("Goliath β Native FP4/FP8 Fused Triton GEMM Kernel") |
| | print("=" * 60) |
| |
|
| | if torch.cuda.is_available(): |
| | print(f"GPU: {torch.cuda.get_device_name(0)}") |
| | benchmark_goliath(M=2048, N=2048, K=2048) |
| | print() |
| | benchmark_goliath(M=4096, N=4096, K=4096) |
| | else: |
| | print("No CUDA available.") |
| |
|