|
|
import ast |
|
|
import os |
|
|
import torch |
|
|
from torch.utils import _pytree as pytree |
|
|
|
|
|
from optimum.quanto import QModuleMixin |
|
|
from optimum.quanto.tensor.qtensor import QTensor |
|
|
from optimum.quanto.tensor.qtype import qtype as _quanto_qtype, qtypes as _quanto_qtypes |
|
|
|
|
|
def _maybe_add_nvfp4_cu13_dll_dir(): |
|
|
if os.name != "nt": |
|
|
return |
|
|
try: |
|
|
import nvidia.cu13 |
|
|
dll_dir = os.path.join(nvidia.cu13.__path__[0], "bin", "x86_64") |
|
|
if os.path.isdir(dll_dir): |
|
|
os.add_dll_directory(dll_dir) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
try: |
|
|
from comfy_kitchen.backends import cuda as _ck_cuda |
|
|
_ck_cuda_available = getattr(_ck_cuda, "_EXT_AVAILABLE", False) |
|
|
except Exception: |
|
|
_ck_cuda = None |
|
|
_ck_cuda_available = False |
|
|
|
|
|
try: |
|
|
_maybe_add_nvfp4_cu13_dll_dir() |
|
|
from lightx2v_kernel import gemm as _lx_gemm |
|
|
_lx_gemm_available = True |
|
|
except Exception: |
|
|
_lx_gemm = None |
|
|
_lx_gemm_available = False |
|
|
|
|
|
_NVFP4_QTYPE_NAME = "nvfp4" |
|
|
if _NVFP4_QTYPE_NAME not in _quanto_qtypes: |
|
|
_quanto_qtypes[_NVFP4_QTYPE_NAME] = _quanto_qtype( |
|
|
_NVFP4_QTYPE_NAME, |
|
|
is_floating_point=True, |
|
|
bits=4, |
|
|
dtype=torch.uint8, |
|
|
qmin=-6.0, |
|
|
qmax=6.0, |
|
|
) |
|
|
_NVFP4_QTYPE = _quanto_qtypes[_NVFP4_QTYPE_NAME] |
|
|
HANDLER_PRIORITY = 1 |
|
|
|
|
|
_NVFP4_LAYOUT_LEGACY = "legacy" |
|
|
_NVFP4_LAYOUT_TENSORCORE = "tensorcore" |
|
|
|
|
|
_NVFP4_BACKEND_AUTO = "auto" |
|
|
_NVFP4_BACKEND_COMFY = "comfy" |
|
|
_NVFP4_BACKEND_LIGHTX2V = "lightx2v" |
|
|
|
|
|
_NVFP4_KERNEL_LOGGED = False |
|
|
_NVFP4_FALLBACK_LOGGED = False |
|
|
_NVFP4_LOAD_LOGGED = False |
|
|
_NVFP4_KERNEL_AVAILABLE = False |
|
|
_NVFP4_KERNEL_CHECKED = False |
|
|
_NVFP4_KERNEL_BACKEND = None |
|
|
_NVFP4_ACT_SCALE_CACHE = {} |
|
|
|
|
|
_NVFP4_BACKEND = os.environ.get("WGP_NVFP4_BACKEND", _NVFP4_BACKEND_AUTO).strip().lower() |
|
|
_NVFP4_BACKEND = _NVFP4_BACKEND_LIGHTX2V |
|
|
|
|
|
def _normalize_nvfp4_backend(name): |
|
|
if name is None: |
|
|
return _NVFP4_BACKEND_AUTO |
|
|
norm = str(name).strip().lower() |
|
|
if norm in ("", "auto", "default"): |
|
|
return _NVFP4_BACKEND_AUTO |
|
|
if norm in ("comfy", "comfy-kitchen", "comfy_kitchen", "ck"): |
|
|
return _NVFP4_BACKEND_COMFY |
|
|
if norm in ("lightx2v", "lightx2v_kernel", "lightx2v-kernel", "lx"): |
|
|
return _NVFP4_BACKEND_LIGHTX2V |
|
|
if norm in ("off", "none", "fallback", "disable", "disabled"): |
|
|
return "fallback" |
|
|
return norm |
|
|
|
|
|
|
|
|
_NVFP4_BACKEND = _normalize_nvfp4_backend(_NVFP4_BACKEND) |
|
|
|
|
|
|
|
|
def _nvfp4_backend_candidates(): |
|
|
if _NVFP4_BACKEND == _NVFP4_BACKEND_AUTO: |
|
|
return [_NVFP4_BACKEND_COMFY, _NVFP4_BACKEND_LIGHTX2V] |
|
|
if _NVFP4_BACKEND in (_NVFP4_BACKEND_COMFY, _NVFP4_BACKEND_LIGHTX2V): |
|
|
return [_NVFP4_BACKEND] |
|
|
return [] |
|
|
|
|
|
|
|
|
def _nvfp4_backend_label(backend): |
|
|
if backend == _NVFP4_BACKEND_LIGHTX2V: |
|
|
return "lightx2v" |
|
|
if backend == _NVFP4_BACKEND_COMFY: |
|
|
return "comfy-kitchen" |
|
|
return backend |
|
|
|
|
|
|
|
|
def _nvfp4_lightx2v_device_ok(device): |
|
|
force = os.environ.get("WGP_NVFP4_LIGHTX2V_FORCE", "").strip().lower() |
|
|
if force in ("1", "true", "yes", "y"): |
|
|
return True |
|
|
try: |
|
|
props = torch.cuda.get_device_properties(device) |
|
|
except Exception: |
|
|
return False |
|
|
return props.major >= 12 |
|
|
|
|
|
|
|
|
def set_nvfp4_backend(name): |
|
|
global _NVFP4_BACKEND, _NVFP4_KERNEL_CHECKED, _NVFP4_KERNEL_AVAILABLE, _NVFP4_KERNEL_BACKEND |
|
|
global _NVFP4_KERNEL_LOGGED, _NVFP4_LOAD_LOGGED |
|
|
_NVFP4_BACKEND = _normalize_nvfp4_backend(name) |
|
|
_NVFP4_KERNEL_CHECKED = False |
|
|
_NVFP4_KERNEL_AVAILABLE = False |
|
|
_NVFP4_KERNEL_BACKEND = None |
|
|
_NVFP4_KERNEL_LOGGED = False |
|
|
_NVFP4_LOAD_LOGGED = False |
|
|
_init_nvfp4_kernel_support() |
|
|
|
|
|
|
|
|
def _nvfp4_note_kernel(): |
|
|
global _NVFP4_KERNEL_LOGGED |
|
|
if not _NVFP4_KERNEL_LOGGED: |
|
|
label = _nvfp4_backend_label(_NVFP4_KERNEL_BACKEND) if _NVFP4_KERNEL_BACKEND else "CUDA" |
|
|
print(f"NVFP4: using {label} kernel") |
|
|
_NVFP4_KERNEL_LOGGED = True |
|
|
|
|
|
|
|
|
def _nvfp4_note_fallback(): |
|
|
global _NVFP4_FALLBACK_LOGGED |
|
|
global _NVFP4_KERNEL_LOGGED |
|
|
if not _NVFP4_FALLBACK_LOGGED: |
|
|
if _NVFP4_KERNEL_LOGGED: |
|
|
print("NVFP4: linear fallback needed on some weights") |
|
|
else: |
|
|
print("NVFP4: linear fallback") |
|
|
_NVFP4_FALLBACK_LOGGED = True |
|
|
|
|
|
def _nvfp4_note_reset(): |
|
|
global _NVFP4_FALLBACK_LOGGED |
|
|
global _NVFP4_KERNEL_LOGGED |
|
|
global _NVFP4_LOAD_LOGGED |
|
|
_NVFP4_KERNEL_LOGGED = False |
|
|
_NVFP4_FALLBACK_LOGGED = False |
|
|
_NVFP4_LOAD_LOGGED = False |
|
|
|
|
|
def _nvfp4_note_load_backend(): |
|
|
global _NVFP4_LOAD_LOGGED |
|
|
if _NVFP4_LOAD_LOGGED: |
|
|
return |
|
|
_NVFP4_LOAD_LOGGED = True |
|
|
if _NVFP4_KERNEL_AVAILABLE: |
|
|
label = _nvfp4_backend_label(_NVFP4_KERNEL_BACKEND) if _NVFP4_KERNEL_BACKEND else "unknown" |
|
|
print(f"NVFP4: kernels available ({label}); optimized path will be used when compatible.") |
|
|
else: |
|
|
print("NVFP4: kernels unavailable; using fallback.") |
|
|
|
|
|
|
|
|
def _check_nvfp4_kernel_support(device, backend): |
|
|
if device.type != "cuda": |
|
|
return False |
|
|
if backend == _NVFP4_BACKEND_COMFY: |
|
|
if not _ck_cuda_available: |
|
|
return False |
|
|
if not hasattr(_ck_cuda, "scaled_mm_nvfp4"): |
|
|
return False |
|
|
if not hasattr(_ck_cuda, "quantize_nvfp4"): |
|
|
return False |
|
|
if not (hasattr(torch.ops, "comfy_kitchen") and hasattr(torch.ops.comfy_kitchen, "scaled_mm_nvfp4")): |
|
|
return False |
|
|
major, minor = torch.cuda.get_device_capability(device) |
|
|
return (major, minor) >= (10, 0) |
|
|
if backend == _NVFP4_BACKEND_LIGHTX2V: |
|
|
if not _lx_gemm_available: |
|
|
return False |
|
|
if not _nvfp4_lightx2v_device_ok(device): |
|
|
return False |
|
|
if not (hasattr(torch.ops, "lightx2v_kernel") and hasattr(torch.ops.lightx2v_kernel, "cutlass_scaled_nvfp4_mm_sm120")): |
|
|
return False |
|
|
if not hasattr(torch.ops.lightx2v_kernel, "scaled_nvfp4_quant_sm120"): |
|
|
return False |
|
|
major, minor = torch.cuda.get_device_capability(device) |
|
|
return (major, minor) >= (12, 0) |
|
|
return False |
|
|
|
|
|
|
|
|
def _init_nvfp4_kernel_support(): |
|
|
global _NVFP4_KERNEL_AVAILABLE, _NVFP4_KERNEL_CHECKED, _NVFP4_KERNEL_BACKEND |
|
|
if _NVFP4_KERNEL_CHECKED: |
|
|
return |
|
|
_NVFP4_KERNEL_CHECKED = True |
|
|
_NVFP4_KERNEL_AVAILABLE = False |
|
|
_NVFP4_KERNEL_BACKEND = None |
|
|
if not torch.cuda.is_available(): |
|
|
return |
|
|
device = torch.device("cuda") |
|
|
for backend in _nvfp4_backend_candidates(): |
|
|
try: |
|
|
if _check_nvfp4_kernel_support(device, backend): |
|
|
_NVFP4_KERNEL_AVAILABLE = True |
|
|
_NVFP4_KERNEL_BACKEND = backend |
|
|
break |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
|
|
|
def _supports_nvfp4_kernel(device): |
|
|
if device.type != "cuda": |
|
|
return False |
|
|
if not _NVFP4_KERNEL_CHECKED: |
|
|
_init_nvfp4_kernel_support() |
|
|
return _NVFP4_KERNEL_AVAILABLE |
|
|
|
|
|
|
|
|
_init_nvfp4_kernel_support() |
|
|
|
|
|
|
|
|
def _nvfp4_layout(weight): |
|
|
return getattr(weight, "_layout", _NVFP4_LAYOUT_LEGACY) |
|
|
|
|
|
|
|
|
def _nvfp4_can_use_kernel(input, weight): |
|
|
if not torch.is_tensor(input): |
|
|
return False |
|
|
if not getattr(weight, "_allow_kernel", True): |
|
|
return False |
|
|
if not _supports_nvfp4_kernel(input.device): |
|
|
return False |
|
|
backend = _NVFP4_KERNEL_BACKEND |
|
|
if backend is None: |
|
|
return False |
|
|
layout = _nvfp4_layout(weight) |
|
|
if backend == _NVFP4_BACKEND_LIGHTX2V: |
|
|
if input.shape[-1] % 32 != 0: |
|
|
return False |
|
|
if weight.size(0) % 32 != 0: |
|
|
return False |
|
|
else: |
|
|
if layout == _NVFP4_LAYOUT_LEGACY: |
|
|
if input.shape[-1] % 64 != 0: |
|
|
return False |
|
|
else: |
|
|
if input.shape[-1] % 16 != 0: |
|
|
return False |
|
|
if weight.size(0) % 8 != 0: |
|
|
return False |
|
|
if weight._data.shape[1] * 2 != input.shape[-1]: |
|
|
return False |
|
|
if weight._block_size != 16: |
|
|
return False |
|
|
if not torch.is_tensor(weight._input_global_scale) or not torch.is_tensor(weight._alpha): |
|
|
return False |
|
|
if getattr(weight._input_global_scale, "is_meta", False): |
|
|
return False |
|
|
try: |
|
|
if not torch.isfinite(weight._input_global_scale).all(): |
|
|
return False |
|
|
except Exception: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def _nvfp4_get_act_scale(device): |
|
|
act_scale = _NVFP4_ACT_SCALE_CACHE.get(device) |
|
|
if act_scale is None: |
|
|
act_scale = torch.tensor(1.0, device=device, dtype=torch.float32) |
|
|
_NVFP4_ACT_SCALE_CACHE[device] = act_scale |
|
|
return act_scale |
|
|
|
|
|
|
|
|
def _nvfp4_swap_nibbles(tensor): |
|
|
return ((tensor & 0x0F) << 4) | ((tensor & 0xF0) >> 4) |
|
|
|
|
|
|
|
|
def _nvfp4_linear_cuda_comfy(input, weight, bias=None): |
|
|
_nvfp4_note_kernel() |
|
|
x2d = input.reshape(-1, input.shape[-1]) |
|
|
if not x2d.is_floating_point(): |
|
|
x2d = x2d.to(torch.float16) |
|
|
orig_dtype = x2d.dtype |
|
|
if orig_dtype not in (torch.float16, torch.bfloat16): |
|
|
x2d = x2d.to(torch.float16) |
|
|
out_dtype = torch.float16 |
|
|
else: |
|
|
out_dtype = orig_dtype |
|
|
if not x2d.is_contiguous(): |
|
|
x2d = x2d.contiguous() |
|
|
weight_fp4 = weight._data |
|
|
weight_scale = weight._scale |
|
|
input_scale = weight._input_global_scale |
|
|
alpha = weight._alpha |
|
|
layout = _nvfp4_layout(weight) |
|
|
device = x2d.device |
|
|
if weight_fp4.device != device: |
|
|
weight_fp4 = weight_fp4.to(device) |
|
|
if weight_scale.device != device: |
|
|
weight_scale = weight_scale.to(device) |
|
|
if input_scale.device != device: |
|
|
input_scale = input_scale.to(device) |
|
|
if alpha.device != device: |
|
|
alpha = alpha.to(device) |
|
|
if bias is not None and torch.is_tensor(bias) and bias.dtype != out_dtype: |
|
|
bias = bias.to(out_dtype) |
|
|
orig_rows = x2d.shape[0] |
|
|
pad_16x = (orig_rows % 16) != 0 |
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
input_scale = input_scale.to(torch.float32) |
|
|
tensor_scale = alpha.to(torch.float32) |
|
|
qx, qx_scale = _ck_cuda.quantize_nvfp4(x2d, input_scale, 0.0, pad_16x) |
|
|
out = _ck_cuda.scaled_mm_nvfp4( |
|
|
qx, |
|
|
weight_fp4, |
|
|
tensor_scale_a=input_scale, |
|
|
tensor_scale_b=tensor_scale, |
|
|
block_scale_a=qx_scale, |
|
|
block_scale_b=weight_scale, |
|
|
bias=bias, |
|
|
out_dtype=out_dtype, |
|
|
) |
|
|
else: |
|
|
alpha = alpha * input_scale |
|
|
if alpha.dtype != torch.float32: |
|
|
alpha = alpha.to(torch.float32) |
|
|
act_scale = _nvfp4_get_act_scale(device) |
|
|
qx, qx_scale = _ck_cuda.quantize_nvfp4(x2d, act_scale, 0.0, pad_16x) |
|
|
weight_fp4 = _nvfp4_swap_nibbles(weight_fp4) |
|
|
out = _ck_cuda.scaled_mm_nvfp4( |
|
|
qx, |
|
|
weight_fp4, |
|
|
act_scale, |
|
|
input_scale, |
|
|
qx_scale, |
|
|
weight_scale, |
|
|
bias=bias, |
|
|
out_dtype=out_dtype, |
|
|
alpha=alpha, |
|
|
) |
|
|
if pad_16x: |
|
|
out = out[:orig_rows] |
|
|
if out.dtype != orig_dtype: |
|
|
out = out.to(orig_dtype) |
|
|
return out.reshape(*input.shape[:-1], weight.size(0)) |
|
|
|
|
|
|
|
|
def _nvfp4_linear_cuda_lightx2v(input, weight, bias=None): |
|
|
_nvfp4_note_kernel() |
|
|
x2d = input.reshape(-1, input.shape[-1]) |
|
|
if not x2d.is_floating_point(): |
|
|
x2d = x2d.to(torch.float16) |
|
|
orig_dtype = x2d.dtype |
|
|
if orig_dtype not in (torch.float16, torch.bfloat16): |
|
|
x2d = x2d.to(torch.float16) |
|
|
out_dtype = torch.float16 |
|
|
else: |
|
|
out_dtype = orig_dtype |
|
|
if not x2d.is_contiguous(): |
|
|
x2d = x2d.contiguous() |
|
|
weight_fp4 = weight._data |
|
|
weight_scale = weight._scale |
|
|
input_scale = weight._input_global_scale |
|
|
alpha = weight._alpha |
|
|
layout = _nvfp4_layout(weight) |
|
|
device = x2d.device |
|
|
if weight_fp4.device != device: |
|
|
weight_fp4 = weight_fp4.to(device) |
|
|
if weight_scale.device != device: |
|
|
weight_scale = weight_scale.to(device) |
|
|
if not weight_fp4.is_contiguous(): |
|
|
weight_fp4 = weight_fp4.contiguous() |
|
|
if not weight_scale.is_contiguous(): |
|
|
weight_scale = weight_scale.contiguous() |
|
|
if input_scale.device != device: |
|
|
input_scale = input_scale.to(device) |
|
|
if alpha.device != device: |
|
|
alpha = alpha.to(device) |
|
|
if input_scale.dtype != torch.float32: |
|
|
input_scale = input_scale.to(torch.float32) |
|
|
if alpha.dtype != torch.float32: |
|
|
alpha = alpha.to(torch.float32) |
|
|
if bias is not None and torch.is_tensor(bias): |
|
|
if bias.dtype != torch.bfloat16: |
|
|
bias = bias.to(torch.bfloat16) |
|
|
if not bias.is_contiguous(): |
|
|
bias = bias.contiguous() |
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
quant_scale = torch.reciprocal(torch.clamp(input_scale, min=1e-8)) |
|
|
alpha = alpha * input_scale |
|
|
else: |
|
|
quant_scale = input_scale |
|
|
qx, qx_scale = _lx_gemm.scaled_nvfp4_quant(x2d, quant_scale) |
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
qx = _nvfp4_swap_nibbles(qx) |
|
|
if not qx.is_contiguous(): |
|
|
qx = qx.contiguous() |
|
|
if not qx_scale.is_contiguous(): |
|
|
qx_scale = qx_scale.contiguous() |
|
|
out = _lx_gemm.cutlass_scaled_nvfp4_mm( |
|
|
qx, |
|
|
weight_fp4, |
|
|
qx_scale, |
|
|
weight_scale, |
|
|
alpha=alpha, |
|
|
bias=bias, |
|
|
) |
|
|
if out.dtype != orig_dtype: |
|
|
out = out.to(orig_dtype) |
|
|
return out.reshape(*input.shape[:-1], weight.size(0)) |
|
|
|
|
|
|
|
|
def _nvfp4_linear_cuda(input, weight, bias=None): |
|
|
if _NVFP4_KERNEL_BACKEND == _NVFP4_BACKEND_LIGHTX2V: |
|
|
return _nvfp4_linear_cuda_lightx2v(input, weight, bias=bias) |
|
|
return _nvfp4_linear_cuda_comfy(input, weight, bias=bias) |
|
|
|
|
|
|
|
|
@torch.compiler.disable() |
|
|
def _nvfp4_linear(input, weight, bias=None, op=None): |
|
|
if _nvfp4_can_use_kernel(input, weight): |
|
|
return _nvfp4_linear_cuda(input, weight, bias=bias) |
|
|
_nvfp4_note_fallback() |
|
|
dtype = input.dtype if torch.is_tensor(input) else weight.dtype |
|
|
device = input.device if torch.is_tensor(input) else weight.device |
|
|
w = weight.dequantize(dtype=dtype, device=device) |
|
|
if bias is not None and torch.is_tensor(bias) and bias.dtype != dtype: |
|
|
bias = bias.to(dtype) |
|
|
if op is not None: |
|
|
return op(input, w, bias) |
|
|
return torch.nn.functional.linear(input, w, bias) |
|
|
|
|
|
|
|
|
def _is_float8_dtype(dtype): |
|
|
return "float8" in str(dtype).lower() or "f8" in str(dtype).lower() |
|
|
|
|
|
_FP4_LUT_BASE = torch.tensor( |
|
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], |
|
|
dtype=torch.float32, |
|
|
) |
|
|
_FP4_LUT_CACHE = {} |
|
|
_FP4_BYTE_LUT_CACHE = {} |
|
|
|
|
|
|
|
|
def _get_fp4_lut(device, dtype): |
|
|
key = (device, dtype) |
|
|
lut = _FP4_LUT_CACHE.get(key) |
|
|
if lut is None: |
|
|
lut = _FP4_LUT_BASE.to(device=device, dtype=dtype) |
|
|
_FP4_LUT_CACHE[key] = lut |
|
|
return lut |
|
|
|
|
|
|
|
|
def _get_fp4_byte_lut(device, dtype): |
|
|
key = (device, dtype) |
|
|
byte_lut = _FP4_BYTE_LUT_CACHE.get(key) |
|
|
if byte_lut is None: |
|
|
lut16 = _get_fp4_lut(device, dtype) |
|
|
b = torch.arange(256, device=device, dtype=torch.int32) |
|
|
byte_lut = torch.empty((256, 2), device=device, dtype=dtype) |
|
|
byte_lut[:, 0] = lut16[b & 0x0F] |
|
|
byte_lut[:, 1] = lut16[b >> 4] |
|
|
_FP4_BYTE_LUT_CACHE[key] = byte_lut |
|
|
return byte_lut |
|
|
|
|
|
|
|
|
def _deswizzle_nvfp4_scale(scale, in_features, block_size=16, dtype=None): |
|
|
k_groups = in_features // block_size |
|
|
if scale.shape[1] < k_groups: |
|
|
raise RuntimeError( |
|
|
f"NVFP4 scale shape mismatch: expected at least {k_groups} groups, got {scale.shape[1]}" |
|
|
) |
|
|
if scale.shape[1] > k_groups: |
|
|
scale = scale[:, :k_groups] |
|
|
|
|
|
m, _ = scale.shape |
|
|
m_tiles = (m + 128 - 1) // 128 |
|
|
f = block_size * 4 |
|
|
k_tiles = (in_features + f - 1) // f |
|
|
tmp = scale if dtype is None else scale.to(dtype) |
|
|
tmp = tmp.reshape(1, m_tiles, k_tiles, 32, 4, 4) |
|
|
tmp = tmp.permute(0, 1, 4, 3, 2, 5) |
|
|
out = tmp.reshape(m_tiles * 128, k_tiles * 4) |
|
|
return out[:m, :k_groups] |
|
|
|
|
|
|
|
|
def _dequantize_nvfp4_weight( |
|
|
weight_u8, |
|
|
weight_scale, |
|
|
input_global_scale, |
|
|
alpha, |
|
|
dtype, |
|
|
device, |
|
|
block_size=16, |
|
|
layout=_NVFP4_LAYOUT_LEGACY, |
|
|
): |
|
|
if weight_u8.device != device: |
|
|
weight_u8 = weight_u8.to(device) |
|
|
scale = weight_scale if weight_scale.device == device else weight_scale.to(device) |
|
|
if alpha.device != device: |
|
|
alpha = alpha.to(device) |
|
|
if input_global_scale.device != device: |
|
|
input_global_scale = input_global_scale.to(device) |
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE and device.type == "cuda" and _ck_cuda_available: |
|
|
try: |
|
|
return _ck_cuda.dequantize_nvfp4(weight_u8, alpha.to(torch.float32), scale, output_type=dtype) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
m, k_bytes = weight_u8.shape |
|
|
byte_lut = _get_fp4_byte_lut(device, dtype) |
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
idx = _nvfp4_swap_nibbles(weight_u8).to(torch.int32) |
|
|
else: |
|
|
idx = weight_u8.to(torch.int32) |
|
|
out = byte_lut[idx].reshape(m, k_bytes * 2) |
|
|
|
|
|
scale = _deswizzle_nvfp4_scale(scale, out.shape[1], block_size=block_size, dtype=dtype) |
|
|
out = out.view(out.shape[0], scale.shape[1], block_size) |
|
|
out.mul_(scale.unsqueeze(-1)) |
|
|
out = out.view(out.shape[0], -1) |
|
|
|
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
scale_factor = alpha.to(dtype) |
|
|
else: |
|
|
scale_factor = alpha.to(dtype) * input_global_scale.to(dtype) |
|
|
out.mul_(scale_factor) |
|
|
return out |
|
|
|
|
|
|
|
|
def _collect_nvfp4_specs(state_dict): |
|
|
specs = [] |
|
|
for key, tensor in state_dict.items(): |
|
|
if not key.endswith(".weight"): |
|
|
continue |
|
|
if tensor.dtype != torch.uint8: |
|
|
continue |
|
|
base = key[:-7] |
|
|
scale_key = base + ".weight_scale" |
|
|
if scale_key not in state_dict: |
|
|
continue |
|
|
if not _is_float8_dtype(state_dict[scale_key].dtype): |
|
|
continue |
|
|
|
|
|
weight_scale_2_key = base + ".weight_scale_2" |
|
|
input_scale_key = base + ".input_scale" |
|
|
if weight_scale_2_key in state_dict: |
|
|
specs.append( |
|
|
{ |
|
|
"name": base, |
|
|
"weight": tensor, |
|
|
"weight_scale": state_dict[scale_key], |
|
|
"weight_scale_2": state_dict[weight_scale_2_key], |
|
|
"input_scale": state_dict.get(input_scale_key, None), |
|
|
"bias": state_dict.get(base + ".bias", None), |
|
|
"layout": _NVFP4_LAYOUT_TENSORCORE, |
|
|
} |
|
|
) |
|
|
continue |
|
|
|
|
|
input_global_key = base + ".input_global_scale" |
|
|
alpha_key = base + ".alpha" |
|
|
input_absmax_key = base + ".input_absmax" |
|
|
weight_global_scale_key = base + ".weight_global_scale" |
|
|
if input_global_key not in state_dict or alpha_key not in state_dict: |
|
|
if input_absmax_key not in state_dict or weight_global_scale_key not in state_dict: |
|
|
continue |
|
|
input_absmax = state_dict[input_absmax_key] |
|
|
weight_global_scale = state_dict[weight_global_scale_key] |
|
|
input_global_scale = (2688.0 / input_absmax).to(torch.float32) |
|
|
alpha = 1.0 / (input_global_scale * weight_global_scale.to(torch.float32)) |
|
|
else: |
|
|
input_global_scale = state_dict[input_global_key] |
|
|
alpha = state_dict[alpha_key] |
|
|
specs.append( |
|
|
{ |
|
|
"name": base, |
|
|
"weight": tensor, |
|
|
"weight_scale": state_dict[scale_key], |
|
|
"input_global_scale": input_global_scale, |
|
|
"alpha": alpha, |
|
|
"bias": state_dict.get(base + ".bias", None), |
|
|
"layout": _NVFP4_LAYOUT_LEGACY, |
|
|
} |
|
|
) |
|
|
return specs |
|
|
|
|
|
|
|
|
def detect_nvfp4_state_dict(state_dict): |
|
|
return len(_collect_nvfp4_specs(state_dict)) > 0 |
|
|
|
|
|
|
|
|
def describe_nvfp4_state_dict(state_dict, max_names=8): |
|
|
specs = _collect_nvfp4_specs(state_dict) |
|
|
names = [spec["name"] for spec in specs] |
|
|
return {"count": len(names), "names": names[:max_names]} |
|
|
|
|
|
|
|
|
def convert_nvfp4_to_quanto(state_dict, default_dtype=None, verboseLevel=1): |
|
|
specs = _collect_nvfp4_specs(state_dict) |
|
|
if not specs: |
|
|
return {"state_dict": state_dict, "quant_map": {}} |
|
|
_nvfp4_note_load_backend() |
|
|
quant_map = {} |
|
|
for spec in specs: |
|
|
qcfg = {"weights": "nvfp4", "activations": "none"} |
|
|
quant_map[spec["name"]] = qcfg |
|
|
quant_map[spec["name"] + ".weight"] = qcfg |
|
|
return {"state_dict": state_dict, "quant_map": quant_map} |
|
|
|
|
|
|
|
|
def detect(state_dict, verboseLevel=1): |
|
|
matched = detect_nvfp4_state_dict(state_dict) |
|
|
details = describe_nvfp4_state_dict(state_dict) if matched else {} |
|
|
return {"matched": matched, "kind": "nvfp4" if matched else "none", "details": details} |
|
|
|
|
|
|
|
|
def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None): |
|
|
if detection is not None and not detection.get("matched", False): |
|
|
return {"state_dict": state_dict, "quant_map": {}} |
|
|
_nvfp4_note_reset() |
|
|
return convert_nvfp4_to_quanto(state_dict, default_dtype=default_dtype, verboseLevel=verboseLevel) |
|
|
|
|
|
|
|
|
def apply_pre_quantization(model, state_dict, quantization_map, default_dtype=None, verboseLevel=1): |
|
|
return quantization_map, [] |
|
|
|
|
|
|
|
|
def _nvfp4_qfallback(callable, *args, **kwargs): |
|
|
args, kwargs = pytree.tree_map_only(NVFP4WeightTensor, lambda x: x.dequantize(), (args, kwargs or {})) |
|
|
return callable(*args, **kwargs) |
|
|
|
|
|
|
|
|
class NVFP4WeightTensor(QTensor): |
|
|
@staticmethod |
|
|
def create( |
|
|
weight_u8, |
|
|
weight_scale, |
|
|
size, |
|
|
stride, |
|
|
dtype, |
|
|
input_global_scale=None, |
|
|
alpha=None, |
|
|
input_scale=None, |
|
|
weight_scale_2=None, |
|
|
device=None, |
|
|
requires_grad=False, |
|
|
layout=_NVFP4_LAYOUT_LEGACY, |
|
|
allow_kernel=True, |
|
|
): |
|
|
if input_global_scale is None and input_scale is not None: |
|
|
input_global_scale = input_scale |
|
|
if alpha is None and weight_scale_2 is not None: |
|
|
alpha = weight_scale_2 |
|
|
if layout == _NVFP4_LAYOUT_LEGACY and (weight_scale_2 is not None or input_scale is not None): |
|
|
layout = _NVFP4_LAYOUT_TENSORCORE |
|
|
if input_global_scale is None or alpha is None: |
|
|
raise ValueError("NVFP4WeightTensor.create requires input_global_scale/alpha or input_scale/weight_scale_2") |
|
|
if torch.is_tensor(input_global_scale): |
|
|
try: |
|
|
if not torch.isfinite(input_global_scale).all(): |
|
|
allow_kernel = False |
|
|
except Exception: |
|
|
allow_kernel = False |
|
|
device = weight_u8.device if device is None else device |
|
|
if weight_u8.device != device: |
|
|
weight_u8 = weight_u8.to(device) |
|
|
if weight_scale.device != device: |
|
|
weight_scale = weight_scale.to(device) |
|
|
if input_global_scale.device != device: |
|
|
input_global_scale = input_global_scale.to(device) |
|
|
if alpha.device != device: |
|
|
alpha = alpha.to(device) |
|
|
return NVFP4WeightTensor( |
|
|
qtype=_NVFP4_QTYPE, |
|
|
axis=0, |
|
|
size=size, |
|
|
stride=stride, |
|
|
weight_u8=weight_u8, |
|
|
weight_scale=weight_scale, |
|
|
input_global_scale=input_global_scale, |
|
|
alpha=alpha, |
|
|
allow_kernel=allow_kernel, |
|
|
dtype=dtype, |
|
|
requires_grad=requires_grad, |
|
|
layout=layout, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def __new__( |
|
|
cls, |
|
|
qtype, |
|
|
axis, |
|
|
size, |
|
|
stride, |
|
|
weight_u8, |
|
|
weight_scale, |
|
|
input_global_scale, |
|
|
alpha, |
|
|
dtype, |
|
|
allow_kernel=True, |
|
|
requires_grad=False, |
|
|
layout=_NVFP4_LAYOUT_LEGACY, |
|
|
): |
|
|
return torch.Tensor._make_wrapper_subclass( |
|
|
cls, |
|
|
size, |
|
|
strides=stride, |
|
|
dtype=dtype, |
|
|
device=weight_u8.device, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
qtype, |
|
|
axis, |
|
|
size, |
|
|
stride, |
|
|
weight_u8, |
|
|
weight_scale, |
|
|
input_global_scale, |
|
|
alpha, |
|
|
dtype, |
|
|
requires_grad=False, |
|
|
layout=_NVFP4_LAYOUT_LEGACY, |
|
|
allow_kernel=True, |
|
|
): |
|
|
super().__init__(qtype, axis) |
|
|
self._data = weight_u8 |
|
|
self._scale = weight_scale |
|
|
self._input_global_scale = input_global_scale |
|
|
self._alpha = alpha |
|
|
self._block_size = 16 |
|
|
self._layout = layout |
|
|
self._allow_kernel = allow_kernel |
|
|
|
|
|
def dequantize(self, dtype=None, device=None): |
|
|
if dtype is None: |
|
|
dtype = self.dtype |
|
|
if device is None: |
|
|
device = self.device |
|
|
return _dequantize_nvfp4_weight( |
|
|
weight_u8=self._data, |
|
|
weight_scale=self._scale, |
|
|
input_global_scale=self._input_global_scale, |
|
|
alpha=self._alpha, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
block_size=self._block_size, |
|
|
layout=self._layout, |
|
|
) |
|
|
|
|
|
def get_quantized_subtensors(self): |
|
|
if self._layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
return [ |
|
|
("weight_u8", self._data), |
|
|
("weight_scale", self._scale), |
|
|
("weight_scale_2", self._alpha), |
|
|
("input_scale", self._input_global_scale), |
|
|
] |
|
|
return [ |
|
|
("weight_u8", self._data), |
|
|
("weight_scale", self._scale), |
|
|
("input_global_scale", self._input_global_scale), |
|
|
("alpha", self._alpha), |
|
|
] |
|
|
|
|
|
def set_quantized_subtensors(self, sub_tensors): |
|
|
if isinstance(sub_tensors, dict): |
|
|
sub_map = sub_tensors |
|
|
else: |
|
|
sub_map = {name: tensor for name, tensor in sub_tensors} |
|
|
data = sub_map.get("weight_u8", sub_map.get("data")) |
|
|
if data is not None: |
|
|
self._data = data |
|
|
if "weight_scale" in sub_map and sub_map["weight_scale"] is not None: |
|
|
self._scale = sub_map["weight_scale"] |
|
|
if "input_scale" in sub_map and sub_map["input_scale"] is not None: |
|
|
self._input_global_scale = sub_map["input_scale"] |
|
|
elif "input_global_scale" in sub_map and sub_map["input_global_scale"] is not None: |
|
|
self._input_global_scale = sub_map["input_global_scale"] |
|
|
if "weight_scale_2" in sub_map and sub_map["weight_scale_2"] is not None: |
|
|
self._alpha = sub_map["weight_scale_2"] |
|
|
elif "alpha" in sub_map and sub_map["alpha"] is not None: |
|
|
self._alpha = sub_map["alpha"] |
|
|
|
|
|
def __tensor_flatten__(self): |
|
|
inner_tensors = ["_data", "_scale", "_input_global_scale", "_alpha"] |
|
|
meta = { |
|
|
"qtype": self._qtype.name, |
|
|
"axis": str(self._axis), |
|
|
"size": str(list(self.size())), |
|
|
"stride": str(list(self.stride())), |
|
|
"dtype": str(self.dtype), |
|
|
"layout": self._layout, |
|
|
"allow_kernel": "1" if self._allow_kernel else "0", |
|
|
} |
|
|
return inner_tensors, meta |
|
|
|
|
|
@staticmethod |
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): |
|
|
qtype = _quanto_qtypes[meta["qtype"]] |
|
|
axis = ast.literal_eval(meta["axis"]) |
|
|
size = ast.literal_eval(meta["size"]) |
|
|
stride = ast.literal_eval(meta["stride"]) |
|
|
dtype_str = meta.get("dtype", "torch.float16") |
|
|
if dtype_str.startswith("torch."): |
|
|
dtype_name = dtype_str.split(".", 1)[1] |
|
|
dtype = getattr(torch, dtype_name, torch.float16) |
|
|
else: |
|
|
dtype = getattr(torch, dtype_str, torch.float16) |
|
|
layout = meta.get("layout", _NVFP4_LAYOUT_LEGACY) |
|
|
allow_kernel = str(meta.get("allow_kernel", "1")).strip().lower() not in ("0", "false", "no") |
|
|
return NVFP4WeightTensor( |
|
|
qtype=qtype, |
|
|
axis=axis, |
|
|
size=size, |
|
|
stride=stride, |
|
|
weight_u8=inner_tensors["_data"], |
|
|
weight_scale=inner_tensors["_scale"], |
|
|
input_global_scale=inner_tensors["_input_global_scale"], |
|
|
alpha=inner_tensors["_alpha"], |
|
|
allow_kernel=allow_kernel, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None): |
|
|
kwargs = kwargs or {} |
|
|
if func is torch.nn.functional.linear: |
|
|
input = args[0] if len(args) > 0 else kwargs.get("input", None) |
|
|
weight = args[1] if len(args) > 1 else kwargs.get("weight", None) |
|
|
bias = args[2] if len(args) > 2 else kwargs.get("bias", None) |
|
|
if isinstance(weight, NVFP4WeightTensor): |
|
|
return _nvfp4_linear(input, weight, bias=bias) |
|
|
with torch._C.DisableTorchFunctionSubclass(): |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
@classmethod |
|
|
def __torch_dispatch__(cls, op, types, args, kwargs=None): |
|
|
op = op.overloadpacket |
|
|
if op is torch.ops.aten.linear: |
|
|
input = args[0] |
|
|
weight = args[1] |
|
|
bias = args[2] if len(args) > 2 else None |
|
|
if isinstance(weight, NVFP4WeightTensor): |
|
|
return _nvfp4_linear(input, weight, bias=bias, op=op) |
|
|
if op is torch.ops.aten.detach: |
|
|
t = args[0] |
|
|
return NVFP4WeightTensor.create( |
|
|
weight_u8=op(t._data), |
|
|
weight_scale=op(t._scale), |
|
|
input_global_scale=op(t._input_global_scale), |
|
|
alpha=op(t._alpha), |
|
|
allow_kernel=getattr(t, "_allow_kernel", True), |
|
|
size=t.size(), |
|
|
stride=t.stride(), |
|
|
dtype=t.dtype, |
|
|
device=t.device, |
|
|
requires_grad=t.requires_grad, |
|
|
layout=t._layout, |
|
|
) |
|
|
if op in (torch.ops.aten._to_copy, torch.ops.aten.to): |
|
|
t = args[0] |
|
|
dtype = kwargs.pop("dtype", t.dtype) if kwargs else t.dtype |
|
|
device = kwargs.pop("device", t.device) if kwargs else t.device |
|
|
if dtype != t.dtype: |
|
|
return t.dequantize(dtype=dtype, device=device) |
|
|
out_data = op(t._data, device=device, **(kwargs or {})) |
|
|
out_scale = op(t._scale, device=device, **(kwargs or {})) |
|
|
out_igs = op(t._input_global_scale, device=device, **(kwargs or {})) |
|
|
out_alpha = op(t._alpha, device=device, **(kwargs or {})) |
|
|
return NVFP4WeightTensor.create( |
|
|
weight_u8=out_data, |
|
|
weight_scale=out_scale, |
|
|
input_global_scale=out_igs, |
|
|
alpha=out_alpha, |
|
|
allow_kernel=getattr(t, "_allow_kernel", True), |
|
|
size=t.size(), |
|
|
stride=t.stride(), |
|
|
dtype=t.dtype, |
|
|
device=device, |
|
|
requires_grad=t.requires_grad, |
|
|
layout=t._layout, |
|
|
) |
|
|
return _nvfp4_qfallback(op, *args, **(kwargs or {})) |
|
|
|
|
|
|
|
|
class QLinearNVFP4(QModuleMixin, torch.nn.Linear): |
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
out_features, |
|
|
bias=True, |
|
|
device=None, |
|
|
dtype=None, |
|
|
weights=None, |
|
|
activations=None, |
|
|
optimizer=None, |
|
|
quantize_input=True, |
|
|
): |
|
|
super().__init__( |
|
|
in_features, |
|
|
out_features, |
|
|
bias=bias, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
weights=weights, |
|
|
activations=activations, |
|
|
optimizer=optimizer, |
|
|
quantize_input=quantize_input, |
|
|
) |
|
|
self._nvfp4_default_dtype = dtype |
|
|
|
|
|
@classmethod |
|
|
def qcreate( |
|
|
cls, |
|
|
module, |
|
|
weights, |
|
|
activations=None, |
|
|
optimizer=None, |
|
|
device=None, |
|
|
): |
|
|
if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point: |
|
|
weight_dtype = module.weight.dtype |
|
|
elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point: |
|
|
weight_dtype = module.bias.dtype |
|
|
else: |
|
|
weight_dtype = torch.float16 |
|
|
return cls( |
|
|
module.in_features, |
|
|
module.out_features, |
|
|
module.bias is not None, |
|
|
device=device, |
|
|
dtype=weight_dtype, |
|
|
weights=weights, |
|
|
activations=activations, |
|
|
optimizer=optimizer, |
|
|
quantize_input=True, |
|
|
) |
|
|
|
|
|
def set_default_dtype(self, dtype): |
|
|
self._nvfp4_default_dtype = dtype |
|
|
|
|
|
@property |
|
|
def qweight(self): |
|
|
if self.weight_qtype == _NVFP4_QTYPE: |
|
|
return self.weight |
|
|
return super().qweight |
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
return torch.nn.functional.linear(input, self.qweight, bias=self.bias) |
|
|
|
|
|
def _load_from_state_dict( |
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
|
): |
|
|
if self.weight_qtype != _NVFP4_QTYPE: |
|
|
return super()._load_from_state_dict( |
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs |
|
|
) |
|
|
|
|
|
weight_key = prefix + "weight" |
|
|
scale_key = prefix + "weight_scale" |
|
|
scale2_key = prefix + "weight_scale_2" |
|
|
igs_key = prefix + "input_global_scale" |
|
|
alpha_key = prefix + "alpha" |
|
|
input_absmax_key = prefix + "input_absmax" |
|
|
weight_global_scale_key = prefix + "weight_global_scale" |
|
|
bias_key = prefix + "bias" |
|
|
input_scale_key = prefix + "input_scale" |
|
|
output_scale_key = prefix + "output_scale" |
|
|
|
|
|
weight_u8 = state_dict.pop(weight_key, None) |
|
|
weight_scale = state_dict.pop(scale_key, None) |
|
|
weight_scale_2 = state_dict.pop(scale2_key, None) |
|
|
input_global_scale = state_dict.pop(igs_key, None) |
|
|
alpha = state_dict.pop(alpha_key, None) |
|
|
input_absmax = state_dict.pop(input_absmax_key, None) |
|
|
weight_global_scale = state_dict.pop(weight_global_scale_key, None) |
|
|
bias = state_dict.pop(bias_key, None) |
|
|
input_scale = state_dict.pop(input_scale_key, None) |
|
|
output_scale = state_dict.pop(output_scale_key, None) |
|
|
|
|
|
if weight_u8 is None: |
|
|
missing_keys.append(weight_key) |
|
|
if weight_scale is None: |
|
|
missing_keys.append(scale_key) |
|
|
layout = _NVFP4_LAYOUT_LEGACY |
|
|
allow_kernel = True |
|
|
if weight_scale_2 is not None or input_scale is not None: |
|
|
layout = _NVFP4_LAYOUT_TENSORCORE |
|
|
if weight_scale_2 is None: |
|
|
missing_keys.append(scale2_key) |
|
|
if input_scale is None: |
|
|
allow_kernel = False |
|
|
if torch.is_tensor(weight_scale_2): |
|
|
input_scale = torch.full( |
|
|
(), |
|
|
float("nan"), |
|
|
dtype=weight_scale_2.dtype, |
|
|
device=weight_scale_2.device, |
|
|
) |
|
|
elif torch.is_tensor(weight_u8): |
|
|
input_scale = torch.full((), float("nan"), dtype=torch.float32, device=weight_u8.device) |
|
|
else: |
|
|
input_scale = torch.tensor(float("nan"), dtype=torch.float32) |
|
|
else: |
|
|
if input_global_scale is None or alpha is None: |
|
|
if input_absmax is not None and weight_global_scale is not None: |
|
|
input_global_scale = (2688.0 / input_absmax).to(torch.float32) |
|
|
alpha = 1.0 / (input_global_scale * weight_global_scale.to(torch.float32)) |
|
|
else: |
|
|
if input_global_scale is None: |
|
|
missing_keys.append(igs_key) |
|
|
if alpha is None: |
|
|
missing_keys.append(alpha_key) |
|
|
|
|
|
target_dtype = self._nvfp4_default_dtype or self.weight.dtype |
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
if weight_u8 is not None and weight_scale is not None and weight_scale_2 is not None and input_scale is not None: |
|
|
nvfp4_weight = NVFP4WeightTensor.create( |
|
|
weight_u8=weight_u8, |
|
|
weight_scale=weight_scale, |
|
|
input_global_scale=input_scale, |
|
|
alpha=weight_scale_2, |
|
|
allow_kernel=allow_kernel, |
|
|
size=self.weight.size(), |
|
|
stride=self.weight.stride(), |
|
|
dtype=target_dtype, |
|
|
device=weight_u8.device, |
|
|
requires_grad=False, |
|
|
layout=layout, |
|
|
) |
|
|
self.weight = torch.nn.Parameter(nvfp4_weight, requires_grad=False) |
|
|
else: |
|
|
if weight_u8 is not None and weight_scale is not None and input_global_scale is not None and alpha is not None: |
|
|
nvfp4_weight = NVFP4WeightTensor.create( |
|
|
weight_u8=weight_u8, |
|
|
weight_scale=weight_scale, |
|
|
input_global_scale=input_global_scale, |
|
|
alpha=alpha, |
|
|
size=self.weight.size(), |
|
|
stride=self.weight.stride(), |
|
|
dtype=target_dtype, |
|
|
device=weight_u8.device, |
|
|
requires_grad=False, |
|
|
layout=layout, |
|
|
) |
|
|
self.weight = torch.nn.Parameter(nvfp4_weight, requires_grad=False) |
|
|
|
|
|
if bias is not None: |
|
|
if target_dtype is not None and bias.dtype != target_dtype: |
|
|
bias = bias.to(target_dtype) |
|
|
self.bias = torch.nn.Parameter(bias) |
|
|
|
|
|
if torch.is_tensor(weight_u8): |
|
|
scale_device = weight_u8.device |
|
|
elif torch.is_tensor(self.weight): |
|
|
scale_device = self.weight.device |
|
|
elif torch.is_tensor(bias): |
|
|
scale_device = bias.device |
|
|
else: |
|
|
scale_device = torch.device("cpu") |
|
|
|
|
|
if input_scale is not None: |
|
|
self.input_scale = input_scale.to(scale_device) |
|
|
else: |
|
|
if not hasattr(self, "input_scale") or self.input_scale.is_meta: |
|
|
scale_dtype = self.input_scale.dtype if hasattr(self, "input_scale") else torch.float32 |
|
|
self.input_scale = torch.ones((), dtype=scale_dtype, device=scale_device) |
|
|
|
|
|
if output_scale is not None: |
|
|
self.output_scale = output_scale.to(scale_device) |
|
|
else: |
|
|
if not hasattr(self, "output_scale") or self.output_scale.is_meta: |
|
|
scale_dtype = self.output_scale.dtype if hasattr(self, "output_scale") else torch.float32 |
|
|
self.output_scale = torch.ones((), dtype=scale_dtype, device=scale_device) |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
def validate_nvfp4_kernel( |
|
|
state_dict=None, |
|
|
checkpoint_path=None, |
|
|
device=None, |
|
|
max_layers=4, |
|
|
seed=0, |
|
|
batch_size=2, |
|
|
dtype=torch.bfloat16, |
|
|
verbose=True, |
|
|
): |
|
|
"""Compare kernel vs fallback outputs for a few NVFP4 layers.""" |
|
|
if state_dict is None: |
|
|
if checkpoint_path is None: |
|
|
raise ValueError("state_dict or checkpoint_path is required") |
|
|
from mmgp import safetensors2 |
|
|
|
|
|
state_dict = {} |
|
|
with safetensors2.safe_open(checkpoint_path, framework="pt", device="cpu", writable_tensors=False) as f: |
|
|
for key in f.keys(): |
|
|
state_dict[key] = f.get_tensor(key) |
|
|
|
|
|
specs = _collect_nvfp4_specs(state_dict) |
|
|
if not specs: |
|
|
return {"ok": False, "reason": "no nvfp4 weights found"} |
|
|
|
|
|
candidates = sorted(specs, key=lambda spec: spec["weight"].numel()) |
|
|
if isinstance(max_layers, int) and max_layers > 0: |
|
|
candidates = candidates[:max_layers] |
|
|
|
|
|
device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
results = [] |
|
|
with torch.no_grad(): |
|
|
for spec in candidates: |
|
|
weight = spec["weight"] |
|
|
layout = spec.get("layout", _NVFP4_LAYOUT_LEGACY) |
|
|
in_features = weight.shape[1] * 2 |
|
|
bias = spec.get("bias") |
|
|
|
|
|
if layout == _NVFP4_LAYOUT_TENSORCORE: |
|
|
input_scale = spec.get("input_scale") |
|
|
tensor_scale = spec.get("weight_scale_2") |
|
|
else: |
|
|
input_scale = spec.get("input_global_scale") |
|
|
tensor_scale = spec.get("alpha") |
|
|
|
|
|
if input_scale is None or tensor_scale is None: |
|
|
results.append({"name": spec["name"], "layout": layout, "kernel": False, "reason": "missing scales"}) |
|
|
continue |
|
|
|
|
|
nvfp4_weight = NVFP4WeightTensor.create( |
|
|
weight_u8=weight, |
|
|
weight_scale=spec["weight_scale"], |
|
|
input_global_scale=input_scale, |
|
|
alpha=tensor_scale, |
|
|
size=(weight.shape[0], in_features), |
|
|
stride=(in_features, 1), |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
requires_grad=False, |
|
|
layout=layout, |
|
|
) |
|
|
|
|
|
x = torch.randn(batch_size, in_features, device=device, dtype=dtype) |
|
|
if bias is not None: |
|
|
bias = bias.to(device=device, dtype=dtype) |
|
|
|
|
|
kernel_ok = _nvfp4_can_use_kernel(x, nvfp4_weight) |
|
|
y_kernel = _nvfp4_linear_cuda(x, nvfp4_weight, bias=bias) if kernel_ok else None |
|
|
x_ref = x |
|
|
if ( |
|
|
layout == _NVFP4_LAYOUT_TENSORCORE |
|
|
and _ck_cuda_available |
|
|
and device.type == "cuda" |
|
|
and torch.is_tensor(input_scale) |
|
|
): |
|
|
input_scale_fp32 = input_scale.to(device=device, dtype=torch.float32) |
|
|
pad_16x = (x.shape[0] % 16) != 0 |
|
|
qx, qx_scale = _ck_cuda.quantize_nvfp4(x, input_scale_fp32, 0.0, pad_16x) |
|
|
x_ref = _ck_cuda.dequantize_nvfp4(qx, input_scale_fp32, qx_scale, output_type=dtype) |
|
|
if pad_16x: |
|
|
x_ref = x_ref[: x.shape[0]] |
|
|
y_ref = torch.nn.functional.linear( |
|
|
x_ref, |
|
|
nvfp4_weight.dequantize(dtype=dtype, device=device), |
|
|
bias, |
|
|
) |
|
|
|
|
|
if y_kernel is None: |
|
|
results.append({"name": spec["name"], "layout": layout, "kernel": False}) |
|
|
continue |
|
|
|
|
|
diff = (y_kernel - y_ref).float() |
|
|
results.append( |
|
|
{ |
|
|
"name": spec["name"], |
|
|
"layout": layout, |
|
|
"kernel": True, |
|
|
"max_abs": diff.abs().max().item(), |
|
|
"mean_abs": diff.abs().mean().item(), |
|
|
} |
|
|
) |
|
|
|
|
|
if verbose: |
|
|
print("NVFP4 kernel validation:") |
|
|
for entry in results: |
|
|
if not entry.get("kernel"): |
|
|
print(f" {entry['name']}: kernel skipped ({entry.get('reason', 'incompatible')})") |
|
|
continue |
|
|
print( |
|
|
f" {entry['name']}: max_abs={entry['max_abs']:.6f} mean_abs={entry['mean_abs']:.6f}" |
|
|
) |
|
|
|
|
|
return {"ok": True, "results": results} |
|
|
|