| """Z-Image-Turbo MLX Pipeline β end-to-end text-to-image generation. |
| |
| Flow: |
| 1. Tokenize prompt β token IDs |
| 2. Qwen3 Encoder β text hidden states (MLX) |
| 3. Initialize random latents |
| 4. Denoise loop: DiT forward pass Γ N steps (MLX) |
| 5. VAE decode latents β RGB image (MLX native) |
| 6. Save to PNG |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import time |
| from pathlib import Path |
|
|
| import mlx.core as mx |
| import numpy as np |
| from PIL import Image |
|
|
| from .autoencoder import Decoder |
| from .qwen3_encoder import Qwen3Encoder, Qwen3EncoderConfig |
| from .zimage_dit import ZImageTransformer, ZImageDiTConfig |
| from .scheduler import FlowMatchEulerScheduler |
| from .tokenizer import Qwen2Tokenizer |
| from .weight_loader import ( |
| _find_model_path, |
| _log_memory, |
| load_text_encoder_weights, |
| load_transformer_weights, |
| load_vae_decoder_weights, |
| ) |
|
|
| logger = logging.getLogger("zimage-mlx") |
|
|
|
|
| def _cast_to_bf16(model): |
| """Cast all parameters of an nn.Module to bfloat16 in-place. |
| |
| This halves memory and speeds up Metal compute for the DiT. |
| """ |
| from mlx.utils import tree_map |
| params = model.parameters() |
| bf16_params = tree_map(lambda x: x.astype(mx.bfloat16) if isinstance(x, mx.array) else x, params) |
| model.update(bf16_params) |
| return model |
|
|
|
|
| class ZImageMLXPipeline: |
| """End-to-end Z-Image-Turbo inference pipeline β 100% MLX. |
| |
| All stages run on Apple Silicon via MLX: text encoding, |
| DiT denoising, and VAE decoding. No PyTorch dependency. |
| """ |
|
|
| def __init__(self, model_id: str = "Tongyi-MAI/Z-Image-Turbo"): |
| self.model_id = model_id |
| self._model_path: Path | None = None |
| self._tokenizer: Qwen2Tokenizer | None = None |
| self._encoder: Qwen3Encoder | None = None |
| self._dit: ZImageTransformer | None = None |
| self._dit_compiled = None |
| self._scheduler = FlowMatchEulerScheduler(shift=3.0) |
| self._vae: Decoder | None = None |
| self._loaded = False |
|
|
| def load(self, model_path: Path | None = None): |
| """Load all model components. |
| |
| Memory strategy (staged loading): |
| - Encoder, DiT, VAE are loaded sequentially. |
| - During generation, encoder is released after text encoding |
| to reduce peak memory (see ``generate()``). |
| """ |
| t0 = time.monotonic() |
| self._model_path = model_path or _find_model_path(self.model_id) |
| _log_memory("before load") |
|
|
| |
| logger.info("[ZImage-MLX] Loading tokenizer...") |
| self._tokenizer = Qwen2Tokenizer(self._model_path) |
|
|
| |
| logger.info("[ZImage-MLX] Loading text encoder (Qwen3, 36 layers)...") |
| self._encoder = Qwen3Encoder(Qwen3EncoderConfig()) |
| te_weights = load_text_encoder_weights(self._model_path) |
| self._encoder.load_weights(list(te_weights.items())) |
| |
| mx.eval(self._encoder.parameters()) |
| del te_weights |
| logger.info("[ZImage-MLX] Text encoder loaded (bfloat16)") |
| _log_memory("after text encoder") |
|
|
| |
| logger.info("[ZImage-MLX] Loading transformer (S3-DiT, 30+2+2 layers)...") |
| self._dit = ZImageTransformer(ZImageDiTConfig()) |
| dit_weights = load_transformer_weights(self._model_path) |
| self._dit.load_weights(list(dit_weights.items())) |
| |
| |
| self._dit = _cast_to_bf16(self._dit) |
| mx.eval(self._dit.parameters()) |
| del dit_weights |
| |
| self._dit_compiled = mx.compile(self._dit) |
| logger.info("[ZImage-MLX] Transformer loaded (bfloat16 + compiled)") |
| _log_memory("after transformer") |
|
|
| |
| logger.info("[ZImage-MLX] Loading VAE decoder...") |
| self._vae = Decoder() |
| vae_weights = load_vae_decoder_weights(self._model_path) |
| self._vae.load_weights(vae_weights) |
| mx.eval(self._vae.parameters()) |
| del vae_weights |
| logger.info("[ZImage-MLX] VAE decoder loaded") |
|
|
| elapsed = time.monotonic() - t0 |
| self._loaded = True |
| _log_memory("after full load") |
| logger.info("[ZImage-MLX] Pipeline loaded in %.1fs", elapsed) |
|
|
|
|
|
|
| def generate( |
| self, |
| prompt: str, |
| width: int = 768, |
| height: int = 768, |
| num_steps: int = 8, |
| seed: int | None = None, |
| guidance_scale: float = 0.0, |
| max_text_len: int = 256, |
| ) -> np.ndarray: |
| """Generate an image from a text prompt. |
| |
| Args: |
| prompt: Text description (Chinese or English) |
| width: Output width (must be divisible by 16) |
| height: Output height (must be divisible by 16) |
| num_steps: Number of denoising steps |
| seed: Random seed (None for random) |
| guidance_scale: CFG scale (0 = no guidance) |
| max_text_len: Max text token length |
| |
| Returns: |
| RGB image as numpy array (H, W, 3) uint8 |
| """ |
| if not self._loaded: |
| raise RuntimeError("Pipeline not loaded. Call load() first.") |
|
|
| t0 = time.monotonic() |
| if seed is None: |
| seed = int(time.time()) % (2**31) |
|
|
| |
| self._reload_encoder() |
|
|
| |
| chat_result = self._tokenizer.apply_chat_template(prompt, max_length=max_text_len) |
| token_ids = chat_result["input_ids"] |
| attn_mask = chat_result["attention_mask"] |
|
|
| input_ids = mx.array([token_ids]) |
|
|
| |
| t_enc = time.monotonic() |
| if self._encoder is None: |
| raise RuntimeError("Text encoder not loaded. Call load() first.") |
| all_hidden = self._encoder(input_ids) |
| cap_feats = all_hidden |
| mx.eval(cap_feats) |
| logger.info("[ZImage-MLX] Text encoded in %.2fs, %d tokens", time.monotonic() - t_enc, cap_feats.shape[1]) |
|
|
| |
| self._release_encoder() |
|
|
| |
| latent_h = height // 8 |
| latent_w = width // 8 |
| mx.random.seed(seed) |
| |
| latents = mx.random.normal((1, 16, latent_h, latent_w)).astype(mx.bfloat16) |
|
|
| |
| cap_feats = cap_feats.astype(mx.bfloat16) |
|
|
| |
| sigmas = self._scheduler.get_sigmas(num_steps) |
| mx.eval(sigmas) |
| sigmas_list = sigmas.tolist() |
|
|
| dit_fn = self._dit_compiled if self._dit_compiled is not None else self._dit |
|
|
| t_denoise = time.monotonic() |
| for i in range(num_steps): |
| sigma = sigmas_list[i] |
| sigma_next = sigmas_list[i + 1] |
|
|
| t_step = mx.array([1.0 - sigma], dtype=mx.bfloat16) |
| noise_pred = dit_fn(latents, t_step, cap_feats) |
|
|
| |
| |
| |
| noise_pred = -noise_pred |
|
|
| latents = self._scheduler.step(noise_pred, sigma, sigma_next, latents) |
| mx.eval(latents) |
|
|
| logger.info("[ZImage-MLX] Step %d/%d (sigma %.4f β %.4f)", i + 1, num_steps, sigma, sigma_next) |
|
|
| denoise_time = time.monotonic() - t_denoise |
| logger.info("[ZImage-MLX] Denoised in %.2fs (%.2fs/step)", denoise_time, denoise_time / num_steps) |
|
|
| |
| t_vae = time.monotonic() |
| image = self._vae_decode(latents) |
| logger.info("[ZImage-MLX] VAE decoded in %.2fs", time.monotonic() - t_vae) |
|
|
| total = time.monotonic() - t0 |
| logger.info("[ZImage-MLX] Total generation: %.2fs", total) |
|
|
| return image |
|
|
| def _vae_decode(self, latents: mx.array) -> np.ndarray: |
| """Decode latents β RGB image using MLX VAE. |
| |
| Diffusers formula: |
| z = latents / scaling_factor + shift_factor |
| raw = vae.decode(z) # output in [-1, 1] |
| image = raw / 2 + 0.5 # denormalize to [0, 1] |
| """ |
| scaling_factor = 0.3611 |
| shift_factor = 0.1159 |
|
|
| |
| z = latents.transpose(0, 2, 3, 1) |
| z = z.astype(mx.float32) |
| z = z / scaling_factor + shift_factor |
|
|
| decoded = self._vae(z) |
| mx.eval(decoded) |
|
|
| |
| img = decoded[0] / 2.0 + 0.5 |
| img = mx.clip(img, 0.0, 1.0) |
| img = np.array(img) |
| img = (img * 255).astype(np.uint8) |
| return img |
|
|
| def generate_and_save( |
| self, |
| prompt: str, |
| output_path: str, |
| width: int = 768, |
| height: int = 768, |
| num_steps: int = 8, |
| seed: int | None = None, |
| ) -> dict: |
| """Generate an image and save to file. |
| |
| Returns: |
| Dict with generation metadata. |
| """ |
| t0 = time.monotonic() |
| if seed is None: |
| seed = int(time.time()) % (2**31) |
|
|
| image = self.generate( |
| prompt=prompt, |
| width=width, |
| height=height, |
| num_steps=num_steps, |
| seed=seed, |
| ) |
|
|
| |
| img = Image.fromarray(image) |
| img.save(output_path) |
|
|
| elapsed = time.monotonic() - t0 |
| return { |
| "image_path": output_path, |
| "width": width, |
| "height": height, |
| "seed": seed, |
| "num_steps": num_steps, |
| "elapsed_s": round(elapsed, 2), |
| "prompt": prompt, |
| } |
|
|
| def _release_encoder(self): |
| """Release text encoder to free ~5 GB before denoising.""" |
| if self._encoder is not None: |
| self._encoder = None |
| mx.clear_cache() |
| _log_memory("after releasing encoder") |
|
|
| def _reload_encoder(self): |
| """Reload encoder for next generation (lazy, on-demand).""" |
| if self._encoder is None and self._model_path is not None: |
| logger.info("[ZImage-MLX] Reloading text encoder...") |
| self._encoder = Qwen3Encoder(Qwen3EncoderConfig()) |
| te_weights = load_text_encoder_weights(self._model_path) |
| self._encoder.load_weights(list(te_weights.items())) |
| |
| mx.eval(self._encoder.parameters()) |
| del te_weights |
| _log_memory("after reloading encoder") |
|
|
| def unload(self): |
| """Release all model memory.""" |
| self._encoder = None |
| self._dit = None |
| self._vae = None |
| self._tokenizer = None |
| self._loaded = False |
| mx.clear_cache() |
| _log_memory("after full unload") |
| mx.clear_cache() |
| logger.info("[ZImage-MLX] Pipeline unloaded") |
|
|