catplusplus's picture
Upload folder using huggingface_hub
1e103b7 verified
raw
history blame
13.1 kB
"""
NVFP4 text encoder loader for diffusers image pipelines.
Loads a compressed-tensors NVFP4-pack-quantized HuggingFace causal LM and wraps
it so it can be plugged into ``diffusers.ZImagePipeline`` (or any pipeline
calling ``self.text_encoder(input_ids, attention_mask, output_hidden_states=True)``).
Strategy:
- Instantiate the HF model on the ``meta`` device (no real allocation).
- Walk every ``torch.nn.Linear`` and swap it for vLLM's ``ReplicatedLinear`` with
``CompressedTensorsConfig`` derived from the checkpoint's
``quantization_config``. This registers ``weight_packed`` / ``weight_scale`` /
``*_global_scale`` parameters in the exact layout vLLM's
``CompressedTensorsW4A4Fp4`` scheme expects.
- Materialise remaining (non-Linear) parameters (embeddings, RMSNorm, k/q norms)
on the target device & dtype.
- Stream the safetensors file and dispatch each tensor through the registered
vLLM ``weight_loader`` (which handles layout swizzling on
``process_weights_after_loading``).
- Tie the LM head to the input embedding when ``config.tie_word_embeddings``.
The result is a regular ``nn.Module`` matching the HF model's call signature
(``forward(input_ids, attention_mask, output_hidden_states)``) -- usable directly
as ``ZImagePipeline.text_encoder``.
vLLM requires a minimal global context (distributed process group + model
parallel state + active VllmConfig) even at TP=1 because ``ReplicatedLinear``
queries the TP world size at construction. We bootstrap that lazily once.
Forced kernel: we set ``VLLM_NVFP4_GEMM_BACKEND=cutlass`` to skip
flashinfer-cutlass JIT (which needs the ``ninja`` binary on PATH). The vLLM
CUTLASS kernel is built into the wheel.
"""
from __future__ import annotations
import json
import os
from collections.abc import Iterator
from typing import Optional
import torch
import torch.nn as nn
# ----------------------------------------------------------------------------
# One-time vLLM bootstrap (TP=1, no engine, just enough context for ReplicatedLinear)
# ----------------------------------------------------------------------------
_VLLM_BOOTSTRAPPED = False
_VLLM_CONFIG_CTX = None # holds the entered set_current_vllm_config context manager
def _bootstrap_vllm_once() -> None:
"""Initialise the bits of vLLM that ReplicatedLinear needs at TP=1.
Idempotent. Uses ``gloo`` so it works without NCCL/CUDA-aware MPI and even
when CUDA is busy with the diffusion transformer.
"""
global _VLLM_BOOTSTRAPPED, _VLLM_CONFIG_CTX
if _VLLM_BOOTSTRAPPED:
return
# Force CUTLASS to avoid flashinfer-cutlass JIT (requires `ninja` on PATH).
os.environ.setdefault("VLLM_NVFP4_GEMM_BACKEND", "cutlass")
from vllm.config import VllmConfig
from vllm.config.vllm import set_current_vllm_config
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
# Pick a free port; world_size=1.
import socket
s = socket.socket()
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
if not torch.distributed.is_initialized():
init_distributed_environment(
world_size=1,
rank=0,
local_rank=0,
distributed_init_method=f"tcp://127.0.0.1:{port}",
backend="gloo",
)
# Enter a long-lived VllmConfig context. We never exit it -- the encoder
# may construct submodules lazily and ReplicatedLinear calls
# get_current_vllm_config() at init.
vc = VllmConfig()
_VLLM_CONFIG_CTX = set_current_vllm_config(vc)
_VLLM_CONFIG_CTX.__enter__()
ensure_model_parallel_initialized(1, 1)
_VLLM_BOOTSTRAPPED = True
# ----------------------------------------------------------------------------
# Module: linear replacement
# ----------------------------------------------------------------------------
def _replace_linears_with_replicated(
model: nn.Module, quant_config
) -> None:
"""Recursively swap every ``nn.Linear`` for vLLM ``ReplicatedLinear``.
Carries the ``prefix`` so quant_config's ``ignore`` patterns (e.g. ``lm_head``)
are correctly applied.
"""
from vllm.model_executor.layers.linear import ReplicatedLinear
def _walk(parent: nn.Module, prefix: str) -> None:
for child_name, child in list(parent.named_children()):
qname = f"{prefix}.{child_name}" if prefix else child_name
if isinstance(child, nn.Linear):
new = ReplicatedLinear(
input_size=child.in_features,
output_size=child.out_features,
bias=child.bias is not None,
quant_config=quant_config,
prefix=qname,
return_bias=False,
params_dtype=torch.bfloat16,
)
setattr(parent, child_name, new)
else:
_walk(child, qname)
_walk(model, prefix="")
def _materialize_remaining_meta_params(
model: nn.Module, dtype: torch.dtype, device: torch.device
) -> None:
"""Replace any ``meta`` parameter with empty real storage.
Only touches parameters NOT already created on a real device by the
ReplicatedLinear swap above (i.e. embeddings, layernorms, biases).
"""
for name, param in list(model.named_parameters(recurse=True)):
if param.device.type == "meta":
real = nn.Parameter(
torch.empty(param.shape, dtype=dtype, device=device),
requires_grad=False,
)
# Replace in the parent module
parent = model
*path, leaf = name.split(".")
for p in path:
parent = getattr(parent, p)
setattr(parent, leaf, real)
# Same for buffers (e.g. rotary inv_freq if registered as buffer on meta)
for name, buf in list(model.named_buffers(recurse=True)):
if buf.device.type == "meta":
real = torch.empty(buf.shape, dtype=buf.dtype, device=device)
parent = model
*path, leaf = name.split(".")
for p in path:
parent = getattr(parent, p)
parent.register_buffer(leaf, real, persistent=False)
# ----------------------------------------------------------------------------
# Weight loading
# ----------------------------------------------------------------------------
def _iter_safetensors(model_dir: str) -> Iterator[tuple[str, torch.Tensor]]:
"""Yield (name, tensor) pairs from all *.safetensors shards in ``model_dir``."""
from safetensors import safe_open
# Single-file checkpoint or sharded? Prefer ``model.safetensors.index.json``.
index_path = os.path.join(model_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
with open(index_path) as f:
index = json.load(f)
shards = sorted(set(index["weight_map"].values()))
else:
# Find all *.safetensors files in dir
shards = sorted(
fn for fn in os.listdir(model_dir) if fn.endswith(".safetensors")
)
for shard in shards:
path = os.path.join(model_dir, shard)
with safe_open(path, framework="pt") as f:
for key in f.keys():
yield key, f.get_tensor(key)
def _load_weights_into_model(model: nn.Module, model_dir: str) -> None:
"""Stream safetensors into the (already-structured) model.
Uses each ReplicatedLinear's registered ``weight_loader`` for quantised
params (which handles tensor-parallel sharding, even though TP=1 here it
keeps casts consistent). Other params (embeddings, layernorms, biases) are
copied directly.
"""
# Strip vllm-omni-style "text_encoder." prefix if present; not applicable
# here since we load the standalone HF Qwen3 checkpoint where keys start
# with "model.layers..." / "model.embed_tokens..." / "lm_head...".
name_to_param: dict[str, nn.Parameter] = dict(model.named_parameters(recurse=True))
name_to_buffer: dict[str, torch.Tensor] = dict(model.named_buffers(recurse=True))
missing = set(name_to_param.keys())
unexpected = []
for key, tensor in _iter_safetensors(model_dir):
# Skip rotary inv_freq etc that aren't params (rare in modern HF saves)
if key in name_to_param:
param = name_to_param[key]
wl = getattr(param, "weight_loader", None)
if wl is not None:
wl(param, tensor.to(param.device))
else:
with torch.no_grad():
param.data.copy_(tensor.to(param.device, dtype=param.dtype))
missing.discard(key)
elif key in name_to_buffer:
with torch.no_grad():
name_to_buffer[key].copy_(tensor.to(name_to_buffer[key].device))
else:
unexpected.append(key)
# Tied embeddings (lm_head.weight not in checkpoint when tie_word_embeddings=True)
cfg = getattr(model, "config", None)
if cfg is not None and getattr(cfg, "tie_word_embeddings", False):
try:
inp_emb = model.get_input_embeddings().weight
model.lm_head.weight = inp_emb # share storage
missing.discard("lm_head.weight")
except Exception:
pass
if missing:
# It's OK if missing entries are *purely* lm_head.weight when tied; we
# already handled that above. Anything else is fatal-ish.
leftover = sorted(missing)
if leftover:
print(
f"[NVFP4TextEncoder] WARN: {len(leftover)} params missing from checkpoint; "
f"first 5: {leftover[:5]}"
)
if unexpected:
print(
f"[NVFP4TextEncoder] WARN: {len(unexpected)} keys in checkpoint unused; "
f"first 5: {unexpected[:5]}"
)
def _process_weights_after_loading(model: nn.Module) -> None:
"""Invoke vLLM's per-layer ``process_weights_after_loading`` for each
ReplicatedLinear (renames ``weight_packed`` -> ``weight``, computes ``alpha``,
swizzles scales for the CUTLASS kernel, etc.)."""
for module in model.modules():
qm = getattr(module, "quant_method", None)
if qm is not None and hasattr(qm, "process_weights_after_loading"):
qm.process_weights_after_loading(module)
# ----------------------------------------------------------------------------
# Public API
# ----------------------------------------------------------------------------
def load_nvfp4_text_encoder(
model_dir: str,
device: str | torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16,
) -> nn.Module:
"""Load an NVFP4-quantised HuggingFace causal LM as a plug-in text encoder.
Args:
model_dir: path to the checkpoint directory containing ``config.json``
and ``model*.safetensors``. The config must carry a
``quantization_config`` block with ``"format": "nvfp4-pack-quantized"``.
device: target CUDA device (forwards to ``model.to(device)``-equivalent
during materialisation).
dtype: activation / non-quantised-param dtype.
Returns:
A ``PreTrainedModel`` whose ``Linear`` layers are NVFP4 inside the vLLM
CUTLASS kernel. Activations flow as ``dtype``.
"""
_bootstrap_vllm_once()
from transformers import AutoConfig, AutoModelForCausalLM
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from vllm.model_executor.models.transformers.utils import (
init_on_device_without_buffers,
)
hf_config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
if not getattr(hf_config, "quantization_config", None):
raise ValueError(
f"{model_dir}/config.json has no `quantization_config`; "
"this loader only handles NVFP4-quantised checkpoints."
)
quant_config = CompressedTensorsConfig.from_config(hf_config.quantization_config)
# 1) Build the model skeleton on meta (zero allocation).
with init_on_device_without_buffers("meta"):
model = AutoModelForCausalLM.from_config(hf_config)
# 2) Swap Linear -> ReplicatedLinear(quant_config) (creates real CUDA params
# of the quantised shapes).
target_device = torch.device(device)
_replace_linears_with_replicated(model, quant_config)
# 3) Materialise any leftover meta parameters (embeddings, RMSNorms, ...)
_materialize_remaining_meta_params(model, dtype=dtype, device=target_device)
# 4) Move newly-created quantised params to target device (ReplicatedLinear
# creates them on the current default device which is usually CPU).
model.to(target_device)
# 5) Load weights via per-param weight_loader.
_load_weights_into_model(model, model_dir)
# 6) Let vLLM swizzle scales / rename weight_packed->weight / compute alpha.
_process_weights_after_loading(model)
# 7) Match HF semantics for downstream pipelines.
model.eval()
model.config.use_cache = False
return model