| """Weight loader for Z-Image-Turbo MLX backend. |
| |
| Loads safetensors weights from HuggingFace cache and maps them |
| to the MLX module hierarchy. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import glob |
| import logging |
| from pathlib import Path |
|
|
| import mlx.core as mx |
|
|
| logger = logging.getLogger("zimage-mlx") |
|
|
| |
| _DEFAULT_MODEL_ID = "Tongyi-MAI/Z-Image-Turbo" |
| _HF_CACHE = Path.home() / ".cache" / "huggingface" / "hub" |
|
|
| |
| _LOCAL_WEIGHTS_DIR = Path(__file__).parent / "weights" |
|
|
|
|
| def _find_model_path(model_id: str = _DEFAULT_MODEL_ID) -> Path: |
| """Find local weight path for a model. |
| |
| Priority: |
| 1. Project-local ``backends/mlx_zimage/weights/`` (if text_encoder/ exists) |
| 2. HF cache ``~/.cache/huggingface/hub/models--Tongyi-MAI--Z-Image-Turbo/`` |
| """ |
| |
| if _LOCAL_WEIGHTS_DIR.is_dir() and (_LOCAL_WEIGHTS_DIR / "text_encoder").is_dir(): |
| logger.info("[ZImage] Using local weights: %s", _LOCAL_WEIGHTS_DIR) |
| return _LOCAL_WEIGHTS_DIR |
|
|
| |
| safe_id = model_id.replace("/", "--") |
| model_dir = _HF_CACHE / f"models--{safe_id}" |
| if not model_dir.exists(): |
| raise FileNotFoundError( |
| f"Model not found. Neither local ({_LOCAL_WEIGHTS_DIR}) " |
| f"nor HF cache ({model_dir}) available." |
| ) |
| |
| snapshots = sorted(model_dir.glob("snapshots/*"), key=lambda p: p.stat().st_mtime, reverse=True) |
| if not snapshots: |
| raise FileNotFoundError(f"No snapshots found in {model_dir}") |
| logger.info("[ZImage] Using HF cache: %s", snapshots[0]) |
| return snapshots[0] |
|
|
|
|
| def _log_memory(label: str) -> None: |
| """Log Metal memory usage (safe no-op if unavailable).""" |
| try: |
| active = mx.metal.get_active_memory() / (1024 ** 3) |
| peak = mx.metal.get_peak_memory() / (1024 ** 3) |
| logger.info("[ZImage] MEM %s: active=%.2f GB, peak=%.2f GB", label, active, peak) |
| except Exception: |
| pass |
|
|
|
|
| def _load_safetensors_shards( |
| shard_dir: Path, |
| pattern: str = "*.safetensors", |
| *, |
| key_filter: str | None = None, |
| ) -> dict[str, mx.array]: |
| """Load safetensors files via mx.load() β zero-copy, preserves bfloat16. |
| |
| Args: |
| shard_dir: Directory containing safetensors shard files. |
| pattern: Glob pattern for shard files. |
| key_filter: If set, only load keys starting with this prefix. |
| """ |
| files = sorted(shard_dir.glob(pattern)) |
| if not files: |
| raise FileNotFoundError(f"No safetensors files in {shard_dir}") |
|
|
| params: dict[str, mx.array] = {} |
| for f in files: |
| |
| shard = mx.load(str(f)) |
| if key_filter: |
| shard = {k: v for k, v in shard.items() if k.startswith(key_filter)} |
| params.update(shard) |
| logger.info("[ZImage] Loaded shard %s (%d keys)", f.name, len(shard)) |
|
|
| logger.info("[ZImage] Total: %d keys from %d files in %s", len(params), len(files), shard_dir.name) |
| _log_memory(f"after loading {shard_dir.name}") |
| return params |
|
|
|
|
| |
|
|
| def load_text_encoder_weights(model_path: Path | None = None) -> dict[str, mx.array]: |
| """Load and map Qwen3 text encoder weights for MLX. |
| |
| The safetensors keys use the pattern: |
| model.embed_tokens.weight |
| model.layers.N.input_layernorm.weight |
| model.layers.N.self_attn.q_proj.weight |
| ... |
| model.norm.weight |
| |
| Our MLX module uses: |
| embed_tokens.weight |
| layers.N.input_layernorm.weight |
| layers.N.self_attn.q_proj.weight |
| ... |
| norm.weight |
| |
| So we strip the leading "model." prefix. |
| """ |
| if model_path is None: |
| model_path = _find_model_path() |
|
|
| te_dir = model_path / "text_encoder" |
| raw = _load_safetensors_shards(te_dir, "model-*.safetensors") |
|
|
| mapped: dict[str, mx.array] = {} |
| for key, tensor in raw.items(): |
| |
| if key.startswith("model."): |
| new_key = key[len("model."):] |
| else: |
| new_key = key |
|
|
| mapped[new_key] = tensor |
|
|
| logger.info("[ZImage] Text encoder: %d parameters mapped", len(mapped)) |
| return mapped |
|
|
|
|
| |
|
|
| def load_transformer_weights(model_path: Path | None = None) -> dict[str, mx.array]: |
| """Load ZImageTransformer2DModel weights.""" |
| if model_path is None: |
| model_path = _find_model_path() |
|
|
| dit_dir = model_path / "transformer" |
| raw = _load_safetensors_shards(dit_dir, "diffusion_pytorch_model-*.safetensors") |
|
|
| |
| logger.info("[ZImage] Transformer: %d parameters loaded", len(raw)) |
| return raw |
|
|
|
|
| |
|
|
| def load_vae_weights(model_path: Path | None = None) -> dict[str, mx.array]: |
| """Load AutoencoderKL weights.""" |
| if model_path is None: |
| model_path = _find_model_path() |
|
|
| vae_dir = model_path / "vae" |
| raw = _load_safetensors_shards(vae_dir) |
|
|
| logger.info("[ZImage] VAE: %d parameters loaded", len(raw)) |
| return raw |
|
|
|
|
| def load_vae_decoder_weights(model_path: Path | None = None) -> list[tuple[str, mx.array]]: |
| """Load VAE decoder weights, mapped for the MLX Decoder module. |
| |
| Only loads keys starting with ``decoder.`` (skips encoder weights |
| to avoid wasting memory). Performs two transformations: |
| 1. Strips the ``decoder.`` prefix so keys match the Decoder module tree. |
| 2. Transposes Conv2d weights from PyTorch (O,I,kH,kW) β MLX (O,kH,kW,I). |
| |
| Returns a list of (key, array) tuples ready for ``Decoder.load_weights()``. |
| """ |
| if model_path is None: |
| model_path = _find_model_path() |
|
|
| vae_dir = model_path / "vae" |
| |
| raw = _load_safetensors_shards(vae_dir, key_filter="decoder.") |
|
|
| weights: list[tuple[str, mx.array]] = [] |
| for key, val in raw.items(): |
| key = key[len("decoder."):] |
|
|
| |
| if val.ndim == 4: |
| val = val.transpose(0, 2, 3, 1) |
|
|
| |
| val = val.astype(mx.float32) |
|
|
| weights.append((key, val)) |
|
|
| logger.info("[ZImage] VAE decoder: %d parameters mapped", len(weights)) |
| return weights |
|
|