| """ |
| NVFP4 text encoder loader for diffusers image pipelines. |
| |
| Loads a compressed-tensors NVFP4-pack-quantized HuggingFace causal LM and wraps |
| it so it can be plugged into ``diffusers.ZImagePipeline`` (or any pipeline |
| calling ``self.text_encoder(input_ids, attention_mask, output_hidden_states=True)``). |
| |
| Strategy: |
| - Instantiate the HF model on the ``meta`` device (no real allocation). |
| - Walk every ``torch.nn.Linear`` and swap it for vLLM's ``ReplicatedLinear`` with |
| ``CompressedTensorsConfig`` derived from the checkpoint's |
| ``quantization_config``. This registers ``weight_packed`` / ``weight_scale`` / |
| ``*_global_scale`` parameters in the exact layout vLLM's |
| ``CompressedTensorsW4A4Fp4`` scheme expects. |
| - Materialise remaining (non-Linear) parameters (embeddings, RMSNorm, k/q norms) |
| on the target device & dtype. |
| - Stream the safetensors file and dispatch each tensor through the registered |
| vLLM ``weight_loader`` (which handles layout swizzling on |
| ``process_weights_after_loading``). |
| - Tie the LM head to the input embedding when ``config.tie_word_embeddings``. |
| |
| The result is a regular ``nn.Module`` matching the HF model's call signature |
| (``forward(input_ids, attention_mask, output_hidden_states)``) -- usable directly |
| as ``ZImagePipeline.text_encoder``. |
| |
| vLLM requires a minimal global context (distributed process group + model |
| parallel state + active VllmConfig) even at TP=1 because ``ReplicatedLinear`` |
| queries the TP world size at construction. We bootstrap that lazily once. |
| |
| Forced kernel: we set ``VLLM_NVFP4_GEMM_BACKEND=cutlass`` to skip |
| flashinfer-cutlass JIT (which needs the ``ninja`` binary on PATH). The vLLM |
| CUTLASS kernel is built into the wheel. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| from collections.abc import Iterator |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| |
| |
| |
| _VLLM_BOOTSTRAPPED = False |
| _VLLM_CONFIG_CTX = None |
|
|
|
|
| def _bootstrap_vllm_once() -> None: |
| """Initialise the bits of vLLM that ReplicatedLinear needs at TP=1. |
| |
| Idempotent. Uses ``gloo`` so it works without NCCL/CUDA-aware MPI and even |
| when CUDA is busy with the diffusion transformer. |
| """ |
| global _VLLM_BOOTSTRAPPED, _VLLM_CONFIG_CTX |
| if _VLLM_BOOTSTRAPPED: |
| return |
|
|
| |
| os.environ.setdefault("VLLM_NVFP4_GEMM_BACKEND", "cutlass") |
|
|
| from vllm.config import VllmConfig |
| from vllm.config.vllm import set_current_vllm_config |
| from vllm.distributed import ( |
| ensure_model_parallel_initialized, |
| init_distributed_environment, |
| ) |
|
|
| |
| import socket |
|
|
| s = socket.socket() |
| s.bind(("127.0.0.1", 0)) |
| port = s.getsockname()[1] |
| s.close() |
|
|
| if not torch.distributed.is_initialized(): |
| init_distributed_environment( |
| world_size=1, |
| rank=0, |
| local_rank=0, |
| distributed_init_method=f"tcp://127.0.0.1:{port}", |
| backend="gloo", |
| ) |
|
|
| |
| |
| |
| vc = VllmConfig() |
| _VLLM_CONFIG_CTX = set_current_vllm_config(vc) |
| _VLLM_CONFIG_CTX.__enter__() |
|
|
| ensure_model_parallel_initialized(1, 1) |
| _VLLM_BOOTSTRAPPED = True |
|
|
|
|
| |
| |
| |
| def _replace_linears_with_replicated( |
| model: nn.Module, quant_config |
| ) -> None: |
| """Recursively swap every ``nn.Linear`` for vLLM ``ReplicatedLinear``. |
| |
| Carries the ``prefix`` so quant_config's ``ignore`` patterns (e.g. ``lm_head``) |
| are correctly applied. |
| """ |
| from vllm.model_executor.layers.linear import ReplicatedLinear |
|
|
| def _walk(parent: nn.Module, prefix: str) -> None: |
| for child_name, child in list(parent.named_children()): |
| qname = f"{prefix}.{child_name}" if prefix else child_name |
| if isinstance(child, nn.Linear): |
| new = ReplicatedLinear( |
| input_size=child.in_features, |
| output_size=child.out_features, |
| bias=child.bias is not None, |
| quant_config=quant_config, |
| prefix=qname, |
| return_bias=False, |
| params_dtype=torch.bfloat16, |
| ) |
| setattr(parent, child_name, new) |
| else: |
| _walk(child, qname) |
|
|
| _walk(model, prefix="") |
|
|
|
|
| def _materialize_remaining_meta_params( |
| model: nn.Module, dtype: torch.dtype, device: torch.device |
| ) -> None: |
| """Replace any ``meta`` parameter with empty real storage. |
| |
| Only touches parameters NOT already created on a real device by the |
| ReplicatedLinear swap above (i.e. embeddings, layernorms, biases). |
| """ |
| for name, param in list(model.named_parameters(recurse=True)): |
| if param.device.type == "meta": |
| real = nn.Parameter( |
| torch.empty(param.shape, dtype=dtype, device=device), |
| requires_grad=False, |
| ) |
| |
| parent = model |
| *path, leaf = name.split(".") |
| for p in path: |
| parent = getattr(parent, p) |
| setattr(parent, leaf, real) |
| |
| for name, buf in list(model.named_buffers(recurse=True)): |
| if buf.device.type == "meta": |
| real = torch.empty(buf.shape, dtype=buf.dtype, device=device) |
| parent = model |
| *path, leaf = name.split(".") |
| for p in path: |
| parent = getattr(parent, p) |
| parent.register_buffer(leaf, real, persistent=False) |
|
|
|
|
| |
| |
| |
| def _iter_safetensors(model_dir: str) -> Iterator[tuple[str, torch.Tensor]]: |
| """Yield (name, tensor) pairs from all *.safetensors shards in ``model_dir``.""" |
| from safetensors import safe_open |
|
|
| |
| index_path = os.path.join(model_dir, "model.safetensors.index.json") |
| if os.path.exists(index_path): |
| with open(index_path) as f: |
| index = json.load(f) |
| shards = sorted(set(index["weight_map"].values())) |
| else: |
| |
| shards = sorted( |
| fn for fn in os.listdir(model_dir) if fn.endswith(".safetensors") |
| ) |
| for shard in shards: |
| path = os.path.join(model_dir, shard) |
| with safe_open(path, framework="pt") as f: |
| for key in f.keys(): |
| yield key, f.get_tensor(key) |
|
|
|
|
| def _load_weights_into_model(model: nn.Module, model_dir: str) -> None: |
| """Stream safetensors into the (already-structured) model. |
| |
| Uses each ReplicatedLinear's registered ``weight_loader`` for quantised |
| params (which handles tensor-parallel sharding, even though TP=1 here it |
| keeps casts consistent). Other params (embeddings, layernorms, biases) are |
| copied directly. |
| """ |
| |
| |
| |
| name_to_param: dict[str, nn.Parameter] = dict(model.named_parameters(recurse=True)) |
| name_to_buffer: dict[str, torch.Tensor] = dict(model.named_buffers(recurse=True)) |
|
|
| missing = set(name_to_param.keys()) |
| unexpected = [] |
|
|
| for key, tensor in _iter_safetensors(model_dir): |
| |
| if key in name_to_param: |
| param = name_to_param[key] |
| wl = getattr(param, "weight_loader", None) |
| if wl is not None: |
| wl(param, tensor.to(param.device)) |
| else: |
| with torch.no_grad(): |
| param.data.copy_(tensor.to(param.device, dtype=param.dtype)) |
| missing.discard(key) |
| elif key in name_to_buffer: |
| with torch.no_grad(): |
| name_to_buffer[key].copy_(tensor.to(name_to_buffer[key].device)) |
| else: |
| unexpected.append(key) |
|
|
| |
| cfg = getattr(model, "config", None) |
| if cfg is not None and getattr(cfg, "tie_word_embeddings", False): |
| try: |
| inp_emb = model.get_input_embeddings().weight |
| model.lm_head.weight = inp_emb |
| missing.discard("lm_head.weight") |
| except Exception: |
| pass |
|
|
| if missing: |
| |
| |
| leftover = sorted(missing) |
| if leftover: |
| print( |
| f"[NVFP4TextEncoder] WARN: {len(leftover)} params missing from checkpoint; " |
| f"first 5: {leftover[:5]}" |
| ) |
| if unexpected: |
| print( |
| f"[NVFP4TextEncoder] WARN: {len(unexpected)} keys in checkpoint unused; " |
| f"first 5: {unexpected[:5]}" |
| ) |
|
|
|
|
| def _process_weights_after_loading(model: nn.Module) -> None: |
| """Invoke vLLM's per-layer ``process_weights_after_loading`` for each |
| ReplicatedLinear (renames ``weight_packed`` -> ``weight``, computes ``alpha``, |
| swizzles scales for the CUTLASS kernel, etc.).""" |
| for module in model.modules(): |
| qm = getattr(module, "quant_method", None) |
| if qm is not None and hasattr(qm, "process_weights_after_loading"): |
| qm.process_weights_after_loading(module) |
|
|
|
|
| |
| |
| |
| def load_nvfp4_text_encoder( |
| model_dir: str, |
| device: str | torch.device = "cuda", |
| dtype: torch.dtype = torch.bfloat16, |
| ) -> nn.Module: |
| """Load an NVFP4-quantised HuggingFace causal LM as a plug-in text encoder. |
| |
| Args: |
| model_dir: path to the checkpoint directory containing ``config.json`` |
| and ``model*.safetensors``. The config must carry a |
| ``quantization_config`` block with ``"format": "nvfp4-pack-quantized"``. |
| device: target CUDA device (forwards to ``model.to(device)``-equivalent |
| during materialisation). |
| dtype: activation / non-quantised-param dtype. |
| |
| Returns: |
| A ``PreTrainedModel`` whose ``Linear`` layers are NVFP4 inside the vLLM |
| CUTLASS kernel. Activations flow as ``dtype``. |
| """ |
| _bootstrap_vllm_once() |
|
|
| from transformers import AutoConfig, AutoModelForCausalLM |
| from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( |
| CompressedTensorsConfig, |
| ) |
| from vllm.model_executor.models.transformers.utils import ( |
| init_on_device_without_buffers, |
| ) |
|
|
| hf_config = AutoConfig.from_pretrained(model_dir, local_files_only=True) |
| if not getattr(hf_config, "quantization_config", None): |
| raise ValueError( |
| f"{model_dir}/config.json has no `quantization_config`; " |
| "this loader only handles NVFP4-quantised checkpoints." |
| ) |
| quant_config = CompressedTensorsConfig.from_config(hf_config.quantization_config) |
|
|
| |
| with init_on_device_without_buffers("meta"): |
| model = AutoModelForCausalLM.from_config(hf_config) |
|
|
| |
| |
| target_device = torch.device(device) |
| _replace_linears_with_replicated(model, quant_config) |
|
|
| |
| _materialize_remaining_meta_params(model, dtype=dtype, device=target_device) |
|
|
| |
| |
| model.to(target_device) |
|
|
| |
| _load_weights_into_model(model, model_dir) |
|
|
| |
| _process_weights_after_loading(model) |
|
|
| |
| model.eval() |
| model.config.use_cache = False |
| return model |
|
|