import torch import logging import math logger = logging.getLogger(__name__) # NVFP4 (E2M1) Table # exp=0, mant=0 -> 0.0 # exp=0, mant=1 -> 0.5 # exp=1, mant=0 -> 1.0 # exp=1, mant=1 -> 1.5 # exp=2, mant=0 -> 2.0 # exp=2, mant=1 -> 3.0 # exp=3, mant=0 -> 4.0 # exp=3, mant=1 -> 6.0 NVFP4_TABLE = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32) def stochastic_float_to_fp4_e2m1(x, generator=None): """Convert float tensor to packed 4-bit E2M1 format with stochastic rounding.""" device = x.device # Ensure the last dimension is even for packing orig_last_dim = x.shape[-1] if orig_last_dim % 2 != 0: x = torch.nn.functional.pad(x, (0, 1)) orig_shape = x.shape # Calculate exponent for stochastic noise scaling # x.abs() log2 + 1 gives a rough exponent exp = torch.floor(torch.log2(x.abs() + 1e-8) + 1.0).clamp(0, 3) # Add stochastic noise scaled by exponent if generator is provided if generator is not None: noise = (torch.rand(x.size(), dtype=x.dtype, device=device, generator=generator) - 0.5) x = x + noise * (2 ** (exp - 2.0)) * 1.25 sign = torch.signbit(x).to(torch.uint8) x = x.abs() # Recalculate exponent after noise exp = torch.floor(torch.log2(x + 1e-8) + 1.1925).clamp(0, 3) # Calculate mantissa # If exp > 0: val = (1 + m/2) * 2^(exp-1) => m = (val / 2^(exp-1) - 1) * 2 # If exp = 0: val = m/2 => m = val * 2 mantissa = torch.where( exp > 0, (x / (2.0 ** (exp - 1)) - 1.0) * 2.0, (x * 2.0) ).round().clamp(0, 1).to(torch.uint8) # Pack into 4 bits: [sign:1, exp:2, mantissa:1] fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa # Pack two 4-bit values into one uint8 fp4_flat = fp4.view(-1) # We already padded x to be even, so fp4_flat.numel() is even packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] new_shape = list(orig_shape) new_shape[-1] = new_shape[-1] // 2 return packed.reshape(new_shape) def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: """ Rearrange a matrix by breaking it into blocks and applying the rearrangement pattern. Matches NVIDIA's block scaling factors layout. """ def ceil_div(a, b): return (a + b - 1) // b rows, cols = input_matrix.shape n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) padded_rows = n_row_blocks * 128 padded_cols = n_col_blocks * 4 if (rows, cols) != (padded_rows, padded_cols): padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype) padded[:rows, :cols] = input_matrix else: padded = input_matrix # Rearrange the blocks: [n_row_blocks, 128, n_col_blocks, 4] -> [n_row_blocks, n_col_blocks, 128, 4] blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) if flatten: return rearranged.flatten() return rearranged.reshape(padded_rows, padded_cols) def from_blocked(blocked_matrix, original_rows, original_cols): """Inverse of to_blocked.""" def ceil_div(a, b): return (a + b - 1) // b n_row_blocks = ceil_div(original_rows, 128) n_col_blocks = ceil_div(original_cols, 4) padded_rows = n_row_blocks * 128 padded_cols = n_col_blocks * 4 # [Total_Blocks, 32, 16] rearranged = blocked_matrix.reshape(-1, 32, 16) # [Total_Blocks, 4, 32, 4] -> [Total_Blocks, 128, 4] blocks = rearranged.reshape(-1, 32, 4, 4).transpose(1, 2).reshape(n_row_blocks, n_col_blocks, 128, 4) # [n_row_blocks, 128, n_col_blocks, 4] padded = blocks.permute(0, 2, 1, 3).reshape(padded_rows, padded_cols) return padded[:original_rows, :original_cols] def quantize_nvfp4(tensor, stochastic_rounding=0): """Quantize tensor to NVFP4 format.""" if tensor.dim() != 2: raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D") F4_E2M1_MAX = 6.0 F8_E4M3_MAX = 448.0 orig_shape = tensor.shape device = tensor.device # Calculate per-tensor scale # We want max(abs(x)) / (tensor_scale * block_scale) <= F4_MAX # And block_scale <= F8_MAX # So tensor_scale ~ max(abs(x)) / (F8_MAX * F4_MAX) tensor_scale = torch.amax(tensor.abs()) / (F8_E4M3_MAX * F4_E2M1_MAX) if tensor_scale == 0: tensor_scale = torch.tensor(1.0, device=device) # Block size is 16 elements along the last dimension block_size = 16 rows, cols = tensor.shape padded_cols = (cols + block_size - 1) // block_size * block_size if cols != padded_cols: x = torch.nn.functional.pad(tensor, (0, padded_cols - cols)) else: x = tensor x = x.reshape(rows, -1, block_size) # Calculate per-block scales (FP8 E4M3) # block_scale = max(abs(block)) / (tensor_scale * F4_MAX) block_scales = (torch.amax(torch.abs(x), dim=-1) / (tensor_scale * F4_E2M1_MAX)).clamp(max=F8_E4M3_MAX) # Normalize by scales # x_norm = x / (tensor_scale * block_scale) x = x / (tensor_scale * block_scales.unsqueeze(-1) + 1e-12) x = x.view(rows, padded_cols)[:, :cols].reshape(orig_shape).nan_to_num() generator = None if stochastic_rounding > 0: generator = torch.Generator(device=device) generator.manual_seed(stochastic_rounding) qdata = stochastic_float_to_fp4_e2m1(x, generator=generator) # ComfyUI expects block_scales in a specific "blocked" layout blocked_scales = to_blocked(block_scales, flatten=False) return qdata, tensor_scale, blocked_scales def dequantize_nvfp4(qdata, tensor_scale, blocked_scales, original_shape): """Dequantize NVFP4 data back to float.""" device = qdata.device # Ensure scales are on the correct device if isinstance(tensor_scale, torch.Tensor): tensor_scale = tensor_scale.to(device) else: tensor_scale = torch.tensor(tensor_scale, device=device) blocked_scales = blocked_scales.to(device) # 1. Unpack uint8 to two 4-bit values high = (qdata >> 4) & 0x0F low = qdata & 0x0F rows, cols = original_shape # Each row in qdata has (cols + 1) // 2 elements # So we stack them and reshape to (rows, -1) to get the padded width fp4 = torch.stack([high, low], dim=-1).reshape(rows, -1)[:, :cols].reshape(original_shape) # 2. Map indices to values # sign: bit 3, index: bits 0-2 sign = (fp4 >> 3).to(torch.float32) sign = 1.0 - 2.0 * sign # 0 -> 1.0, 1 -> -1.0 indices = fp4 & 0x07 values = NVFP4_TABLE.to(device)[indices.long()] x = sign * values # 3. Undo block scaling # blocked_scales shape: [padded_rows, padded_cols] rows, cols = original_shape block_cols = (cols + 15) // 16 if blocked_scales.shape == (rows, block_cols): block_scales = blocked_scales else: block_scales = from_blocked(blocked_scales, rows, block_cols) # block_scales is [rows, block_cols], each scale covers 16 elements padded_cols = block_cols * 16 if cols != padded_cols: x_padded = torch.nn.functional.pad(x.view(rows, cols), (0, padded_cols - cols)) else: x_padded = x.view(rows, cols) x_padded = x_padded.reshape(rows, -1, 16) x_padded = x_padded * block_scales.to(x.dtype).unsqueeze(-1) x = x_padded.view(rows, padded_cols)[:, :cols].reshape(original_shape) # 4. Apply per-tensor scale x = x * tensor_scale.to(x.dtype) return x