""" NVFP4 text encoder loader for diffusers image pipelines. Loads a compressed-tensors NVFP4-pack-quantized HuggingFace causal LM and wraps it so it can be plugged into ``diffusers.ZImagePipeline`` (or any pipeline calling ``self.text_encoder(input_ids, attention_mask, output_hidden_states=True)``). Strategy: - Instantiate the HF model on the ``meta`` device (no real allocation). - Walk every ``torch.nn.Linear`` and swap it for vLLM's ``ReplicatedLinear`` with ``CompressedTensorsConfig`` derived from the checkpoint's ``quantization_config``. This registers ``weight_packed`` / ``weight_scale`` / ``*_global_scale`` parameters in the exact layout vLLM's ``CompressedTensorsW4A4Fp4`` scheme expects. - Materialise remaining (non-Linear) parameters (embeddings, RMSNorm, k/q norms) on the target device & dtype. - Stream the safetensors file and dispatch each tensor through the registered vLLM ``weight_loader`` (which handles layout swizzling on ``process_weights_after_loading``). - Tie the LM head to the input embedding when ``config.tie_word_embeddings``. The result is a regular ``nn.Module`` matching the HF model's call signature (``forward(input_ids, attention_mask, output_hidden_states)``) -- usable directly as ``ZImagePipeline.text_encoder``. vLLM requires a minimal global context (distributed process group + model parallel state + active VllmConfig) even at TP=1 because ``ReplicatedLinear`` queries the TP world size at construction. We bootstrap that lazily once. Forced kernel: we set ``VLLM_NVFP4_GEMM_BACKEND=cutlass`` to skip flashinfer-cutlass JIT (which needs the ``ninja`` binary on PATH). The vLLM CUTLASS kernel is built into the wheel. """ from __future__ import annotations import json import os from collections.abc import Iterator from typing import Optional import torch import torch.nn as nn # ---------------------------------------------------------------------------- # One-time vLLM bootstrap (TP=1, no engine, just enough context for ReplicatedLinear) # ---------------------------------------------------------------------------- _VLLM_BOOTSTRAPPED = False _VLLM_CONFIG_CTX = None # holds the entered set_current_vllm_config context manager def _bootstrap_vllm_once() -> None: """Initialise the bits of vLLM that ReplicatedLinear needs at TP=1. Idempotent. Uses ``gloo`` so it works without NCCL/CUDA-aware MPI and even when CUDA is busy with the diffusion transformer. """ global _VLLM_BOOTSTRAPPED, _VLLM_CONFIG_CTX if _VLLM_BOOTSTRAPPED: return # Force CUTLASS to avoid flashinfer-cutlass JIT (requires `ninja` on PATH). os.environ.setdefault("VLLM_NVFP4_GEMM_BACKEND", "cutlass") from vllm.config import VllmConfig from vllm.config.vllm import set_current_vllm_config from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, ) # Pick a free port; world_size=1. import socket s = socket.socket() s.bind(("127.0.0.1", 0)) port = s.getsockname()[1] s.close() if not torch.distributed.is_initialized(): init_distributed_environment( world_size=1, rank=0, local_rank=0, distributed_init_method=f"tcp://127.0.0.1:{port}", backend="gloo", ) # Enter a long-lived VllmConfig context. We never exit it -- the encoder # may construct submodules lazily and ReplicatedLinear calls # get_current_vllm_config() at init. vc = VllmConfig() _VLLM_CONFIG_CTX = set_current_vllm_config(vc) _VLLM_CONFIG_CTX.__enter__() ensure_model_parallel_initialized(1, 1) _VLLM_BOOTSTRAPPED = True # ---------------------------------------------------------------------------- # Module: linear replacement # ---------------------------------------------------------------------------- def _replace_linears_with_replicated( model: nn.Module, quant_config ) -> None: """Recursively swap every ``nn.Linear`` for vLLM ``ReplicatedLinear``. Carries the ``prefix`` so quant_config's ``ignore`` patterns (e.g. ``lm_head``) are correctly applied. """ from vllm.model_executor.layers.linear import ReplicatedLinear def _walk(parent: nn.Module, prefix: str) -> None: for child_name, child in list(parent.named_children()): qname = f"{prefix}.{child_name}" if prefix else child_name if isinstance(child, nn.Linear): new = ReplicatedLinear( input_size=child.in_features, output_size=child.out_features, bias=child.bias is not None, quant_config=quant_config, prefix=qname, return_bias=False, params_dtype=torch.bfloat16, ) setattr(parent, child_name, new) else: _walk(child, qname) _walk(model, prefix="") def _materialize_remaining_meta_params( model: nn.Module, dtype: torch.dtype, device: torch.device ) -> None: """Replace any ``meta`` parameter with empty real storage. Only touches parameters NOT already created on a real device by the ReplicatedLinear swap above (i.e. embeddings, layernorms, biases). """ for name, param in list(model.named_parameters(recurse=True)): if param.device.type == "meta": real = nn.Parameter( torch.empty(param.shape, dtype=dtype, device=device), requires_grad=False, ) # Replace in the parent module parent = model *path, leaf = name.split(".") for p in path: parent = getattr(parent, p) setattr(parent, leaf, real) # Same for buffers (e.g. rotary inv_freq if registered as buffer on meta) for name, buf in list(model.named_buffers(recurse=True)): if buf.device.type == "meta": real = torch.empty(buf.shape, dtype=buf.dtype, device=device) parent = model *path, leaf = name.split(".") for p in path: parent = getattr(parent, p) parent.register_buffer(leaf, real, persistent=False) # ---------------------------------------------------------------------------- # Weight loading # ---------------------------------------------------------------------------- def _iter_safetensors(model_dir: str) -> Iterator[tuple[str, torch.Tensor]]: """Yield (name, tensor) pairs from all *.safetensors shards in ``model_dir``.""" from safetensors import safe_open # Single-file checkpoint or sharded? Prefer ``model.safetensors.index.json``. index_path = os.path.join(model_dir, "model.safetensors.index.json") if os.path.exists(index_path): with open(index_path) as f: index = json.load(f) shards = sorted(set(index["weight_map"].values())) else: # Find all *.safetensors files in dir shards = sorted( fn for fn in os.listdir(model_dir) if fn.endswith(".safetensors") ) for shard in shards: path = os.path.join(model_dir, shard) with safe_open(path, framework="pt") as f: for key in f.keys(): yield key, f.get_tensor(key) def _load_weights_into_model(model: nn.Module, model_dir: str) -> None: """Stream safetensors into the (already-structured) model. Uses each ReplicatedLinear's registered ``weight_loader`` for quantised params (which handles tensor-parallel sharding, even though TP=1 here it keeps casts consistent). Other params (embeddings, layernorms, biases) are copied directly. """ # Strip vllm-omni-style "text_encoder." prefix if present; not applicable # here since we load the standalone HF Qwen3 checkpoint where keys start # with "model.layers..." / "model.embed_tokens..." / "lm_head...". name_to_param: dict[str, nn.Parameter] = dict(model.named_parameters(recurse=True)) name_to_buffer: dict[str, torch.Tensor] = dict(model.named_buffers(recurse=True)) missing = set(name_to_param.keys()) unexpected = [] for key, tensor in _iter_safetensors(model_dir): # Skip rotary inv_freq etc that aren't params (rare in modern HF saves) if key in name_to_param: param = name_to_param[key] wl = getattr(param, "weight_loader", None) if wl is not None: wl(param, tensor.to(param.device)) else: with torch.no_grad(): param.data.copy_(tensor.to(param.device, dtype=param.dtype)) missing.discard(key) elif key in name_to_buffer: with torch.no_grad(): name_to_buffer[key].copy_(tensor.to(name_to_buffer[key].device)) else: unexpected.append(key) # Tied embeddings (lm_head.weight not in checkpoint when tie_word_embeddings=True) cfg = getattr(model, "config", None) if cfg is not None and getattr(cfg, "tie_word_embeddings", False): try: inp_emb = model.get_input_embeddings().weight model.lm_head.weight = inp_emb # share storage missing.discard("lm_head.weight") except Exception: pass if missing: # It's OK if missing entries are *purely* lm_head.weight when tied; we # already handled that above. Anything else is fatal-ish. leftover = sorted(missing) if leftover: print( f"[NVFP4TextEncoder] WARN: {len(leftover)} params missing from checkpoint; " f"first 5: {leftover[:5]}" ) if unexpected: print( f"[NVFP4TextEncoder] WARN: {len(unexpected)} keys in checkpoint unused; " f"first 5: {unexpected[:5]}" ) def _process_weights_after_loading(model: nn.Module) -> None: """Invoke vLLM's per-layer ``process_weights_after_loading`` for each ReplicatedLinear (renames ``weight_packed`` -> ``weight``, computes ``alpha``, swizzles scales for the CUTLASS kernel, etc.).""" for module in model.modules(): qm = getattr(module, "quant_method", None) if qm is not None and hasattr(qm, "process_weights_after_loading"): qm.process_weights_after_loading(module) # ---------------------------------------------------------------------------- # Public API # ---------------------------------------------------------------------------- def load_nvfp4_text_encoder( model_dir: str, device: str | torch.device = "cuda", dtype: torch.dtype = torch.bfloat16, ) -> nn.Module: """Load an NVFP4-quantised HuggingFace causal LM as a plug-in text encoder. Args: model_dir: path to the checkpoint directory containing ``config.json`` and ``model*.safetensors``. The config must carry a ``quantization_config`` block with ``"format": "nvfp4-pack-quantized"``. device: target CUDA device (forwards to ``model.to(device)``-equivalent during materialisation). dtype: activation / non-quantised-param dtype. Returns: A ``PreTrainedModel`` whose ``Linear`` layers are NVFP4 inside the vLLM CUTLASS kernel. Activations flow as ``dtype``. """ _bootstrap_vllm_once() from transformers import AutoConfig, AutoModelForCausalLM from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, ) from vllm.model_executor.models.transformers.utils import ( init_on_device_without_buffers, ) hf_config = AutoConfig.from_pretrained(model_dir, local_files_only=True) if not getattr(hf_config, "quantization_config", None): raise ValueError( f"{model_dir}/config.json has no `quantization_config`; " "this loader only handles NVFP4-quantised checkpoints." ) quant_config = CompressedTensorsConfig.from_config(hf_config.quantization_config) # 1) Build the model skeleton on meta (zero allocation). with init_on_device_without_buffers("meta"): model = AutoModelForCausalLM.from_config(hf_config) # 2) Swap Linear -> ReplicatedLinear(quant_config) (creates real CUDA params # of the quantised shapes). target_device = torch.device(device) _replace_linears_with_replicated(model, quant_config) # 3) Materialise any leftover meta parameters (embeddings, RMSNorms, ...) _materialize_remaining_meta_params(model, dtype=dtype, device=target_device) # 4) Move newly-created quantised params to target device (ReplicatedLinear # creates them on the current default device which is usually CPU). model.to(target_device) # 5) Load weights via per-param weight_loader. _load_weights_into_model(model, model_dir) # 6) Let vLLM swizzle scales / rename weight_packed->weight / compute alpha. _process_weights_after_loading(model) # 7) Match HF semantics for downstream pipelines. model.eval() model.config.use_cache = False return model