File size: 6,075 Bytes
1e103b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import torch
from diffusers import ZImagePipeline
from nunchaku.models.transformers.transformer_zimage import NunchakuZImageTransformer2DModel
from nunchaku.utils import get_gpu_memory


class ZImageTurboBackend:
    def __init__(
        self,
        model_id,
        optimized_model_path=None,
        optimized_edit_model_path=None,
        uma=False,
        nvfp4_text_encoder_path: str | None = None,
    ):
        self.model_id = model_id
        self.optimized_model_path = optimized_model_path
        self.pipeline = None
        self.uma = uma
        # Optional path to an NVFP4-pack-quantized Qwen3 text encoder. When set,
        # we load the encoder via vLLM's CompressedTensorsW4A4Fp4 (CUTLASS NVFP4
        # GEMM) instead of the bf16 text_encoder shipped inside the Z-Image
        # base repo. Cuts encoder VRAM ~4x with negligible quality loss
        # (cosine >0.999 vs the bf16 reference on Thor).
        self.nvfp4_text_encoder_path = nvfp4_text_encoder_path

    def _build_nvfp4_text_encoder(self):
        """Load the NVFP4 text encoder if requested, returns (encoder, tokenizer) or (None, None)."""
        if not self.nvfp4_text_encoder_path:
            return None, None
        print(
            f"[ZImageTurboBackend] Loading NVFP4 text encoder from {self.nvfp4_text_encoder_path} "
            "(vLLM CompressedTensorsW4A4Fp4 + CUTLASS NVFP4 GEMM)"
        )
        from NVFP4TextEncoder import load_nvfp4_text_encoder
        from transformers import AutoTokenizer

        encoder = load_nvfp4_text_encoder(
            self.nvfp4_text_encoder_path,
            device="cuda",
            dtype=torch.bfloat16,
        )
        tokenizer = AutoTokenizer.from_pretrained(self.nvfp4_text_encoder_path)
        return encoder, tokenizer

    def load(self):
        print(f"Loading ZImageTurboBackend from {self.model_id}...")
        print(f"Loading NunchakuZImageTransformer2DModel from {self.optimized_model_path}...")

        # Load transformer (optimized model)
        transformer = NunchakuZImageTransformer2DModel.from_pretrained(self.optimized_model_path)

        # If requested, build the NVFP4 text encoder before constructing the pipeline so
        # diffusers does not also load the bf16 text_encoder from disk (it would double VRAM).
        nvfp4_encoder, nvfp4_tokenizer = self._build_nvfp4_text_encoder()

        # Load pipeline
        print("Initializing ZImagePipeline...")
        pipeline_kwargs = dict(
            transformer=transformer,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=False,  # standard for HF example
        )
        if nvfp4_encoder is not None:
            # Pass our pre-built encoder so diffusers skips loading the bf16 subfolder.
            pipeline_kwargs["text_encoder"] = nvfp4_encoder
            if nvfp4_tokenizer is not None:
                pipeline_kwargs["tokenizer"] = nvfp4_tokenizer

        pipeline = ZImagePipeline.from_pretrained(self.model_id, **pipeline_kwargs)

        gpu_mem = get_gpu_memory()
        print(f"GPU memory available: {gpu_mem} GB")

        # Enable Flash Attention 2
        try:
            if hasattr(pipeline.transformer, "set_attention_backend"):
                pipeline.transformer.set_attention_backend("native")
                print("Enabled Native SDPA for Z-Image transformer")
            if hasattr(pipeline.vae, "set_attention_backend"):
                pipeline.vae.set_attention_backend("native")
                print("Enabled Native SDPA for Z-Image VAE")
        except Exception as e:
            print(f"Could not enable Flash Attention 2: {e}")

        if self.uma:
            print("UMA mode enabled: Loading all components to GPU and disabling offloads")
            # When using the NVFP4 encoder, it is already on CUDA and its quantised parameters
            # are not compatible with diffusers' generic .to() pathway (e.g. uint8 weight_packed).
            # We move only the diffusers-managed components (vae, transformer if not nunchaku, ...).
            if nvfp4_encoder is not None:
                # Exclude text_encoder from blanket .to('cuda'); it is already on cuda.
                excl = getattr(pipeline, "_exclude_from_cpu_offload", [])
                if "text_encoder" not in excl:
                    excl.append("text_encoder")
                    pipeline._exclude_from_cpu_offload = excl
                for name, comp in pipeline.components.items():
                    if name == "text_encoder":
                        continue
                    if isinstance(comp, torch.nn.Module):
                        try:
                            comp.to("cuda")
                        except Exception:
                            pass
            else:
                pipeline.to("cuda")
        elif gpu_mem <= 18:
            print("GPU memory <= 18GB, using sequential cpu offload for low VRAM")
            # The prompt requested sequential offloading without splitting layers for Nunchaku
            pipeline._exclude_from_cpu_offload.append("transformer")
            if nvfp4_encoder is not None:
                # NVFP4 weights live entirely on CUDA; do not let accelerate move them.
                pipeline._exclude_from_cpu_offload.append("text_encoder")
            pipeline.enable_sequential_cpu_offload()
            transformer.to("cuda")
            if nvfp4_encoder is not None:
                nvfp4_encoder.to("cuda")
        else:
            print("GPU memory > 18GB, using cpu offload")
            if nvfp4_encoder is not None:
                if not hasattr(pipeline, "_exclude_from_cpu_offload"):
                    pipeline._exclude_from_cpu_offload = []
                pipeline._exclude_from_cpu_offload.append("text_encoder")
            pipeline.enable_model_cpu_offload()
            if nvfp4_encoder is not None:
                nvfp4_encoder.to("cuda")

        self.pipeline = pipeline
        # Return twice for pipeline and edit_pipeline (though Z-Image-Turbo is T2I only)
        return self.pipeline, self.pipeline