Z-Image-Turbo-FP8-Full / create-image.py
ykarout's picture
Upload folder using huggingface_hub
1cd269a verified
import os
import time
import torch
from safetensors.torch import load_file
from diffusers import ZImagePipeline, ZImageTransformer2DModel
from diffusers.models.attention_dispatch import attention_backend
from huggingface_hub import hf_hub_download
# Set HF repo id after upload, or leave empty to load from local folder.
MODEL_ID = os.environ.get("MODEL_ID", "ykarout/Z-Image-Turbo-FP8-Full")
REPO_DIR = os.path.dirname(os.path.abspath(__file__))
USE_LOCAL = os.environ.get("USE_LOCAL", "") == "1" or not MODEL_ID
MODEL_SRC = REPO_DIR if USE_LOCAL else MODEL_ID
# Select which FP8 transformer file to use.
if USE_LOCAL:
TRANSFORMER_FP8 = os.path.join(REPO_DIR, "transformer", "diffusion_pytorch_model.safetensors") # E4M3FN (local)
# TRANSFORMER_FP8 = os.path.join(REPO_DIR, "transformer", "diffusion_pytorch_model_e5m2.safetensors") # E5M2 (local)
else:
TRANSFORMER_FP8 = hf_hub_download(
repo_id=MODEL_ID,
filename="transformer/diffusion_pytorch_model.safetensors",
local_files_only=USE_LOCAL,
)
# TRANSFORMER_FP8 = hf_hub_download(repo_id=MODEL_ID, filename="transformer/diffusion_pytorch_model_e5m2.safetensors")
def _strip_prefix(state_dict):
prefix = "model.diffusion_model."
return {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
def load_fp8_transformer(fp8_path, config_root):
raw = load_file(fp8_path)
raw = _strip_prefix(raw)
return ZImageTransformer2DModel.from_single_file(
raw,
config=config_root,
subfolder="transformer",
torch_dtype=torch.bfloat16,
local_files_only=USE_LOCAL,
low_cpu_mem_usage=False,
)
def _replace_linear_with_te(module, prefix=""):
for name, child in module.named_children():
path = f"{prefix}.{name}" if prefix else name
if isinstance(child, torch.nn.Linear):
# Skip small MLPs that violate TE FP8 shape constraints.
if "t_embedder" in path or "adaLN_modulation" in path:
continue
te_linear = te.Linear(
child.in_features,
child.out_features,
bias=child.bias is not None,
params_dtype=child.weight.dtype,
)
te_linear.weight = torch.nn.Parameter(child.weight)
if child.bias is not None:
te_linear.bias = torch.nn.Parameter(child.bias)
setattr(module, name, te_linear)
else:
_replace_linear_with_te(child, path)
# 1. Load the pipeline
# Use bfloat16 for optimal performance on supported GPUs
pipe = ZImagePipeline.from_pretrained(
MODEL_SRC,
torch_dtype=torch.bfloat16,
local_files_only=USE_LOCAL,
low_cpu_mem_usage=False,
)
pipe.transformer = load_fp8_transformer(TRANSFORMER_FP8, MODEL_SRC)
try:
import transformer_engine.pytorch as te
from transformer_engine.common import recipe as te_recipe
_replace_linear_with_te(pipe.transformer)
fp8_recipe = te_recipe.DelayedScaling()
except Exception as e:
raise RuntimeError(
"Transformer Engine is required for FP8 execution. Install TE and rerun. "
f"Import error: {e}"
)
# [Optional] Attention Backend
# Diffusers uses SDPA by default. Switch to Flash Attention for better efficiency if supported:
# pipe.transformer.set_attention_backend("flash") # Flash-Attention-2
# pipe.transformer.set_attention_backend("_flash_3") # Flash-Attention-3
# Or use the native flash SDPA backend:
# attention_ctx = attention_backend("_native_flash")
# [Optional] Model Compilation
# Compiling the DiT model accelerates inference, but the first run will take longer to compile.
# pipe.transformer.compile()
# [Optional] CPU Offloading
# Enable CPU offloading for memory-constrained devices.
pipe.enable_model_cpu_offload()
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start = time.perf_counter()
# 2. Generate Image
with attention_backend("_native_flash"), te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe), torch.inference_mode():
image = pipe(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=9, # This actually results in 8 DiT forwards
guidance_scale=0.0, # Guidance should be 0 for the Turbo models
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
torch.cuda.synchronize()
elapsed_s = time.perf_counter() - start
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
image.save("example.png")
print(f"saved example.png | elapsed_s={elapsed_s:.3f} peak_allocated_gb={peak_gb:.3f}")