Z-Image-Turbo-MLX / weight_loader.py
illusion615's picture
Upload folder using huggingface_hub
64566e4 verified
"""Weight loader for Z-Image-Turbo MLX backend.
Loads safetensors weights from HuggingFace cache and maps them
to the MLX module hierarchy.
"""
from __future__ import annotations
import glob
import logging
from pathlib import Path
import mlx.core as mx
logger = logging.getLogger("zimage-mlx")
# Default HF cache path for Z-Image-Turbo
_DEFAULT_MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
_HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
# Local weights directory (project-local, survives HF cache cleanup)
_LOCAL_WEIGHTS_DIR = Path(__file__).parent / "weights"
def _find_model_path(model_id: str = _DEFAULT_MODEL_ID) -> Path:
"""Find local weight path for a model.
Priority:
1. Project-local ``backends/mlx_zimage/weights/`` (if text_encoder/ exists)
2. HF cache ``~/.cache/huggingface/hub/models--Tongyi-MAI--Z-Image-Turbo/``
"""
# 1. Local weights directory
if _LOCAL_WEIGHTS_DIR.is_dir() and (_LOCAL_WEIGHTS_DIR / "text_encoder").is_dir():
logger.info("[ZImage] Using local weights: %s", _LOCAL_WEIGHTS_DIR)
return _LOCAL_WEIGHTS_DIR
# 2. HF cache
safe_id = model_id.replace("/", "--")
model_dir = _HF_CACHE / f"models--{safe_id}"
if not model_dir.exists():
raise FileNotFoundError(
f"Model not found. Neither local ({_LOCAL_WEIGHTS_DIR}) "
f"nor HF cache ({model_dir}) available."
)
# Find the latest snapshot
snapshots = sorted(model_dir.glob("snapshots/*"), key=lambda p: p.stat().st_mtime, reverse=True)
if not snapshots:
raise FileNotFoundError(f"No snapshots found in {model_dir}")
logger.info("[ZImage] Using HF cache: %s", snapshots[0])
return snapshots[0]
def _log_memory(label: str) -> None:
"""Log Metal memory usage (safe no-op if unavailable)."""
try:
active = mx.metal.get_active_memory() / (1024 ** 3)
peak = mx.metal.get_peak_memory() / (1024 ** 3)
logger.info("[ZImage] MEM %s: active=%.2f GB, peak=%.2f GB", label, active, peak)
except Exception:
pass # mx.metal not available (e.g. CI / non-Apple)
def _load_safetensors_shards(
shard_dir: Path,
pattern: str = "*.safetensors",
*,
key_filter: str | None = None,
) -> dict[str, mx.array]:
"""Load safetensors files via mx.load() β€” zero-copy, preserves bfloat16.
Args:
shard_dir: Directory containing safetensors shard files.
pattern: Glob pattern for shard files.
key_filter: If set, only load keys starting with this prefix.
"""
files = sorted(shard_dir.glob(pattern))
if not files:
raise FileNotFoundError(f"No safetensors files in {shard_dir}")
params: dict[str, mx.array] = {}
for f in files:
# mx.load() natively reads safetensors β†’ mx.array (preserves bfloat16)
shard = mx.load(str(f))
if key_filter:
shard = {k: v for k, v in shard.items() if k.startswith(key_filter)}
params.update(shard)
logger.info("[ZImage] Loaded shard %s (%d keys)", f.name, len(shard))
logger.info("[ZImage] Total: %d keys from %d files in %s", len(params), len(files), shard_dir.name)
_log_memory(f"after loading {shard_dir.name}")
return params
# ── Text Encoder weight mapping ──────────────────────────────────
def load_text_encoder_weights(model_path: Path | None = None) -> dict[str, mx.array]:
"""Load and map Qwen3 text encoder weights for MLX.
The safetensors keys use the pattern:
model.embed_tokens.weight
model.layers.N.input_layernorm.weight
model.layers.N.self_attn.q_proj.weight
...
model.norm.weight
Our MLX module uses:
embed_tokens.weight
layers.N.input_layernorm.weight
layers.N.self_attn.q_proj.weight
...
norm.weight
So we strip the leading "model." prefix.
"""
if model_path is None:
model_path = _find_model_path()
te_dir = model_path / "text_encoder"
raw = _load_safetensors_shards(te_dir, "model-*.safetensors")
mapped: dict[str, mx.array] = {}
for key, tensor in raw.items():
# Strip "model." prefix
if key.startswith("model."):
new_key = key[len("model."):]
else:
new_key = key
mapped[new_key] = tensor
logger.info("[ZImage] Text encoder: %d parameters mapped", len(mapped))
return mapped
# ── Transformer weight mapping ───────────────────────────────────
def load_transformer_weights(model_path: Path | None = None) -> dict[str, mx.array]:
"""Load ZImageTransformer2DModel weights."""
if model_path is None:
model_path = _find_model_path()
dit_dir = model_path / "transformer"
raw = _load_safetensors_shards(dit_dir, "diffusion_pytorch_model-*.safetensors")
# Keys are already flat (no "model." prefix), use as-is
logger.info("[ZImage] Transformer: %d parameters loaded", len(raw))
return raw
# ── VAE weight mapping ───────────────────────────────────────────
def load_vae_weights(model_path: Path | None = None) -> dict[str, mx.array]:
"""Load AutoencoderKL weights."""
if model_path is None:
model_path = _find_model_path()
vae_dir = model_path / "vae"
raw = _load_safetensors_shards(vae_dir)
logger.info("[ZImage] VAE: %d parameters loaded", len(raw))
return raw
def load_vae_decoder_weights(model_path: Path | None = None) -> list[tuple[str, mx.array]]:
"""Load VAE decoder weights, mapped for the MLX Decoder module.
Only loads keys starting with ``decoder.`` (skips encoder weights
to avoid wasting memory). Performs two transformations:
1. Strips the ``decoder.`` prefix so keys match the Decoder module tree.
2. Transposes Conv2d weights from PyTorch (O,I,kH,kW) β†’ MLX (O,kH,kW,I).
Returns a list of (key, array) tuples ready for ``Decoder.load_weights()``.
"""
if model_path is None:
model_path = _find_model_path()
vae_dir = model_path / "vae"
# Only load decoder.* keys β€” skip encoder weights entirely
raw = _load_safetensors_shards(vae_dir, key_filter="decoder.")
weights: list[tuple[str, mx.array]] = []
for key, val in raw.items():
key = key[len("decoder."):]
# Conv2d weight: (O, I, kH, kW) β†’ (O, kH, kW, I)
if val.ndim == 4:
val = val.transpose(0, 2, 3, 1)
# force_upcast: ensure float32 for numerical stability
val = val.astype(mx.float32)
weights.append((key, val))
logger.info("[ZImage] VAE decoder: %d parameters mapped", len(weights))
return weights