File size: 13,132 Bytes
1e103b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
"""
NVFP4 text encoder loader for diffusers image pipelines.

Loads a compressed-tensors NVFP4-pack-quantized HuggingFace causal LM and wraps
it so it can be plugged into ``diffusers.ZImagePipeline`` (or any pipeline
calling ``self.text_encoder(input_ids, attention_mask, output_hidden_states=True)``).

Strategy:
- Instantiate the HF model on the ``meta`` device (no real allocation).
- Walk every ``torch.nn.Linear`` and swap it for vLLM's ``ReplicatedLinear`` with
  ``CompressedTensorsConfig`` derived from the checkpoint's
  ``quantization_config``. This registers ``weight_packed`` / ``weight_scale`` /
  ``*_global_scale`` parameters in the exact layout vLLM's
  ``CompressedTensorsW4A4Fp4`` scheme expects.
- Materialise remaining (non-Linear) parameters (embeddings, RMSNorm, k/q norms)
  on the target device & dtype.
- Stream the safetensors file and dispatch each tensor through the registered
  vLLM ``weight_loader`` (which handles layout swizzling on
  ``process_weights_after_loading``).
- Tie the LM head to the input embedding when ``config.tie_word_embeddings``.

The result is a regular ``nn.Module`` matching the HF model's call signature
(``forward(input_ids, attention_mask, output_hidden_states)``) -- usable directly
as ``ZImagePipeline.text_encoder``.

vLLM requires a minimal global context (distributed process group + model
parallel state + active VllmConfig) even at TP=1 because ``ReplicatedLinear``
queries the TP world size at construction. We bootstrap that lazily once.

Forced kernel: we set ``VLLM_NVFP4_GEMM_BACKEND=cutlass`` to skip
flashinfer-cutlass JIT (which needs the ``ninja`` binary on PATH). The vLLM
CUTLASS kernel is built into the wheel.
"""
from __future__ import annotations

import json
import os
from collections.abc import Iterator
from typing import Optional

import torch
import torch.nn as nn


# ----------------------------------------------------------------------------
# One-time vLLM bootstrap (TP=1, no engine, just enough context for ReplicatedLinear)
# ----------------------------------------------------------------------------
_VLLM_BOOTSTRAPPED = False
_VLLM_CONFIG_CTX = None  # holds the entered set_current_vllm_config context manager


def _bootstrap_vllm_once() -> None:
    """Initialise the bits of vLLM that ReplicatedLinear needs at TP=1.

    Idempotent. Uses ``gloo`` so it works without NCCL/CUDA-aware MPI and even
    when CUDA is busy with the diffusion transformer.
    """
    global _VLLM_BOOTSTRAPPED, _VLLM_CONFIG_CTX
    if _VLLM_BOOTSTRAPPED:
        return

    # Force CUTLASS to avoid flashinfer-cutlass JIT (requires `ninja` on PATH).
    os.environ.setdefault("VLLM_NVFP4_GEMM_BACKEND", "cutlass")

    from vllm.config import VllmConfig
    from vllm.config.vllm import set_current_vllm_config
    from vllm.distributed import (
        ensure_model_parallel_initialized,
        init_distributed_environment,
    )

    # Pick a free port; world_size=1.
    import socket

    s = socket.socket()
    s.bind(("127.0.0.1", 0))
    port = s.getsockname()[1]
    s.close()

    if not torch.distributed.is_initialized():
        init_distributed_environment(
            world_size=1,
            rank=0,
            local_rank=0,
            distributed_init_method=f"tcp://127.0.0.1:{port}",
            backend="gloo",
        )

    # Enter a long-lived VllmConfig context. We never exit it -- the encoder
    # may construct submodules lazily and ReplicatedLinear calls
    # get_current_vllm_config() at init.
    vc = VllmConfig()
    _VLLM_CONFIG_CTX = set_current_vllm_config(vc)
    _VLLM_CONFIG_CTX.__enter__()

    ensure_model_parallel_initialized(1, 1)
    _VLLM_BOOTSTRAPPED = True


# ----------------------------------------------------------------------------
# Module: linear replacement
# ----------------------------------------------------------------------------
def _replace_linears_with_replicated(
    model: nn.Module, quant_config
) -> None:
    """Recursively swap every ``nn.Linear`` for vLLM ``ReplicatedLinear``.

    Carries the ``prefix`` so quant_config's ``ignore`` patterns (e.g. ``lm_head``)
    are correctly applied.
    """
    from vllm.model_executor.layers.linear import ReplicatedLinear

    def _walk(parent: nn.Module, prefix: str) -> None:
        for child_name, child in list(parent.named_children()):
            qname = f"{prefix}.{child_name}" if prefix else child_name
            if isinstance(child, nn.Linear):
                new = ReplicatedLinear(
                    input_size=child.in_features,
                    output_size=child.out_features,
                    bias=child.bias is not None,
                    quant_config=quant_config,
                    prefix=qname,
                    return_bias=False,
                    params_dtype=torch.bfloat16,
                )
                setattr(parent, child_name, new)
            else:
                _walk(child, qname)

    _walk(model, prefix="")


def _materialize_remaining_meta_params(
    model: nn.Module, dtype: torch.dtype, device: torch.device
) -> None:
    """Replace any ``meta`` parameter with empty real storage.

    Only touches parameters NOT already created on a real device by the
    ReplicatedLinear swap above (i.e. embeddings, layernorms, biases).
    """
    for name, param in list(model.named_parameters(recurse=True)):
        if param.device.type == "meta":
            real = nn.Parameter(
                torch.empty(param.shape, dtype=dtype, device=device),
                requires_grad=False,
            )
            # Replace in the parent module
            parent = model
            *path, leaf = name.split(".")
            for p in path:
                parent = getattr(parent, p)
            setattr(parent, leaf, real)
    # Same for buffers (e.g. rotary inv_freq if registered as buffer on meta)
    for name, buf in list(model.named_buffers(recurse=True)):
        if buf.device.type == "meta":
            real = torch.empty(buf.shape, dtype=buf.dtype, device=device)
            parent = model
            *path, leaf = name.split(".")
            for p in path:
                parent = getattr(parent, p)
            parent.register_buffer(leaf, real, persistent=False)


# ----------------------------------------------------------------------------
# Weight loading
# ----------------------------------------------------------------------------
def _iter_safetensors(model_dir: str) -> Iterator[tuple[str, torch.Tensor]]:
    """Yield (name, tensor) pairs from all *.safetensors shards in ``model_dir``."""
    from safetensors import safe_open

    # Single-file checkpoint or sharded? Prefer ``model.safetensors.index.json``.
    index_path = os.path.join(model_dir, "model.safetensors.index.json")
    if os.path.exists(index_path):
        with open(index_path) as f:
            index = json.load(f)
        shards = sorted(set(index["weight_map"].values()))
    else:
        # Find all *.safetensors files in dir
        shards = sorted(
            fn for fn in os.listdir(model_dir) if fn.endswith(".safetensors")
        )
    for shard in shards:
        path = os.path.join(model_dir, shard)
        with safe_open(path, framework="pt") as f:
            for key in f.keys():
                yield key, f.get_tensor(key)


def _load_weights_into_model(model: nn.Module, model_dir: str) -> None:
    """Stream safetensors into the (already-structured) model.

    Uses each ReplicatedLinear's registered ``weight_loader`` for quantised
    params (which handles tensor-parallel sharding, even though TP=1 here it
    keeps casts consistent). Other params (embeddings, layernorms, biases) are
    copied directly.
    """
    # Strip vllm-omni-style "text_encoder." prefix if present; not applicable
    # here since we load the standalone HF Qwen3 checkpoint where keys start
    # with "model.layers..." / "model.embed_tokens..." / "lm_head...".
    name_to_param: dict[str, nn.Parameter] = dict(model.named_parameters(recurse=True))
    name_to_buffer: dict[str, torch.Tensor] = dict(model.named_buffers(recurse=True))

    missing = set(name_to_param.keys())
    unexpected = []

    for key, tensor in _iter_safetensors(model_dir):
        # Skip rotary inv_freq etc that aren't params (rare in modern HF saves)
        if key in name_to_param:
            param = name_to_param[key]
            wl = getattr(param, "weight_loader", None)
            if wl is not None:
                wl(param, tensor.to(param.device))
            else:
                with torch.no_grad():
                    param.data.copy_(tensor.to(param.device, dtype=param.dtype))
            missing.discard(key)
        elif key in name_to_buffer:
            with torch.no_grad():
                name_to_buffer[key].copy_(tensor.to(name_to_buffer[key].device))
        else:
            unexpected.append(key)

    # Tied embeddings (lm_head.weight not in checkpoint when tie_word_embeddings=True)
    cfg = getattr(model, "config", None)
    if cfg is not None and getattr(cfg, "tie_word_embeddings", False):
        try:
            inp_emb = model.get_input_embeddings().weight
            model.lm_head.weight = inp_emb  # share storage
            missing.discard("lm_head.weight")
        except Exception:
            pass

    if missing:
        # It's OK if missing entries are *purely* lm_head.weight when tied; we
        # already handled that above. Anything else is fatal-ish.
        leftover = sorted(missing)
        if leftover:
            print(
                f"[NVFP4TextEncoder] WARN: {len(leftover)} params missing from checkpoint; "
                f"first 5: {leftover[:5]}"
            )
    if unexpected:
        print(
            f"[NVFP4TextEncoder] WARN: {len(unexpected)} keys in checkpoint unused; "
            f"first 5: {unexpected[:5]}"
        )


def _process_weights_after_loading(model: nn.Module) -> None:
    """Invoke vLLM's per-layer ``process_weights_after_loading`` for each
    ReplicatedLinear (renames ``weight_packed`` -> ``weight``, computes ``alpha``,
    swizzles scales for the CUTLASS kernel, etc.)."""
    for module in model.modules():
        qm = getattr(module, "quant_method", None)
        if qm is not None and hasattr(qm, "process_weights_after_loading"):
            qm.process_weights_after_loading(module)


# ----------------------------------------------------------------------------
# Public API
# ----------------------------------------------------------------------------
def load_nvfp4_text_encoder(
    model_dir: str,
    device: str | torch.device = "cuda",
    dtype: torch.dtype = torch.bfloat16,
) -> nn.Module:
    """Load an NVFP4-quantised HuggingFace causal LM as a plug-in text encoder.

    Args:
        model_dir: path to the checkpoint directory containing ``config.json``
            and ``model*.safetensors``. The config must carry a
            ``quantization_config`` block with ``"format": "nvfp4-pack-quantized"``.
        device: target CUDA device (forwards to ``model.to(device)``-equivalent
            during materialisation).
        dtype: activation / non-quantised-param dtype.

    Returns:
        A ``PreTrainedModel`` whose ``Linear`` layers are NVFP4 inside the vLLM
        CUTLASS kernel. Activations flow as ``dtype``.
    """
    _bootstrap_vllm_once()

    from transformers import AutoConfig, AutoModelForCausalLM
    from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
        CompressedTensorsConfig,
    )
    from vllm.model_executor.models.transformers.utils import (
        init_on_device_without_buffers,
    )

    hf_config = AutoConfig.from_pretrained(model_dir, local_files_only=True)
    if not getattr(hf_config, "quantization_config", None):
        raise ValueError(
            f"{model_dir}/config.json has no `quantization_config`; "
            "this loader only handles NVFP4-quantised checkpoints."
        )
    quant_config = CompressedTensorsConfig.from_config(hf_config.quantization_config)

    # 1) Build the model skeleton on meta (zero allocation).
    with init_on_device_without_buffers("meta"):
        model = AutoModelForCausalLM.from_config(hf_config)

    # 2) Swap Linear -> ReplicatedLinear(quant_config) (creates real CUDA params
    #    of the quantised shapes).
    target_device = torch.device(device)
    _replace_linears_with_replicated(model, quant_config)

    # 3) Materialise any leftover meta parameters (embeddings, RMSNorms, ...)
    _materialize_remaining_meta_params(model, dtype=dtype, device=target_device)

    # 4) Move newly-created quantised params to target device (ReplicatedLinear
    #    creates them on the current default device which is usually CPU).
    model.to(target_device)

    # 5) Load weights via per-param weight_loader.
    _load_weights_into_model(model, model_dir)

    # 6) Let vLLM swizzle scales / rename weight_packed->weight / compute alpha.
    _process_weights_after_loading(model)

    # 7) Match HF semantics for downstream pipelines.
    model.eval()
    model.config.use_cache = False
    return model