Z-Image-Turbo-MLX / pipeline.py
illusion615's picture
Upload folder using huggingface_hub
64566e4 verified
"""Z-Image-Turbo MLX Pipeline β€” end-to-end text-to-image generation.
Flow:
1. Tokenize prompt β†’ token IDs
2. Qwen3 Encoder β†’ text hidden states (MLX)
3. Initialize random latents
4. Denoise loop: DiT forward pass Γ— N steps (MLX)
5. VAE decode latents β†’ RGB image (MLX native)
6. Save to PNG
"""
from __future__ import annotations
import logging
import time
from pathlib import Path
import mlx.core as mx
import numpy as np
from PIL import Image
from .autoencoder import Decoder
from .qwen3_encoder import Qwen3Encoder, Qwen3EncoderConfig
from .zimage_dit import ZImageTransformer, ZImageDiTConfig
from .scheduler import FlowMatchEulerScheduler
from .tokenizer import Qwen2Tokenizer
from .weight_loader import (
_find_model_path,
_log_memory,
load_text_encoder_weights,
load_transformer_weights,
load_vae_decoder_weights,
)
logger = logging.getLogger("zimage-mlx")
def _cast_to_bf16(model):
"""Cast all parameters of an nn.Module to bfloat16 in-place.
This halves memory and speeds up Metal compute for the DiT.
"""
from mlx.utils import tree_map
params = model.parameters()
bf16_params = tree_map(lambda x: x.astype(mx.bfloat16) if isinstance(x, mx.array) else x, params)
model.update(bf16_params)
return model
class ZImageMLXPipeline:
"""End-to-end Z-Image-Turbo inference pipeline β€” 100% MLX.
All stages run on Apple Silicon via MLX: text encoding,
DiT denoising, and VAE decoding. No PyTorch dependency.
"""
def __init__(self, model_id: str = "Tongyi-MAI/Z-Image-Turbo"):
self.model_id = model_id
self._model_path: Path | None = None
self._tokenizer: Qwen2Tokenizer | None = None
self._encoder: Qwen3Encoder | None = None
self._dit: ZImageTransformer | None = None
self._dit_compiled = None # mx.compile'd forward pass
self._scheduler = FlowMatchEulerScheduler(shift=3.0)
self._vae: Decoder | None = None
self._loaded = False
def load(self, model_path: Path | None = None):
"""Load all model components.
Memory strategy (staged loading):
- Encoder, DiT, VAE are loaded sequentially.
- During generation, encoder is released after text encoding
to reduce peak memory (see ``generate()``).
"""
t0 = time.monotonic()
self._model_path = model_path or _find_model_path(self.model_id)
_log_memory("before load")
# 1. Tokenizer
logger.info("[ZImage-MLX] Loading tokenizer...")
self._tokenizer = Qwen2Tokenizer(self._model_path)
# 2. Text encoder (Qwen3)
logger.info("[ZImage-MLX] Loading text encoder (Qwen3, 36 layers)...")
self._encoder = Qwen3Encoder(Qwen3EncoderConfig())
te_weights = load_text_encoder_weights(self._model_path)
self._encoder.load_weights(list(te_weights.items()))
# Weights are already bfloat16 on disk; keep them as-is for memory savings
mx.eval(self._encoder.parameters())
del te_weights # release weight dict immediately
logger.info("[ZImage-MLX] Text encoder loaded (bfloat16)")
_log_memory("after text encoder")
# 3. DiT (ZImageTransformer)
logger.info("[ZImage-MLX] Loading transformer (S3-DiT, 30+2+2 layers)...")
self._dit = ZImageTransformer(ZImageDiTConfig())
dit_weights = load_transformer_weights(self._model_path)
self._dit.load_weights(list(dit_weights.items()))
# Cast DiT to bfloat16 for faster inference (~2Γ— speedup, ~50% memory)
# PyTorch diffusers also runs at bfloat16 on MPS
self._dit = _cast_to_bf16(self._dit)
mx.eval(self._dit.parameters())
del dit_weights # release weight dict immediately
# Compile DiT forward for additional Metal kernel fusion speedup
self._dit_compiled = mx.compile(self._dit)
logger.info("[ZImage-MLX] Transformer loaded (bfloat16 + compiled)")
_log_memory("after transformer")
# 4. VAE decoder (MLX native)
logger.info("[ZImage-MLX] Loading VAE decoder...")
self._vae = Decoder()
vae_weights = load_vae_decoder_weights(self._model_path)
self._vae.load_weights(vae_weights)
mx.eval(self._vae.parameters())
del vae_weights # release weight list immediately
logger.info("[ZImage-MLX] VAE decoder loaded")
elapsed = time.monotonic() - t0
self._loaded = True
_log_memory("after full load")
logger.info("[ZImage-MLX] Pipeline loaded in %.1fs", elapsed)
def generate(
self,
prompt: str,
width: int = 768,
height: int = 768,
num_steps: int = 8,
seed: int | None = None,
guidance_scale: float = 0.0, # Z-Image-Turbo typically uses 0
max_text_len: int = 256,
) -> np.ndarray:
"""Generate an image from a text prompt.
Args:
prompt: Text description (Chinese or English)
width: Output width (must be divisible by 16)
height: Output height (must be divisible by 16)
num_steps: Number of denoising steps
seed: Random seed (None for random)
guidance_scale: CFG scale (0 = no guidance)
max_text_len: Max text token length
Returns:
RGB image as numpy array (H, W, 3) uint8
"""
if not self._loaded:
raise RuntimeError("Pipeline not loaded. Call load() first.")
t0 = time.monotonic()
if seed is None:
seed = int(time.time()) % (2**31)
# Ensure encoder is available (may have been released after prev gen)
self._reload_encoder()
# ── 1. Tokenize (with chat template like diffusers) ──
chat_result = self._tokenizer.apply_chat_template(prompt, max_length=max_text_len)
token_ids = chat_result["input_ids"] # list[int]
attn_mask = chat_result["attention_mask"] # list[int]
input_ids = mx.array([token_ids]) # (1, L)
# ── 2. Text encode ──
t_enc = time.monotonic()
if self._encoder is None:
raise RuntimeError("Text encoder not loaded. Call load() first.")
all_hidden = self._encoder(input_ids) # (1, L, 2560) β€” bfloat16
cap_feats = all_hidden # (1, L, 2560)
mx.eval(cap_feats)
logger.info("[ZImage-MLX] Text encoded in %.2fs, %d tokens", time.monotonic() - t_enc, cap_feats.shape[1])
# Release encoder to free memory before DiT denoising.
self._release_encoder()
# ── 3. Initialize latents ──
latent_h = height // 8
latent_w = width // 8
mx.random.seed(seed)
# Use bfloat16 latents to match DiT precision
latents = mx.random.normal((1, 16, latent_h, latent_w)).astype(mx.bfloat16)
# Ensure cap_feats is bfloat16 for DiT
cap_feats = cap_feats.astype(mx.bfloat16)
# ── 4. Denoise loop ──
sigmas = self._scheduler.get_sigmas(num_steps)
mx.eval(sigmas)
sigmas_list = sigmas.tolist()
dit_fn = self._dit_compiled if self._dit_compiled is not None else self._dit
t_denoise = time.monotonic()
for i in range(num_steps):
sigma = sigmas_list[i]
sigma_next = sigmas_list[i + 1]
t_step = mx.array([1.0 - sigma], dtype=mx.bfloat16)
noise_pred = dit_fn(latents, t_step, cap_feats)
# Diffusers negates the model output before passing to scheduler:
# noise_pred = -noise_pred
# Then scheduler does: prev_sample = sample + (sigma_next - sigma) * model_output
noise_pred = -noise_pred
latents = self._scheduler.step(noise_pred, sigma, sigma_next, latents)
mx.eval(latents)
logger.info("[ZImage-MLX] Step %d/%d (sigma %.4f β†’ %.4f)", i + 1, num_steps, sigma, sigma_next)
denoise_time = time.monotonic() - t_denoise
logger.info("[ZImage-MLX] Denoised in %.2fs (%.2fs/step)", denoise_time, denoise_time / num_steps)
# ── 5. VAE decode (MLX native) ──
t_vae = time.monotonic()
image = self._vae_decode(latents)
logger.info("[ZImage-MLX] VAE decoded in %.2fs", time.monotonic() - t_vae)
total = time.monotonic() - t0
logger.info("[ZImage-MLX] Total generation: %.2fs", total)
return image
def _vae_decode(self, latents: mx.array) -> np.ndarray:
"""Decode latents β†’ RGB image using MLX VAE.
Diffusers formula:
z = latents / scaling_factor + shift_factor
raw = vae.decode(z) # output in [-1, 1]
image = raw / 2 + 0.5 # denormalize to [0, 1]
"""
scaling_factor = 0.3611
shift_factor = 0.1159
# NCHW β†’ NHWC for MLX convolutions
z = latents.transpose(0, 2, 3, 1) # (B,C,H,W) β†’ (B,H,W,C)
z = z.astype(mx.float32) # force_upcast
z = z / scaling_factor + shift_factor
decoded = self._vae(z) # (B,8H,8W,3) in [-1, 1]
mx.eval(decoded)
# Denormalize [-1,1] β†’ [0,1], then clamp β†’ uint8
img = decoded[0] / 2.0 + 0.5
img = mx.clip(img, 0.0, 1.0)
img = np.array(img)
img = (img * 255).astype(np.uint8)
return img
def generate_and_save(
self,
prompt: str,
output_path: str,
width: int = 768,
height: int = 768,
num_steps: int = 8,
seed: int | None = None,
) -> dict:
"""Generate an image and save to file.
Returns:
Dict with generation metadata.
"""
t0 = time.monotonic()
if seed is None:
seed = int(time.time()) % (2**31)
image = self.generate(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
seed=seed,
)
# Save
img = Image.fromarray(image)
img.save(output_path)
elapsed = time.monotonic() - t0
return {
"image_path": output_path,
"width": width,
"height": height,
"seed": seed,
"num_steps": num_steps,
"elapsed_s": round(elapsed, 2),
"prompt": prompt,
}
def _release_encoder(self):
"""Release text encoder to free ~5 GB before denoising."""
if self._encoder is not None:
self._encoder = None
mx.clear_cache()
_log_memory("after releasing encoder")
def _reload_encoder(self):
"""Reload encoder for next generation (lazy, on-demand)."""
if self._encoder is None and self._model_path is not None:
logger.info("[ZImage-MLX] Reloading text encoder...")
self._encoder = Qwen3Encoder(Qwen3EncoderConfig())
te_weights = load_text_encoder_weights(self._model_path)
self._encoder.load_weights(list(te_weights.items()))
# Weights are bfloat16 on disk; keep as-is
mx.eval(self._encoder.parameters())
del te_weights
_log_memory("after reloading encoder")
def unload(self):
"""Release all model memory."""
self._encoder = None
self._dit = None
self._vae = None
self._tokenizer = None
self._loaded = False
mx.clear_cache()
_log_memory("after full unload")
mx.clear_cache()
logger.info("[ZImage-MLX] Pipeline unloaded")