File size: 6,857 Bytes
64566e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """Weight loader for Z-Image-Turbo MLX backend.
Loads safetensors weights from HuggingFace cache and maps them
to the MLX module hierarchy.
"""
from __future__ import annotations
import glob
import logging
from pathlib import Path
import mlx.core as mx
logger = logging.getLogger("zimage-mlx")
# Default HF cache path for Z-Image-Turbo
_DEFAULT_MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
_HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub"
# Local weights directory (project-local, survives HF cache cleanup)
_LOCAL_WEIGHTS_DIR = Path(__file__).parent / "weights"
def _find_model_path(model_id: str = _DEFAULT_MODEL_ID) -> Path:
"""Find local weight path for a model.
Priority:
1. Project-local ``backends/mlx_zimage/weights/`` (if text_encoder/ exists)
2. HF cache ``~/.cache/huggingface/hub/models--Tongyi-MAI--Z-Image-Turbo/``
"""
# 1. Local weights directory
if _LOCAL_WEIGHTS_DIR.is_dir() and (_LOCAL_WEIGHTS_DIR / "text_encoder").is_dir():
logger.info("[ZImage] Using local weights: %s", _LOCAL_WEIGHTS_DIR)
return _LOCAL_WEIGHTS_DIR
# 2. HF cache
safe_id = model_id.replace("/", "--")
model_dir = _HF_CACHE / f"models--{safe_id}"
if not model_dir.exists():
raise FileNotFoundError(
f"Model not found. Neither local ({_LOCAL_WEIGHTS_DIR}) "
f"nor HF cache ({model_dir}) available."
)
# Find the latest snapshot
snapshots = sorted(model_dir.glob("snapshots/*"), key=lambda p: p.stat().st_mtime, reverse=True)
if not snapshots:
raise FileNotFoundError(f"No snapshots found in {model_dir}")
logger.info("[ZImage] Using HF cache: %s", snapshots[0])
return snapshots[0]
def _log_memory(label: str) -> None:
"""Log Metal memory usage (safe no-op if unavailable)."""
try:
active = mx.metal.get_active_memory() / (1024 ** 3)
peak = mx.metal.get_peak_memory() / (1024 ** 3)
logger.info("[ZImage] MEM %s: active=%.2f GB, peak=%.2f GB", label, active, peak)
except Exception:
pass # mx.metal not available (e.g. CI / non-Apple)
def _load_safetensors_shards(
shard_dir: Path,
pattern: str = "*.safetensors",
*,
key_filter: str | None = None,
) -> dict[str, mx.array]:
"""Load safetensors files via mx.load() β zero-copy, preserves bfloat16.
Args:
shard_dir: Directory containing safetensors shard files.
pattern: Glob pattern for shard files.
key_filter: If set, only load keys starting with this prefix.
"""
files = sorted(shard_dir.glob(pattern))
if not files:
raise FileNotFoundError(f"No safetensors files in {shard_dir}")
params: dict[str, mx.array] = {}
for f in files:
# mx.load() natively reads safetensors β mx.array (preserves bfloat16)
shard = mx.load(str(f))
if key_filter:
shard = {k: v for k, v in shard.items() if k.startswith(key_filter)}
params.update(shard)
logger.info("[ZImage] Loaded shard %s (%d keys)", f.name, len(shard))
logger.info("[ZImage] Total: %d keys from %d files in %s", len(params), len(files), shard_dir.name)
_log_memory(f"after loading {shard_dir.name}")
return params
# ββ Text Encoder weight mapping ββββββββββββββββββββββββββββββββββ
def load_text_encoder_weights(model_path: Path | None = None) -> dict[str, mx.array]:
"""Load and map Qwen3 text encoder weights for MLX.
The safetensors keys use the pattern:
model.embed_tokens.weight
model.layers.N.input_layernorm.weight
model.layers.N.self_attn.q_proj.weight
...
model.norm.weight
Our MLX module uses:
embed_tokens.weight
layers.N.input_layernorm.weight
layers.N.self_attn.q_proj.weight
...
norm.weight
So we strip the leading "model." prefix.
"""
if model_path is None:
model_path = _find_model_path()
te_dir = model_path / "text_encoder"
raw = _load_safetensors_shards(te_dir, "model-*.safetensors")
mapped: dict[str, mx.array] = {}
for key, tensor in raw.items():
# Strip "model." prefix
if key.startswith("model."):
new_key = key[len("model."):]
else:
new_key = key
mapped[new_key] = tensor
logger.info("[ZImage] Text encoder: %d parameters mapped", len(mapped))
return mapped
# ββ Transformer weight mapping βββββββββββββββββββββββββββββββββββ
def load_transformer_weights(model_path: Path | None = None) -> dict[str, mx.array]:
"""Load ZImageTransformer2DModel weights."""
if model_path is None:
model_path = _find_model_path()
dit_dir = model_path / "transformer"
raw = _load_safetensors_shards(dit_dir, "diffusion_pytorch_model-*.safetensors")
# Keys are already flat (no "model." prefix), use as-is
logger.info("[ZImage] Transformer: %d parameters loaded", len(raw))
return raw
# ββ VAE weight mapping βββββββββββββββββββββββββββββββββββββββββββ
def load_vae_weights(model_path: Path | None = None) -> dict[str, mx.array]:
"""Load AutoencoderKL weights."""
if model_path is None:
model_path = _find_model_path()
vae_dir = model_path / "vae"
raw = _load_safetensors_shards(vae_dir)
logger.info("[ZImage] VAE: %d parameters loaded", len(raw))
return raw
def load_vae_decoder_weights(model_path: Path | None = None) -> list[tuple[str, mx.array]]:
"""Load VAE decoder weights, mapped for the MLX Decoder module.
Only loads keys starting with ``decoder.`` (skips encoder weights
to avoid wasting memory). Performs two transformations:
1. Strips the ``decoder.`` prefix so keys match the Decoder module tree.
2. Transposes Conv2d weights from PyTorch (O,I,kH,kW) β MLX (O,kH,kW,I).
Returns a list of (key, array) tuples ready for ``Decoder.load_weights()``.
"""
if model_path is None:
model_path = _find_model_path()
vae_dir = model_path / "vae"
# Only load decoder.* keys β skip encoder weights entirely
raw = _load_safetensors_shards(vae_dir, key_filter="decoder.")
weights: list[tuple[str, mx.array]] = []
for key, val in raw.items():
key = key[len("decoder."):]
# Conv2d weight: (O, I, kH, kW) β (O, kH, kW, I)
if val.ndim == 4:
val = val.transpose(0, 2, 3, 1)
# force_upcast: ensure float32 for numerical stability
val = val.astype(mx.float32)
weights.append((key, val))
logger.info("[ZImage] VAE decoder: %d parameters mapped", len(weights))
return weights
|