|
|
import os |
|
|
import time |
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
from diffusers import ZImagePipeline, ZImageTransformer2DModel |
|
|
from diffusers.models.attention_dispatch import attention_backend |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "ykarout/Z-Image-Turbo-FP8-Full") |
|
|
REPO_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
USE_LOCAL = os.environ.get("USE_LOCAL", "") == "1" or not MODEL_ID |
|
|
MODEL_SRC = REPO_DIR if USE_LOCAL else MODEL_ID |
|
|
|
|
|
|
|
|
if USE_LOCAL: |
|
|
TRANSFORMER_FP8 = os.path.join(REPO_DIR, "transformer", "diffusion_pytorch_model.safetensors") |
|
|
|
|
|
else: |
|
|
TRANSFORMER_FP8 = hf_hub_download( |
|
|
repo_id=MODEL_ID, |
|
|
filename="transformer/diffusion_pytorch_model.safetensors", |
|
|
local_files_only=USE_LOCAL, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _strip_prefix(state_dict): |
|
|
prefix = "model.diffusion_model." |
|
|
return {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()} |
|
|
|
|
|
|
|
|
def load_fp8_transformer(fp8_path, config_root): |
|
|
raw = load_file(fp8_path) |
|
|
raw = _strip_prefix(raw) |
|
|
return ZImageTransformer2DModel.from_single_file( |
|
|
raw, |
|
|
config=config_root, |
|
|
subfolder="transformer", |
|
|
torch_dtype=torch.bfloat16, |
|
|
local_files_only=USE_LOCAL, |
|
|
low_cpu_mem_usage=False, |
|
|
) |
|
|
|
|
|
|
|
|
def _replace_linear_with_te(module, prefix=""): |
|
|
for name, child in module.named_children(): |
|
|
path = f"{prefix}.{name}" if prefix else name |
|
|
if isinstance(child, torch.nn.Linear): |
|
|
|
|
|
if "t_embedder" in path or "adaLN_modulation" in path: |
|
|
continue |
|
|
te_linear = te.Linear( |
|
|
child.in_features, |
|
|
child.out_features, |
|
|
bias=child.bias is not None, |
|
|
params_dtype=child.weight.dtype, |
|
|
) |
|
|
te_linear.weight = torch.nn.Parameter(child.weight) |
|
|
if child.bias is not None: |
|
|
te_linear.bias = torch.nn.Parameter(child.bias) |
|
|
setattr(module, name, te_linear) |
|
|
else: |
|
|
_replace_linear_with_te(child, path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe = ZImagePipeline.from_pretrained( |
|
|
MODEL_SRC, |
|
|
torch_dtype=torch.bfloat16, |
|
|
local_files_only=USE_LOCAL, |
|
|
low_cpu_mem_usage=False, |
|
|
) |
|
|
pipe.transformer = load_fp8_transformer(TRANSFORMER_FP8, MODEL_SRC) |
|
|
|
|
|
try: |
|
|
import transformer_engine.pytorch as te |
|
|
from transformer_engine.common import recipe as te_recipe |
|
|
|
|
|
_replace_linear_with_te(pipe.transformer) |
|
|
fp8_recipe = te_recipe.DelayedScaling() |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Transformer Engine is required for FP8 execution. Install TE and rerun. " |
|
|
f"Import error: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." |
|
|
|
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
torch.cuda.synchronize() |
|
|
start = time.perf_counter() |
|
|
|
|
|
|
|
|
with attention_backend("_native_flash"), te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe), torch.inference_mode(): |
|
|
image = pipe( |
|
|
prompt=prompt, |
|
|
height=1024, |
|
|
width=1024, |
|
|
num_inference_steps=9, |
|
|
guidance_scale=0.0, |
|
|
generator=torch.Generator("cuda").manual_seed(42), |
|
|
).images[0] |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
elapsed_s = time.perf_counter() - start |
|
|
peak_gb = torch.cuda.max_memory_allocated() / (1024**3) |
|
|
|
|
|
image.save("example.png") |
|
|
print(f"saved example.png | elapsed_s={elapsed_s:.3f} peak_allocated_gb={peak_gb:.3f}") |
|
|
|