import ast import os import torch from torch.utils import _pytree as pytree from optimum.quanto import QModuleMixin from optimum.quanto.tensor.qtensor import QTensor from optimum.quanto.tensor.qtype import qtype as _quanto_qtype, qtypes as _quanto_qtypes def _maybe_add_nvfp4_cu13_dll_dir(): if os.name != "nt": return try: import nvidia.cu13 dll_dir = os.path.join(nvidia.cu13.__path__[0], "bin", "x86_64") if os.path.isdir(dll_dir): os.add_dll_directory(dll_dir) except Exception: pass try: from comfy_kitchen.backends import cuda as _ck_cuda _ck_cuda_available = getattr(_ck_cuda, "_EXT_AVAILABLE", False) except Exception: _ck_cuda = None _ck_cuda_available = False try: _maybe_add_nvfp4_cu13_dll_dir() from lightx2v_kernel import gemm as _lx_gemm _lx_gemm_available = True except Exception: _lx_gemm = None _lx_gemm_available = False _NVFP4_QTYPE_NAME = "nvfp4" if _NVFP4_QTYPE_NAME not in _quanto_qtypes: _quanto_qtypes[_NVFP4_QTYPE_NAME] = _quanto_qtype( _NVFP4_QTYPE_NAME, is_floating_point=True, bits=4, dtype=torch.uint8, qmin=-6.0, qmax=6.0, ) _NVFP4_QTYPE = _quanto_qtypes[_NVFP4_QTYPE_NAME] HANDLER_PRIORITY = 1 _NVFP4_LAYOUT_LEGACY = "legacy" _NVFP4_LAYOUT_TENSORCORE = "tensorcore" _NVFP4_BACKEND_AUTO = "auto" _NVFP4_BACKEND_COMFY = "comfy" _NVFP4_BACKEND_LIGHTX2V = "lightx2v" _NVFP4_KERNEL_LOGGED = False _NVFP4_FALLBACK_LOGGED = False _NVFP4_LOAD_LOGGED = False _NVFP4_KERNEL_AVAILABLE = False _NVFP4_KERNEL_CHECKED = False _NVFP4_KERNEL_BACKEND = None _NVFP4_ACT_SCALE_CACHE = {} _NVFP4_BACKEND = os.environ.get("WGP_NVFP4_BACKEND", _NVFP4_BACKEND_AUTO).strip().lower() _NVFP4_BACKEND = _NVFP4_BACKEND_LIGHTX2V def _normalize_nvfp4_backend(name): if name is None: return _NVFP4_BACKEND_AUTO norm = str(name).strip().lower() if norm in ("", "auto", "default"): return _NVFP4_BACKEND_AUTO if norm in ("comfy", "comfy-kitchen", "comfy_kitchen", "ck"): return _NVFP4_BACKEND_COMFY if norm in ("lightx2v", "lightx2v_kernel", "lightx2v-kernel", "lx"): return _NVFP4_BACKEND_LIGHTX2V if norm in ("off", "none", "fallback", "disable", "disabled"): return "fallback" return norm _NVFP4_BACKEND = _normalize_nvfp4_backend(_NVFP4_BACKEND) def _nvfp4_backend_candidates(): if _NVFP4_BACKEND == _NVFP4_BACKEND_AUTO: return [_NVFP4_BACKEND_COMFY, _NVFP4_BACKEND_LIGHTX2V] if _NVFP4_BACKEND in (_NVFP4_BACKEND_COMFY, _NVFP4_BACKEND_LIGHTX2V): return [_NVFP4_BACKEND] return [] def _nvfp4_backend_label(backend): if backend == _NVFP4_BACKEND_LIGHTX2V: return "lightx2v" if backend == _NVFP4_BACKEND_COMFY: return "comfy-kitchen" return backend def _nvfp4_lightx2v_device_ok(device): force = os.environ.get("WGP_NVFP4_LIGHTX2V_FORCE", "").strip().lower() if force in ("1", "true", "yes", "y"): return True try: props = torch.cuda.get_device_properties(device) except Exception: return False return props.major >= 12 def set_nvfp4_backend(name): global _NVFP4_BACKEND, _NVFP4_KERNEL_CHECKED, _NVFP4_KERNEL_AVAILABLE, _NVFP4_KERNEL_BACKEND global _NVFP4_KERNEL_LOGGED, _NVFP4_LOAD_LOGGED _NVFP4_BACKEND = _normalize_nvfp4_backend(name) _NVFP4_KERNEL_CHECKED = False _NVFP4_KERNEL_AVAILABLE = False _NVFP4_KERNEL_BACKEND = None _NVFP4_KERNEL_LOGGED = False _NVFP4_LOAD_LOGGED = False _init_nvfp4_kernel_support() def _nvfp4_note_kernel(): global _NVFP4_KERNEL_LOGGED if not _NVFP4_KERNEL_LOGGED: label = _nvfp4_backend_label(_NVFP4_KERNEL_BACKEND) if _NVFP4_KERNEL_BACKEND else "CUDA" print(f"NVFP4: using {label} kernel") _NVFP4_KERNEL_LOGGED = True def _nvfp4_note_fallback(): global _NVFP4_FALLBACK_LOGGED global _NVFP4_KERNEL_LOGGED if not _NVFP4_FALLBACK_LOGGED: if _NVFP4_KERNEL_LOGGED: print("NVFP4: linear fallback needed on some weights") else: print("NVFP4: linear fallback") _NVFP4_FALLBACK_LOGGED = True def _nvfp4_note_reset(): global _NVFP4_FALLBACK_LOGGED global _NVFP4_KERNEL_LOGGED global _NVFP4_LOAD_LOGGED _NVFP4_KERNEL_LOGGED = False _NVFP4_FALLBACK_LOGGED = False _NVFP4_LOAD_LOGGED = False def _nvfp4_note_load_backend(): global _NVFP4_LOAD_LOGGED if _NVFP4_LOAD_LOGGED: return _NVFP4_LOAD_LOGGED = True if _NVFP4_KERNEL_AVAILABLE: label = _nvfp4_backend_label(_NVFP4_KERNEL_BACKEND) if _NVFP4_KERNEL_BACKEND else "unknown" print(f"NVFP4: kernels available ({label}); optimized path will be used when compatible.") else: print("NVFP4: kernels unavailable; using fallback.") def _check_nvfp4_kernel_support(device, backend): if device.type != "cuda": return False if backend == _NVFP4_BACKEND_COMFY: if not _ck_cuda_available: return False if not hasattr(_ck_cuda, "scaled_mm_nvfp4"): return False if not hasattr(_ck_cuda, "quantize_nvfp4"): return False if not (hasattr(torch.ops, "comfy_kitchen") and hasattr(torch.ops.comfy_kitchen, "scaled_mm_nvfp4")): return False major, minor = torch.cuda.get_device_capability(device) return (major, minor) >= (10, 0) if backend == _NVFP4_BACKEND_LIGHTX2V: if not _lx_gemm_available: return False if not _nvfp4_lightx2v_device_ok(device): return False if not (hasattr(torch.ops, "lightx2v_kernel") and hasattr(torch.ops.lightx2v_kernel, "cutlass_scaled_nvfp4_mm_sm120")): return False if not hasattr(torch.ops.lightx2v_kernel, "scaled_nvfp4_quant_sm120"): return False major, minor = torch.cuda.get_device_capability(device) return (major, minor) >= (12, 0) return False def _init_nvfp4_kernel_support(): global _NVFP4_KERNEL_AVAILABLE, _NVFP4_KERNEL_CHECKED, _NVFP4_KERNEL_BACKEND if _NVFP4_KERNEL_CHECKED: return _NVFP4_KERNEL_CHECKED = True _NVFP4_KERNEL_AVAILABLE = False _NVFP4_KERNEL_BACKEND = None if not torch.cuda.is_available(): return device = torch.device("cuda") for backend in _nvfp4_backend_candidates(): try: if _check_nvfp4_kernel_support(device, backend): _NVFP4_KERNEL_AVAILABLE = True _NVFP4_KERNEL_BACKEND = backend break except Exception: continue def _supports_nvfp4_kernel(device): if device.type != "cuda": return False if not _NVFP4_KERNEL_CHECKED: _init_nvfp4_kernel_support() return _NVFP4_KERNEL_AVAILABLE _init_nvfp4_kernel_support() def _nvfp4_layout(weight): return getattr(weight, "_layout", _NVFP4_LAYOUT_LEGACY) def _nvfp4_can_use_kernel(input, weight): if not torch.is_tensor(input): return False if not getattr(weight, "_allow_kernel", True): return False if not _supports_nvfp4_kernel(input.device): return False backend = _NVFP4_KERNEL_BACKEND if backend is None: return False layout = _nvfp4_layout(weight) if backend == _NVFP4_BACKEND_LIGHTX2V: if input.shape[-1] % 32 != 0: return False if weight.size(0) % 32 != 0: return False else: if layout == _NVFP4_LAYOUT_LEGACY: if input.shape[-1] % 64 != 0: return False else: if input.shape[-1] % 16 != 0: return False if weight.size(0) % 8 != 0: return False if weight._data.shape[1] * 2 != input.shape[-1]: return False if weight._block_size != 16: return False if not torch.is_tensor(weight._input_global_scale) or not torch.is_tensor(weight._alpha): return False if getattr(weight._input_global_scale, "is_meta", False): return False try: if not torch.isfinite(weight._input_global_scale).all(): return False except Exception: return False return True def _nvfp4_get_act_scale(device): act_scale = _NVFP4_ACT_SCALE_CACHE.get(device) if act_scale is None: act_scale = torch.tensor(1.0, device=device, dtype=torch.float32) _NVFP4_ACT_SCALE_CACHE[device] = act_scale return act_scale def _nvfp4_swap_nibbles(tensor): return ((tensor & 0x0F) << 4) | ((tensor & 0xF0) >> 4) def _nvfp4_linear_cuda_comfy(input, weight, bias=None): _nvfp4_note_kernel() x2d = input.reshape(-1, input.shape[-1]) if not x2d.is_floating_point(): x2d = x2d.to(torch.float16) orig_dtype = x2d.dtype if orig_dtype not in (torch.float16, torch.bfloat16): x2d = x2d.to(torch.float16) out_dtype = torch.float16 else: out_dtype = orig_dtype if not x2d.is_contiguous(): x2d = x2d.contiguous() weight_fp4 = weight._data weight_scale = weight._scale input_scale = weight._input_global_scale alpha = weight._alpha layout = _nvfp4_layout(weight) device = x2d.device if weight_fp4.device != device: weight_fp4 = weight_fp4.to(device) if weight_scale.device != device: weight_scale = weight_scale.to(device) if input_scale.device != device: input_scale = input_scale.to(device) if alpha.device != device: alpha = alpha.to(device) if bias is not None and torch.is_tensor(bias) and bias.dtype != out_dtype: bias = bias.to(out_dtype) orig_rows = x2d.shape[0] pad_16x = (orig_rows % 16) != 0 if layout == _NVFP4_LAYOUT_TENSORCORE: input_scale = input_scale.to(torch.float32) tensor_scale = alpha.to(torch.float32) qx, qx_scale = _ck_cuda.quantize_nvfp4(x2d, input_scale, 0.0, pad_16x) out = _ck_cuda.scaled_mm_nvfp4( qx, weight_fp4, tensor_scale_a=input_scale, tensor_scale_b=tensor_scale, block_scale_a=qx_scale, block_scale_b=weight_scale, bias=bias, out_dtype=out_dtype, ) else: alpha = alpha * input_scale if alpha.dtype != torch.float32: alpha = alpha.to(torch.float32) act_scale = _nvfp4_get_act_scale(device) qx, qx_scale = _ck_cuda.quantize_nvfp4(x2d, act_scale, 0.0, pad_16x) weight_fp4 = _nvfp4_swap_nibbles(weight_fp4) out = _ck_cuda.scaled_mm_nvfp4( qx, weight_fp4, act_scale, input_scale, qx_scale, weight_scale, bias=bias, out_dtype=out_dtype, alpha=alpha, ) if pad_16x: out = out[:orig_rows] if out.dtype != orig_dtype: out = out.to(orig_dtype) return out.reshape(*input.shape[:-1], weight.size(0)) def _nvfp4_linear_cuda_lightx2v(input, weight, bias=None): _nvfp4_note_kernel() x2d = input.reshape(-1, input.shape[-1]) if not x2d.is_floating_point(): x2d = x2d.to(torch.float16) orig_dtype = x2d.dtype if orig_dtype not in (torch.float16, torch.bfloat16): x2d = x2d.to(torch.float16) out_dtype = torch.float16 else: out_dtype = orig_dtype if not x2d.is_contiguous(): x2d = x2d.contiguous() weight_fp4 = weight._data weight_scale = weight._scale input_scale = weight._input_global_scale alpha = weight._alpha layout = _nvfp4_layout(weight) device = x2d.device if weight_fp4.device != device: weight_fp4 = weight_fp4.to(device) if weight_scale.device != device: weight_scale = weight_scale.to(device) if not weight_fp4.is_contiguous(): weight_fp4 = weight_fp4.contiguous() if not weight_scale.is_contiguous(): weight_scale = weight_scale.contiguous() if input_scale.device != device: input_scale = input_scale.to(device) if alpha.device != device: alpha = alpha.to(device) if input_scale.dtype != torch.float32: input_scale = input_scale.to(torch.float32) if alpha.dtype != torch.float32: alpha = alpha.to(torch.float32) if bias is not None and torch.is_tensor(bias): if bias.dtype != torch.bfloat16: bias = bias.to(torch.bfloat16) if not bias.is_contiguous(): bias = bias.contiguous() if layout == _NVFP4_LAYOUT_TENSORCORE: quant_scale = torch.reciprocal(torch.clamp(input_scale, min=1e-8)) alpha = alpha * input_scale else: quant_scale = input_scale qx, qx_scale = _lx_gemm.scaled_nvfp4_quant(x2d, quant_scale) if layout == _NVFP4_LAYOUT_TENSORCORE: qx = _nvfp4_swap_nibbles(qx) if not qx.is_contiguous(): qx = qx.contiguous() if not qx_scale.is_contiguous(): qx_scale = qx_scale.contiguous() out = _lx_gemm.cutlass_scaled_nvfp4_mm( qx, weight_fp4, qx_scale, weight_scale, alpha=alpha, bias=bias, ) if out.dtype != orig_dtype: out = out.to(orig_dtype) return out.reshape(*input.shape[:-1], weight.size(0)) def _nvfp4_linear_cuda(input, weight, bias=None): if _NVFP4_KERNEL_BACKEND == _NVFP4_BACKEND_LIGHTX2V: return _nvfp4_linear_cuda_lightx2v(input, weight, bias=bias) return _nvfp4_linear_cuda_comfy(input, weight, bias=bias) @torch.compiler.disable() def _nvfp4_linear(input, weight, bias=None, op=None): if _nvfp4_can_use_kernel(input, weight): return _nvfp4_linear_cuda(input, weight, bias=bias) _nvfp4_note_fallback() dtype = input.dtype if torch.is_tensor(input) else weight.dtype device = input.device if torch.is_tensor(input) else weight.device w = weight.dequantize(dtype=dtype, device=device) if bias is not None and torch.is_tensor(bias) and bias.dtype != dtype: bias = bias.to(dtype) if op is not None: return op(input, w, bias) return torch.nn.functional.linear(input, w, bias) def _is_float8_dtype(dtype): return "float8" in str(dtype).lower() or "f8" in str(dtype).lower() _FP4_LUT_BASE = torch.tensor( [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], dtype=torch.float32, ) _FP4_LUT_CACHE = {} _FP4_BYTE_LUT_CACHE = {} def _get_fp4_lut(device, dtype): key = (device, dtype) lut = _FP4_LUT_CACHE.get(key) if lut is None: lut = _FP4_LUT_BASE.to(device=device, dtype=dtype) _FP4_LUT_CACHE[key] = lut return lut def _get_fp4_byte_lut(device, dtype): key = (device, dtype) byte_lut = _FP4_BYTE_LUT_CACHE.get(key) if byte_lut is None: lut16 = _get_fp4_lut(device, dtype) b = torch.arange(256, device=device, dtype=torch.int32) byte_lut = torch.empty((256, 2), device=device, dtype=dtype) byte_lut[:, 0] = lut16[b & 0x0F] byte_lut[:, 1] = lut16[b >> 4] _FP4_BYTE_LUT_CACHE[key] = byte_lut return byte_lut def _deswizzle_nvfp4_scale(scale, in_features, block_size=16, dtype=None): k_groups = in_features // block_size if scale.shape[1] < k_groups: raise RuntimeError( f"NVFP4 scale shape mismatch: expected at least {k_groups} groups, got {scale.shape[1]}" ) if scale.shape[1] > k_groups: scale = scale[:, :k_groups] m, _ = scale.shape m_tiles = (m + 128 - 1) // 128 f = block_size * 4 k_tiles = (in_features + f - 1) // f tmp = scale if dtype is None else scale.to(dtype) tmp = tmp.reshape(1, m_tiles, k_tiles, 32, 4, 4) tmp = tmp.permute(0, 1, 4, 3, 2, 5) out = tmp.reshape(m_tiles * 128, k_tiles * 4) return out[:m, :k_groups] def _dequantize_nvfp4_weight( weight_u8, weight_scale, input_global_scale, alpha, dtype, device, block_size=16, layout=_NVFP4_LAYOUT_LEGACY, ): if weight_u8.device != device: weight_u8 = weight_u8.to(device) scale = weight_scale if weight_scale.device == device else weight_scale.to(device) if alpha.device != device: alpha = alpha.to(device) if input_global_scale.device != device: input_global_scale = input_global_scale.to(device) if layout == _NVFP4_LAYOUT_TENSORCORE and device.type == "cuda" and _ck_cuda_available: try: return _ck_cuda.dequantize_nvfp4(weight_u8, alpha.to(torch.float32), scale, output_type=dtype) except Exception: pass m, k_bytes = weight_u8.shape byte_lut = _get_fp4_byte_lut(device, dtype) if layout == _NVFP4_LAYOUT_TENSORCORE: idx = _nvfp4_swap_nibbles(weight_u8).to(torch.int32) else: idx = weight_u8.to(torch.int32) out = byte_lut[idx].reshape(m, k_bytes * 2) scale = _deswizzle_nvfp4_scale(scale, out.shape[1], block_size=block_size, dtype=dtype) out = out.view(out.shape[0], scale.shape[1], block_size) out.mul_(scale.unsqueeze(-1)) out = out.view(out.shape[0], -1) if layout == _NVFP4_LAYOUT_TENSORCORE: scale_factor = alpha.to(dtype) else: scale_factor = alpha.to(dtype) * input_global_scale.to(dtype) out.mul_(scale_factor) return out def _collect_nvfp4_specs(state_dict): specs = [] for key, tensor in state_dict.items(): if not key.endswith(".weight"): continue if tensor.dtype != torch.uint8: continue base = key[:-7] scale_key = base + ".weight_scale" if scale_key not in state_dict: continue if not _is_float8_dtype(state_dict[scale_key].dtype): continue weight_scale_2_key = base + ".weight_scale_2" input_scale_key = base + ".input_scale" if weight_scale_2_key in state_dict: specs.append( { "name": base, "weight": tensor, "weight_scale": state_dict[scale_key], "weight_scale_2": state_dict[weight_scale_2_key], "input_scale": state_dict.get(input_scale_key, None), "bias": state_dict.get(base + ".bias", None), "layout": _NVFP4_LAYOUT_TENSORCORE, } ) continue input_global_key = base + ".input_global_scale" alpha_key = base + ".alpha" input_absmax_key = base + ".input_absmax" weight_global_scale_key = base + ".weight_global_scale" if input_global_key not in state_dict or alpha_key not in state_dict: if input_absmax_key not in state_dict or weight_global_scale_key not in state_dict: continue input_absmax = state_dict[input_absmax_key] weight_global_scale = state_dict[weight_global_scale_key] input_global_scale = (2688.0 / input_absmax).to(torch.float32) alpha = 1.0 / (input_global_scale * weight_global_scale.to(torch.float32)) else: input_global_scale = state_dict[input_global_key] alpha = state_dict[alpha_key] specs.append( { "name": base, "weight": tensor, "weight_scale": state_dict[scale_key], "input_global_scale": input_global_scale, "alpha": alpha, "bias": state_dict.get(base + ".bias", None), "layout": _NVFP4_LAYOUT_LEGACY, } ) return specs def detect_nvfp4_state_dict(state_dict): return len(_collect_nvfp4_specs(state_dict)) > 0 def describe_nvfp4_state_dict(state_dict, max_names=8): specs = _collect_nvfp4_specs(state_dict) names = [spec["name"] for spec in specs] return {"count": len(names), "names": names[:max_names]} def convert_nvfp4_to_quanto(state_dict, default_dtype=None, verboseLevel=1): specs = _collect_nvfp4_specs(state_dict) if not specs: return {"state_dict": state_dict, "quant_map": {}} _nvfp4_note_load_backend() quant_map = {} for spec in specs: qcfg = {"weights": "nvfp4", "activations": "none"} quant_map[spec["name"]] = qcfg quant_map[spec["name"] + ".weight"] = qcfg return {"state_dict": state_dict, "quant_map": quant_map} def detect(state_dict, verboseLevel=1): matched = detect_nvfp4_state_dict(state_dict) details = describe_nvfp4_state_dict(state_dict) if matched else {} return {"matched": matched, "kind": "nvfp4" if matched else "none", "details": details} def convert_to_quanto(state_dict, default_dtype, verboseLevel=1, detection=None): if detection is not None and not detection.get("matched", False): return {"state_dict": state_dict, "quant_map": {}} _nvfp4_note_reset() return convert_nvfp4_to_quanto(state_dict, default_dtype=default_dtype, verboseLevel=verboseLevel) def apply_pre_quantization(model, state_dict, quantization_map, default_dtype=None, verboseLevel=1): return quantization_map, [] def _nvfp4_qfallback(callable, *args, **kwargs): args, kwargs = pytree.tree_map_only(NVFP4WeightTensor, lambda x: x.dequantize(), (args, kwargs or {})) return callable(*args, **kwargs) class NVFP4WeightTensor(QTensor): @staticmethod def create( weight_u8, weight_scale, size, stride, dtype, input_global_scale=None, alpha=None, input_scale=None, weight_scale_2=None, device=None, requires_grad=False, layout=_NVFP4_LAYOUT_LEGACY, allow_kernel=True, ): if input_global_scale is None and input_scale is not None: input_global_scale = input_scale if alpha is None and weight_scale_2 is not None: alpha = weight_scale_2 if layout == _NVFP4_LAYOUT_LEGACY and (weight_scale_2 is not None or input_scale is not None): layout = _NVFP4_LAYOUT_TENSORCORE if input_global_scale is None or alpha is None: raise ValueError("NVFP4WeightTensor.create requires input_global_scale/alpha or input_scale/weight_scale_2") if torch.is_tensor(input_global_scale): try: if not torch.isfinite(input_global_scale).all(): allow_kernel = False except Exception: allow_kernel = False device = weight_u8.device if device is None else device if weight_u8.device != device: weight_u8 = weight_u8.to(device) if weight_scale.device != device: weight_scale = weight_scale.to(device) if input_global_scale.device != device: input_global_scale = input_global_scale.to(device) if alpha.device != device: alpha = alpha.to(device) return NVFP4WeightTensor( qtype=_NVFP4_QTYPE, axis=0, size=size, stride=stride, weight_u8=weight_u8, weight_scale=weight_scale, input_global_scale=input_global_scale, alpha=alpha, allow_kernel=allow_kernel, dtype=dtype, requires_grad=requires_grad, layout=layout, ) @staticmethod def __new__( cls, qtype, axis, size, stride, weight_u8, weight_scale, input_global_scale, alpha, dtype, allow_kernel=True, requires_grad=False, layout=_NVFP4_LAYOUT_LEGACY, ): return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=dtype, device=weight_u8.device, requires_grad=requires_grad, ) def __init__( self, qtype, axis, size, stride, weight_u8, weight_scale, input_global_scale, alpha, dtype, requires_grad=False, layout=_NVFP4_LAYOUT_LEGACY, allow_kernel=True, ): super().__init__(qtype, axis) self._data = weight_u8 self._scale = weight_scale self._input_global_scale = input_global_scale self._alpha = alpha self._block_size = 16 self._layout = layout self._allow_kernel = allow_kernel def dequantize(self, dtype=None, device=None): if dtype is None: dtype = self.dtype if device is None: device = self.device return _dequantize_nvfp4_weight( weight_u8=self._data, weight_scale=self._scale, input_global_scale=self._input_global_scale, alpha=self._alpha, dtype=dtype, device=device, block_size=self._block_size, layout=self._layout, ) def get_quantized_subtensors(self): if self._layout == _NVFP4_LAYOUT_TENSORCORE: return [ ("weight_u8", self._data), ("weight_scale", self._scale), ("weight_scale_2", self._alpha), ("input_scale", self._input_global_scale), ] return [ ("weight_u8", self._data), ("weight_scale", self._scale), ("input_global_scale", self._input_global_scale), ("alpha", self._alpha), ] def set_quantized_subtensors(self, sub_tensors): if isinstance(sub_tensors, dict): sub_map = sub_tensors else: sub_map = {name: tensor for name, tensor in sub_tensors} data = sub_map.get("weight_u8", sub_map.get("data")) if data is not None: self._data = data if "weight_scale" in sub_map and sub_map["weight_scale"] is not None: self._scale = sub_map["weight_scale"] if "input_scale" in sub_map and sub_map["input_scale"] is not None: self._input_global_scale = sub_map["input_scale"] elif "input_global_scale" in sub_map and sub_map["input_global_scale"] is not None: self._input_global_scale = sub_map["input_global_scale"] if "weight_scale_2" in sub_map and sub_map["weight_scale_2"] is not None: self._alpha = sub_map["weight_scale_2"] elif "alpha" in sub_map and sub_map["alpha"] is not None: self._alpha = sub_map["alpha"] def __tensor_flatten__(self): inner_tensors = ["_data", "_scale", "_input_global_scale", "_alpha"] meta = { "qtype": self._qtype.name, "axis": str(self._axis), "size": str(list(self.size())), "stride": str(list(self.stride())), "dtype": str(self.dtype), "layout": self._layout, "allow_kernel": "1" if self._allow_kernel else "0", } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): qtype = _quanto_qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) dtype_str = meta.get("dtype", "torch.float16") if dtype_str.startswith("torch."): dtype_name = dtype_str.split(".", 1)[1] dtype = getattr(torch, dtype_name, torch.float16) else: dtype = getattr(torch, dtype_str, torch.float16) layout = meta.get("layout", _NVFP4_LAYOUT_LEGACY) allow_kernel = str(meta.get("allow_kernel", "1")).strip().lower() not in ("0", "false", "no") return NVFP4WeightTensor( qtype=qtype, axis=axis, size=size, stride=stride, weight_u8=inner_tensors["_data"], weight_scale=inner_tensors["_scale"], input_global_scale=inner_tensors["_input_global_scale"], alpha=inner_tensors["_alpha"], allow_kernel=allow_kernel, dtype=dtype, layout=layout, ) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} if func is torch.nn.functional.linear: input = args[0] if len(args) > 0 else kwargs.get("input", None) weight = args[1] if len(args) > 1 else kwargs.get("weight", None) bias = args[2] if len(args) > 2 else kwargs.get("bias", None) if isinstance(weight, NVFP4WeightTensor): return _nvfp4_linear(input, weight, bias=bias) with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): op = op.overloadpacket if op is torch.ops.aten.linear: input = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None if isinstance(weight, NVFP4WeightTensor): return _nvfp4_linear(input, weight, bias=bias, op=op) if op is torch.ops.aten.detach: t = args[0] return NVFP4WeightTensor.create( weight_u8=op(t._data), weight_scale=op(t._scale), input_global_scale=op(t._input_global_scale), alpha=op(t._alpha), allow_kernel=getattr(t, "_allow_kernel", True), size=t.size(), stride=t.stride(), dtype=t.dtype, device=t.device, requires_grad=t.requires_grad, layout=t._layout, ) if op in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.pop("dtype", t.dtype) if kwargs else t.dtype device = kwargs.pop("device", t.device) if kwargs else t.device if dtype != t.dtype: return t.dequantize(dtype=dtype, device=device) out_data = op(t._data, device=device, **(kwargs or {})) out_scale = op(t._scale, device=device, **(kwargs or {})) out_igs = op(t._input_global_scale, device=device, **(kwargs or {})) out_alpha = op(t._alpha, device=device, **(kwargs or {})) return NVFP4WeightTensor.create( weight_u8=out_data, weight_scale=out_scale, input_global_scale=out_igs, alpha=out_alpha, allow_kernel=getattr(t, "_allow_kernel", True), size=t.size(), stride=t.stride(), dtype=t.dtype, device=device, requires_grad=t.requires_grad, layout=t._layout, ) return _nvfp4_qfallback(op, *args, **(kwargs or {})) class QLinearNVFP4(QModuleMixin, torch.nn.Linear): def __init__( self, in_features, out_features, bias=True, device=None, dtype=None, weights=None, activations=None, optimizer=None, quantize_input=True, ): super().__init__( in_features, out_features, bias=bias, device=device, dtype=dtype, weights=weights, activations=activations, optimizer=optimizer, quantize_input=quantize_input, ) self._nvfp4_default_dtype = dtype @classmethod def qcreate( cls, module, weights, activations=None, optimizer=None, device=None, ): if torch.is_tensor(module.weight) and module.weight.dtype.is_floating_point: weight_dtype = module.weight.dtype elif torch.is_tensor(getattr(module, "bias", None)) and module.bias.dtype.is_floating_point: weight_dtype = module.bias.dtype else: weight_dtype = torch.float16 return cls( module.in_features, module.out_features, module.bias is not None, device=device, dtype=weight_dtype, weights=weights, activations=activations, optimizer=optimizer, quantize_input=True, ) def set_default_dtype(self, dtype): self._nvfp4_default_dtype = dtype @property def qweight(self): if self.weight_qtype == _NVFP4_QTYPE: return self.weight return super().qweight def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.linear(input, self.qweight, bias=self.bias) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): if self.weight_qtype != _NVFP4_QTYPE: return super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) weight_key = prefix + "weight" scale_key = prefix + "weight_scale" scale2_key = prefix + "weight_scale_2" igs_key = prefix + "input_global_scale" alpha_key = prefix + "alpha" input_absmax_key = prefix + "input_absmax" weight_global_scale_key = prefix + "weight_global_scale" bias_key = prefix + "bias" input_scale_key = prefix + "input_scale" output_scale_key = prefix + "output_scale" weight_u8 = state_dict.pop(weight_key, None) weight_scale = state_dict.pop(scale_key, None) weight_scale_2 = state_dict.pop(scale2_key, None) input_global_scale = state_dict.pop(igs_key, None) alpha = state_dict.pop(alpha_key, None) input_absmax = state_dict.pop(input_absmax_key, None) weight_global_scale = state_dict.pop(weight_global_scale_key, None) bias = state_dict.pop(bias_key, None) input_scale = state_dict.pop(input_scale_key, None) output_scale = state_dict.pop(output_scale_key, None) if weight_u8 is None: missing_keys.append(weight_key) if weight_scale is None: missing_keys.append(scale_key) layout = _NVFP4_LAYOUT_LEGACY allow_kernel = True if weight_scale_2 is not None or input_scale is not None: layout = _NVFP4_LAYOUT_TENSORCORE if weight_scale_2 is None: missing_keys.append(scale2_key) if input_scale is None: allow_kernel = False if torch.is_tensor(weight_scale_2): input_scale = torch.full( (), float("nan"), dtype=weight_scale_2.dtype, device=weight_scale_2.device, ) elif torch.is_tensor(weight_u8): input_scale = torch.full((), float("nan"), dtype=torch.float32, device=weight_u8.device) else: input_scale = torch.tensor(float("nan"), dtype=torch.float32) else: if input_global_scale is None or alpha is None: if input_absmax is not None and weight_global_scale is not None: input_global_scale = (2688.0 / input_absmax).to(torch.float32) alpha = 1.0 / (input_global_scale * weight_global_scale.to(torch.float32)) else: if input_global_scale is None: missing_keys.append(igs_key) if alpha is None: missing_keys.append(alpha_key) target_dtype = self._nvfp4_default_dtype or self.weight.dtype if layout == _NVFP4_LAYOUT_TENSORCORE: if weight_u8 is not None and weight_scale is not None and weight_scale_2 is not None and input_scale is not None: nvfp4_weight = NVFP4WeightTensor.create( weight_u8=weight_u8, weight_scale=weight_scale, input_global_scale=input_scale, alpha=weight_scale_2, allow_kernel=allow_kernel, size=self.weight.size(), stride=self.weight.stride(), dtype=target_dtype, device=weight_u8.device, requires_grad=False, layout=layout, ) self.weight = torch.nn.Parameter(nvfp4_weight, requires_grad=False) else: if weight_u8 is not None and weight_scale is not None and input_global_scale is not None and alpha is not None: nvfp4_weight = NVFP4WeightTensor.create( weight_u8=weight_u8, weight_scale=weight_scale, input_global_scale=input_global_scale, alpha=alpha, size=self.weight.size(), stride=self.weight.stride(), dtype=target_dtype, device=weight_u8.device, requires_grad=False, layout=layout, ) self.weight = torch.nn.Parameter(nvfp4_weight, requires_grad=False) if bias is not None: if target_dtype is not None and bias.dtype != target_dtype: bias = bias.to(target_dtype) self.bias = torch.nn.Parameter(bias) if torch.is_tensor(weight_u8): scale_device = weight_u8.device elif torch.is_tensor(self.weight): scale_device = self.weight.device elif torch.is_tensor(bias): scale_device = bias.device else: scale_device = torch.device("cpu") if input_scale is not None: self.input_scale = input_scale.to(scale_device) else: if not hasattr(self, "input_scale") or self.input_scale.is_meta: scale_dtype = self.input_scale.dtype if hasattr(self, "input_scale") else torch.float32 self.input_scale = torch.ones((), dtype=scale_dtype, device=scale_device) if output_scale is not None: self.output_scale = output_scale.to(scale_device) else: if not hasattr(self, "output_scale") or self.output_scale.is_meta: scale_dtype = self.output_scale.dtype if hasattr(self, "output_scale") else torch.float32 self.output_scale = torch.ones((), dtype=scale_dtype, device=scale_device) return def validate_nvfp4_kernel( state_dict=None, checkpoint_path=None, device=None, max_layers=4, seed=0, batch_size=2, dtype=torch.bfloat16, verbose=True, ): """Compare kernel vs fallback outputs for a few NVFP4 layers.""" if state_dict is None: if checkpoint_path is None: raise ValueError("state_dict or checkpoint_path is required") from mmgp import safetensors2 state_dict = {} with safetensors2.safe_open(checkpoint_path, framework="pt", device="cpu", writable_tensors=False) as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) specs = _collect_nvfp4_specs(state_dict) if not specs: return {"ok": False, "reason": "no nvfp4 weights found"} candidates = sorted(specs, key=lambda spec: spec["weight"].numel()) if isinstance(max_layers, int) and max_layers > 0: candidates = candidates[:max_layers] device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) results = [] with torch.no_grad(): for spec in candidates: weight = spec["weight"] layout = spec.get("layout", _NVFP4_LAYOUT_LEGACY) in_features = weight.shape[1] * 2 bias = spec.get("bias") if layout == _NVFP4_LAYOUT_TENSORCORE: input_scale = spec.get("input_scale") tensor_scale = spec.get("weight_scale_2") else: input_scale = spec.get("input_global_scale") tensor_scale = spec.get("alpha") if input_scale is None or tensor_scale is None: results.append({"name": spec["name"], "layout": layout, "kernel": False, "reason": "missing scales"}) continue nvfp4_weight = NVFP4WeightTensor.create( weight_u8=weight, weight_scale=spec["weight_scale"], input_global_scale=input_scale, alpha=tensor_scale, size=(weight.shape[0], in_features), stride=(in_features, 1), dtype=dtype, device=device, requires_grad=False, layout=layout, ) x = torch.randn(batch_size, in_features, device=device, dtype=dtype) if bias is not None: bias = bias.to(device=device, dtype=dtype) kernel_ok = _nvfp4_can_use_kernel(x, nvfp4_weight) y_kernel = _nvfp4_linear_cuda(x, nvfp4_weight, bias=bias) if kernel_ok else None x_ref = x if ( layout == _NVFP4_LAYOUT_TENSORCORE and _ck_cuda_available and device.type == "cuda" and torch.is_tensor(input_scale) ): input_scale_fp32 = input_scale.to(device=device, dtype=torch.float32) pad_16x = (x.shape[0] % 16) != 0 qx, qx_scale = _ck_cuda.quantize_nvfp4(x, input_scale_fp32, 0.0, pad_16x) x_ref = _ck_cuda.dequantize_nvfp4(qx, input_scale_fp32, qx_scale, output_type=dtype) if pad_16x: x_ref = x_ref[: x.shape[0]] y_ref = torch.nn.functional.linear( x_ref, nvfp4_weight.dequantize(dtype=dtype, device=device), bias, ) if y_kernel is None: results.append({"name": spec["name"], "layout": layout, "kernel": False}) continue diff = (y_kernel - y_ref).float() results.append( { "name": spec["name"], "layout": layout, "kernel": True, "max_abs": diff.abs().max().item(), "mean_abs": diff.abs().mean().item(), } ) if verbose: print("NVFP4 kernel validation:") for entry in results: if not entry.get("kernel"): print(f" {entry['name']}: kernel skipped ({entry.get('reason', 'incompatible')})") continue print( f" {entry['name']}: max_abs={entry['max_abs']:.6f} mean_abs={entry['mean_abs']:.6f}" ) return {"ok": True, "results": results}