File size: 13,132 Bytes
1e103b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 | """
NVFP4 text encoder loader for diffusers image pipelines.
Loads a compressed-tensors NVFP4-pack-quantized HuggingFace causal LM and wraps
it so it can be plugged into ``diffusers.ZImagePipeline`` (or any pipeline
calling ``self.text_encoder(input_ids, attention_mask, output_hidden_states=True)``).
Strategy:
- Instantiate the HF model on the ``meta`` device (no real allocation).
- Walk every ``torch.nn.Linear`` and swap it for vLLM's ``ReplicatedLinear`` with
``CompressedTensorsConfig`` derived from the checkpoint's
``quantization_config``. This registers ``weight_packed`` / ``weight_scale`` /
``*_global_scale`` parameters in the exact layout vLLM's
``CompressedTensorsW4A4Fp4`` scheme expects.
- Materialise remaining (non-Linear) parameters (embeddings, RMSNorm, k/q norms)
on the target device & dtype.
- Stream the safetensors file and dispatch each tensor through the registered
vLLM ``weight_loader`` (which handles layout swizzling on
``process_weights_after_loading``).
- Tie the LM head to the input embedding when ``config.tie_word_embeddings``.
The result is a regular ``nn.Module`` matching the HF model's call signature
(``forward(input_ids, attention_mask, output_hidden_states)``) -- usable directly
as ``ZImagePipeline.text_encoder``.
vLLM requires a minimal global context (distributed process group + model
parallel state + active VllmConfig) even at TP=1 because ``ReplicatedLinear``
queries the TP world size at construction. We bootstrap that lazily once.
Forced kernel: we set ``VLLM_NVFP4_GEMM_BACKEND=cutlass`` to skip
flashinfer-cutlass JIT (which needs the ``ninja`` binary on PATH). The vLLM
CUTLASS kernel is built into the wheel.
"""
from __future__ import annotations
import json
import os
from collections.abc import Iterator
from typing import Optional
import torch
import torch.nn as nn
# ----------------------------------------------------------------------------
# One-time vLLM bootstrap (TP=1, no engine, just enough context for ReplicatedLinear)
# ----------------------------------------------------------------------------
_VLLM_BOOTSTRAPPED = False
_VLLM_CONFIG_CTX = None # holds the entered set_current_vllm_config context manager
def _bootstrap_vllm_once() -> None:
"""Initialise the bits of vLLM that ReplicatedLinear needs at TP=1.
Idempotent. Uses ``gloo`` so it works without NCCL/CUDA-aware MPI and even
when CUDA is busy with the diffusion transformer.
"""
global _VLLM_BOOTSTRAPPED, _VLLM_CONFIG_CTX
if _VLLM_BOOTSTRAPPED:
return
# Force CUTLASS to avoid flashinfer-cutlass JIT (requires `ninja` on PATH).
os.environ.setdefault("VLLM_NVFP4_GEMM_BACKEND", "cutlass")
from vllm.config import VllmConfig
from vllm.config.vllm import set_current_vllm_config
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
# Pick a free port; world_size=1.
import socket
s = socket.socket()
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
if not torch.distributed.is_initialized():
init_distributed_environment(
world_size=1,
rank=0,
local_rank=0,
distributed_init_method=f"tcp://127.0.0.1:{port}",
backend="gloo",
)
# Enter a long-lived VllmConfig context. We never exit it -- the encoder
# may construct submodules lazily and ReplicatedLinear calls
# get_current_vllm_config() at init.
vc = VllmConfig()
_VLLM_CONFIG_CTX = set_current_vllm_config(vc)
_VLLM_CONFIG_CTX.__enter__()
ensure_model_parallel_initialized(1, 1)
_VLLM_BOOTSTRAPPED = True
# ----------------------------------------------------------------------------
# Module: linear replacement
# ----------------------------------------------------------------------------
def _replace_linears_with_replicated(
model: nn.Module, quant_config
) -> None:
"""Recursively swap every ``nn.Linear`` for vLLM ``ReplicatedLinear``.
Carries the ``prefix`` so quant_config's ``ignore`` patterns (e.g. ``lm_head``)
are correctly applied.
"""
from vllm.model_executor.layers.linear import ReplicatedLinear
def _walk(parent: nn.Module, prefix: str) -> None:
for child_name, child in list(parent.named_children()):
qname = f"{prefix}.{child_name}" if prefix else child_name
if isinstance(child, nn.Linear):
new = ReplicatedLinear(
input_size=child.in_features,
output_size=child.out_features,
bias=child.bias is not None,
quant_config=quant_config,
prefix=qname,
return_bias=False,
params_dtype=torch.bfloat16,
)
setattr(parent, child_name, new)
else:
_walk(child, qname)
_walk(model, prefix="")
def _materialize_remaining_meta_params(
model: nn.Module, dtype: torch.dtype, device: torch.device
) -> None:
"""Replace any ``meta`` parameter with empty real storage.
Only touches parameters NOT already created on a real device by the
ReplicatedLinear swap above (i.e. embeddings, layernorms, biases).
"""
for name, param in list(model.named_parameters(recurse=True)):
if param.device.type == "meta":
real = nn.Parameter(
torch.empty(param.shape, dtype=dtype, device=device),
requires_grad=False,
)
# Replace in the parent module
parent = model
*path, leaf = name.split(".")
for p in path:
parent = getattr(parent, p)
setattr(parent, leaf, real)
# Same for buffers (e.g. rotary inv_freq if registered as buffer on meta)
for name, buf in list(model.named_buffers(recurse=True)):
if buf.device.type == "meta":
real = torch.empty(buf.shape, dtype=buf.dtype, device=device)
parent = model
*path, leaf = name.split(".")
for p in path:
parent = getattr(parent, p)
parent.register_buffer(leaf, real, persistent=False)
# ----------------------------------------------------------------------------
# Weight loading
# ----------------------------------------------------------------------------
def _iter_safetensors(model_dir: str) -> Iterator[tuple[str, torch.Tensor]]:
"""Yield (name, tensor) pairs from all *.safetensors shards in ``model_dir``."""
from safetensors import safe_open
# Single-file checkpoint or sharded? Prefer ``model.safetensors.index.json``.
index_path = os.path.join(model_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
with open(index_path) as f:
index = json.load(f)
shards = sorted(set(index["weight_map"].values()))
else:
# Find all *.safetensors files in dir
shards = sorted(
fn for fn in os.listdir(model_dir) if fn.endswith(".safetensors")
)
for shard in shards:
path = os.path.join(model_dir, shard)
with safe_open(path, framework="pt") as f:
for key in f.keys():
yield key, f.get_tensor(key)
def _load_weights_into_model(model: nn.Module, model_dir: str) -> None:
"""Stream safetensors into the (already-structured) model.
Uses each ReplicatedLinear's registered ``weight_loader`` for quantised
params (which handles tensor-parallel sharding, even though TP=1 here it
keeps casts consistent). Other params (embeddings, layernorms, biases) are
copied directly.
"""
# Strip vllm-omni-style "text_encoder." prefix if present; not applicable
# here since we load the standalone HF Qwen3 checkpoint where keys start
# with "model.layers..." / "model.embed_tokens..." / "lm_head...".
name_to_param: dict[str, nn.Parameter] = dict(model.named_parameters(recurse=True))
name_to_buffer: dict[str, torch.Tensor] = dict(model.named_buffers(recurse=True))
missing = set(name_to_param.keys())
unexpected = []
for key, tensor in _iter_safetensors(model_dir):
# Skip rotary inv_freq etc that aren't params (rare in modern HF saves)
if key in name_to_param:
param = name_to_param[key]
wl = getattr(param, "weight_loader", None)
if wl is not None:
wl(param, tensor.to(param.device))
else:
with torch.no_grad():
param.data.copy_(tensor.to(param.device, dtype=param.dtype))
missing.discard(key)
elif key in name_to_buffer:
with torch.no_grad():
name_to_buffer[key].copy_(tensor.to(name_to_buffer[key].device))
else:
unexpected.append(key)
# Tied embeddings (lm_head.weight not in checkpoint when tie_word_embeddings=True)
cfg = getattr(model, "config", None)
if cfg is not None and getattr(cfg, "tie_word_embeddings", False):
try:
inp_emb = model.get_input_embeddings().weight
model.lm_head.weight = inp_emb # share storage
missing.discard("lm_head.weight")
except Exception:
pass
if missing:
# It's OK if missing entries are *purely* lm_head.weight when tied; we
# already handled that above. Anything else is fatal-ish.
leftover = sorted(missing)
if leftover:
print(
f"[NVFP4TextEncoder] WARN: {len(leftover)} params missing from checkpoint; "
f"first 5: {leftover[:5]}"
)
if unexpected:
print(
f"[NVFP4TextEncoder] WARN: {len(unexpected)} keys in checkpoint unused; "
f"first 5: {unexpected[:5]}"
)
def _process_weights_after_loading(model: nn.Module) -> None:
"""Invoke vLLM's per-layer ``process_weights_after_loading`` for each
ReplicatedLinear (renames ``weight_packed`` -> ``weight``, computes ``alpha``,
swizzles scales for the CUTLASS kernel, etc.)."""
for module in model.modules():
qm = getattr(module, "quant_method", None)
if qm is not None and hasattr(qm, "process_weights_after_loading"):
qm.process_weights_after_loading(module)
# ----------------------------------------------------------------------------
# Public API
# ----------------------------------------------------------------------------
def load_nvfp4_text_encoder(
model_dir: str,
device: str | torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16,
) -> nn.Module:
"""Load an NVFP4-quantised HuggingFace causal LM as a plug-in text encoder.
Args:
model_dir: path to the checkpoint directory containing ``config.json``
and ``model*.safetensors``. The config must carry a
``quantization_config`` block with ``"format": "nvfp4-pack-quantized"``.
device: target CUDA device (forwards to ``model.to(device)``-equivalent
during materialisation).
dtype: activation / non-quantised-param dtype.
Returns:
A ``PreTrainedModel`` whose ``Linear`` layers are NVFP4 inside the vLLM
CUTLASS kernel. Activations flow as ``dtype``.
"""
_bootstrap_vllm_once()
from transformers import AutoConfig, AutoModelForCausalLM
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.models.transformers.utils import (
init_on_device_without_buffers,
)
hf_config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
if not getattr(hf_config, "quantization_config", None):
raise ValueError(
f"{model_dir}/config.json has no `quantization_config`; "
"this loader only handles NVFP4-quantised checkpoints."
)
quant_config = CompressedTensorsConfig.from_config(hf_config.quantization_config)
# 1) Build the model skeleton on meta (zero allocation).
with init_on_device_without_buffers("meta"):
model = AutoModelForCausalLM.from_config(hf_config)
# 2) Swap Linear -> ReplicatedLinear(quant_config) (creates real CUDA params
# of the quantised shapes).
target_device = torch.device(device)
_replace_linears_with_replicated(model, quant_config)
# 3) Materialise any leftover meta parameters (embeddings, RMSNorms, ...)
_materialize_remaining_meta_params(model, dtype=dtype, device=target_device)
# 4) Move newly-created quantised params to target device (ReplicatedLinear
# creates them on the current default device which is usually CPU).
model.to(target_device)
# 5) Load weights via per-param weight_loader.
_load_weights_into_model(model, model_dir)
# 6) Let vLLM swizzle scales / rename weight_packed->weight / compute alpha.
_process_weights_after_loading(model)
# 7) Match HF semantics for downstream pipelines.
model.eval()
model.config.use_cache = False
return model
|