"""Weight loader for Z-Image-Turbo MLX backend. Loads safetensors weights from HuggingFace cache and maps them to the MLX module hierarchy. """ from __future__ import annotations import glob import logging from pathlib import Path import mlx.core as mx logger = logging.getLogger("zimage-mlx") # Default HF cache path for Z-Image-Turbo _DEFAULT_MODEL_ID = "Tongyi-MAI/Z-Image-Turbo" _HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub" # Local weights directory (project-local, survives HF cache cleanup) _LOCAL_WEIGHTS_DIR = Path(__file__).parent / "weights" def _find_model_path(model_id: str = _DEFAULT_MODEL_ID) -> Path: """Find local weight path for a model. Priority: 1. Project-local ``backends/mlx_zimage/weights/`` (if text_encoder/ exists) 2. HF cache ``~/.cache/huggingface/hub/models--Tongyi-MAI--Z-Image-Turbo/`` """ # 1. Local weights directory if _LOCAL_WEIGHTS_DIR.is_dir() and (_LOCAL_WEIGHTS_DIR / "text_encoder").is_dir(): logger.info("[ZImage] Using local weights: %s", _LOCAL_WEIGHTS_DIR) return _LOCAL_WEIGHTS_DIR # 2. HF cache safe_id = model_id.replace("/", "--") model_dir = _HF_CACHE / f"models--{safe_id}" if not model_dir.exists(): raise FileNotFoundError( f"Model not found. Neither local ({_LOCAL_WEIGHTS_DIR}) " f"nor HF cache ({model_dir}) available." ) # Find the latest snapshot snapshots = sorted(model_dir.glob("snapshots/*"), key=lambda p: p.stat().st_mtime, reverse=True) if not snapshots: raise FileNotFoundError(f"No snapshots found in {model_dir}") logger.info("[ZImage] Using HF cache: %s", snapshots[0]) return snapshots[0] def _log_memory(label: str) -> None: """Log Metal memory usage (safe no-op if unavailable).""" try: active = mx.metal.get_active_memory() / (1024 ** 3) peak = mx.metal.get_peak_memory() / (1024 ** 3) logger.info("[ZImage] MEM %s: active=%.2f GB, peak=%.2f GB", label, active, peak) except Exception: pass # mx.metal not available (e.g. CI / non-Apple) def _load_safetensors_shards( shard_dir: Path, pattern: str = "*.safetensors", *, key_filter: str | None = None, ) -> dict[str, mx.array]: """Load safetensors files via mx.load() — zero-copy, preserves bfloat16. Args: shard_dir: Directory containing safetensors shard files. pattern: Glob pattern for shard files. key_filter: If set, only load keys starting with this prefix. """ files = sorted(shard_dir.glob(pattern)) if not files: raise FileNotFoundError(f"No safetensors files in {shard_dir}") params: dict[str, mx.array] = {} for f in files: # mx.load() natively reads safetensors → mx.array (preserves bfloat16) shard = mx.load(str(f)) if key_filter: shard = {k: v for k, v in shard.items() if k.startswith(key_filter)} params.update(shard) logger.info("[ZImage] Loaded shard %s (%d keys)", f.name, len(shard)) logger.info("[ZImage] Total: %d keys from %d files in %s", len(params), len(files), shard_dir.name) _log_memory(f"after loading {shard_dir.name}") return params # ── Text Encoder weight mapping ────────────────────────────────── def load_text_encoder_weights(model_path: Path | None = None) -> dict[str, mx.array]: """Load and map Qwen3 text encoder weights for MLX. The safetensors keys use the pattern: model.embed_tokens.weight model.layers.N.input_layernorm.weight model.layers.N.self_attn.q_proj.weight ... model.norm.weight Our MLX module uses: embed_tokens.weight layers.N.input_layernorm.weight layers.N.self_attn.q_proj.weight ... norm.weight So we strip the leading "model." prefix. """ if model_path is None: model_path = _find_model_path() te_dir = model_path / "text_encoder" raw = _load_safetensors_shards(te_dir, "model-*.safetensors") mapped: dict[str, mx.array] = {} for key, tensor in raw.items(): # Strip "model." prefix if key.startswith("model."): new_key = key[len("model."):] else: new_key = key mapped[new_key] = tensor logger.info("[ZImage] Text encoder: %d parameters mapped", len(mapped)) return mapped # ── Transformer weight mapping ─────────────────────────────────── def load_transformer_weights(model_path: Path | None = None) -> dict[str, mx.array]: """Load ZImageTransformer2DModel weights.""" if model_path is None: model_path = _find_model_path() dit_dir = model_path / "transformer" raw = _load_safetensors_shards(dit_dir, "diffusion_pytorch_model-*.safetensors") # Keys are already flat (no "model." prefix), use as-is logger.info("[ZImage] Transformer: %d parameters loaded", len(raw)) return raw # ── VAE weight mapping ─────────────────────────────────────────── def load_vae_weights(model_path: Path | None = None) -> dict[str, mx.array]: """Load AutoencoderKL weights.""" if model_path is None: model_path = _find_model_path() vae_dir = model_path / "vae" raw = _load_safetensors_shards(vae_dir) logger.info("[ZImage] VAE: %d parameters loaded", len(raw)) return raw def load_vae_decoder_weights(model_path: Path | None = None) -> list[tuple[str, mx.array]]: """Load VAE decoder weights, mapped for the MLX Decoder module. Only loads keys starting with ``decoder.`` (skips encoder weights to avoid wasting memory). Performs two transformations: 1. Strips the ``decoder.`` prefix so keys match the Decoder module tree. 2. Transposes Conv2d weights from PyTorch (O,I,kH,kW) → MLX (O,kH,kW,I). Returns a list of (key, array) tuples ready for ``Decoder.load_weights()``. """ if model_path is None: model_path = _find_model_path() vae_dir = model_path / "vae" # Only load decoder.* keys — skip encoder weights entirely raw = _load_safetensors_shards(vae_dir, key_filter="decoder.") weights: list[tuple[str, mx.array]] = [] for key, val in raw.items(): key = key[len("decoder."):] # Conv2d weight: (O, I, kH, kW) → (O, kH, kW, I) if val.ndim == 4: val = val.transpose(0, 2, 3, 1) # force_upcast: ensure float32 for numerical stability val = val.astype(mx.float32) weights.append((key, val)) logger.info("[ZImage] VAE decoder: %d parameters mapped", len(weights)) return weights