File size: 5,035 Bytes
49a0fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd269a
49a0fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd269a
49a0fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd269a
49a0fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import time
import torch
from safetensors.torch import load_file
from diffusers import ZImagePipeline, ZImageTransformer2DModel
from diffusers.models.attention_dispatch import attention_backend
from huggingface_hub import hf_hub_download

# Set HF repo id after upload, or leave empty to load from local folder.
MODEL_ID = os.environ.get("MODEL_ID", "ykarout/Z-Image-Turbo-FP8-Full")
REPO_DIR = os.path.dirname(os.path.abspath(__file__))
USE_LOCAL = os.environ.get("USE_LOCAL", "") == "1" or not MODEL_ID
MODEL_SRC = REPO_DIR if USE_LOCAL else MODEL_ID

# Select which FP8 transformer file to use.
if USE_LOCAL:
    TRANSFORMER_FP8 = os.path.join(REPO_DIR, "transformer", "diffusion_pytorch_model.safetensors")  # E4M3FN (local)
    # TRANSFORMER_FP8 = os.path.join(REPO_DIR, "transformer", "diffusion_pytorch_model_e5m2.safetensors")  # E5M2 (local)
else:
    TRANSFORMER_FP8 = hf_hub_download(
        repo_id=MODEL_ID,
        filename="transformer/diffusion_pytorch_model.safetensors",
        local_files_only=USE_LOCAL,
    )
    # TRANSFORMER_FP8 = hf_hub_download(repo_id=MODEL_ID, filename="transformer/diffusion_pytorch_model_e5m2.safetensors")


def _strip_prefix(state_dict):
    prefix = "model.diffusion_model."
    return {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}


def load_fp8_transformer(fp8_path, config_root):
    raw = load_file(fp8_path)
    raw = _strip_prefix(raw)
    return ZImageTransformer2DModel.from_single_file(
        raw,
        config=config_root,
        subfolder="transformer",
        torch_dtype=torch.bfloat16,
        local_files_only=USE_LOCAL,
        low_cpu_mem_usage=False,
    )


def _replace_linear_with_te(module, prefix=""):
    for name, child in module.named_children():
        path = f"{prefix}.{name}" if prefix else name
        if isinstance(child, torch.nn.Linear):
            # Skip small MLPs that violate TE FP8 shape constraints.
            if "t_embedder" in path or "adaLN_modulation" in path:
                continue
            te_linear = te.Linear(
                child.in_features,
                child.out_features,
                bias=child.bias is not None,
                params_dtype=child.weight.dtype,
            )
            te_linear.weight = torch.nn.Parameter(child.weight)
            if child.bias is not None:
                te_linear.bias = torch.nn.Parameter(child.bias)
            setattr(module, name, te_linear)
        else:
            _replace_linear_with_te(child, path)


# 1. Load the pipeline
# Use bfloat16 for optimal performance on supported GPUs
pipe = ZImagePipeline.from_pretrained(
    MODEL_SRC,
    torch_dtype=torch.bfloat16,
    local_files_only=USE_LOCAL,
    low_cpu_mem_usage=False,
)
pipe.transformer = load_fp8_transformer(TRANSFORMER_FP8, MODEL_SRC)

try:
    import transformer_engine.pytorch as te
    from transformer_engine.common import recipe as te_recipe

    _replace_linear_with_te(pipe.transformer)
    fp8_recipe = te_recipe.DelayedScaling()
except Exception as e:
    raise RuntimeError(
        "Transformer Engine is required for FP8 execution. Install TE and rerun. "
        f"Import error: {e}"
    )

# [Optional] Attention Backend
# Diffusers uses SDPA by default. Switch to Flash Attention for better efficiency if supported:
# pipe.transformer.set_attention_backend("flash")    # Flash-Attention-2
# pipe.transformer.set_attention_backend("_flash_3") # Flash-Attention-3
# Or use the native flash SDPA backend:
# attention_ctx = attention_backend("_native_flash")

# [Optional] Model Compilation
# Compiling the DiT model accelerates inference, but the first run will take longer to compile.
# pipe.transformer.compile()

# [Optional] CPU Offloading
# Enable CPU offloading for memory-constrained devices.
pipe.enable_model_cpu_offload()

prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start = time.perf_counter()

# 2. Generate Image
with attention_backend("_native_flash"), te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe), torch.inference_mode():
    image = pipe(
        prompt=prompt,
        height=1024,
        width=1024,
        num_inference_steps=9,  # This actually results in 8 DiT forwards
        guidance_scale=0.0,     # Guidance should be 0 for the Turbo models
        generator=torch.Generator("cuda").manual_seed(42),
    ).images[0]

torch.cuda.synchronize()
elapsed_s = time.perf_counter() - start
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)

image.save("example.png")
print(f"saved example.png | elapsed_s={elapsed_s:.3f} peak_allocated_gb={peak_gb:.3f}")