#!/usr/bin/env python3 """ Convert Sana DC-AE VAE from PyTorch/Diffusers to MLX format This avoids Core ML conversion issues by using MLX, Apple's optimized framework for Apple Silicon. Usage: python convert_vae_to_mlx.py \ --model-version Efficient-Large-Model/Sana_600M_512px_diffusers \ --output sana_vae_mlx.npz """ import argparse import json import numpy as np from pathlib import Path import torch def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument( "--model-version", default="Efficient-Large-Model/Sana_600M_512px_diffusers", help="Sana model from Hugging Face", ) parser.add_argument( "--output", required=True, help="Output .npz file for MLX weights", ) parser.add_argument( "--component", choices=["decoder", "encoder", "both"], default="decoder", help="Which component to convert", ) return parser.parse_args() def convert_pytorch_to_mlx(pytorch_weights, prefix="decoder."): """ Convert PyTorch weights to MLX format MLX uses channels-last format (BHWC) while PyTorch uses (BCHW) """ mlx_weights = {} for key, value in pytorch_weights.items(): if not key.startswith(prefix): continue # Remove prefix mlx_key = key[len(prefix):] # Convert tensor to numpy if isinstance(value, torch.Tensor): value = value.cpu().numpy() # Convert Conv2d weights from (out_c, in_c, h, w) to (out_c, h, w, in_c) if "conv" in mlx_key and value.ndim == 4: value = np.transpose(value, (0, 2, 3, 1)) # Convert Linear weights - MLX uses (out, in) same as PyTorch # So no conversion needed for linear layers mlx_weights[mlx_key] = value.astype(np.float32) return mlx_weights def main(): args = get_arguments() print("=" * 80) print("Converting Sana VAE to MLX Format") print("=" * 80) print() # Download model print(f"Downloading model: {args.model_version}") from huggingface_hub import snapshot_download local_path = snapshot_download( repo_id=args.model_version, allow_patterns=["vae/*"], ) print(f"✓ Downloaded to: {local_path}") print() # Load config config_path = Path(local_path) / "vae" / "config.json" with open(config_path) as f: config = json.load(f) print("VAE Configuration:") print(f" Latent channels: {config.get('latent_channels', 32)}") print(f" Scaling factor: {config.get('scaling_factor', 1.0)}") print(f" Block channels: {config.get('block_out_channels')}") print() # Load PyTorch weights weights_path = Path(local_path) / "vae" / "diffusion_pytorch_model.safetensors" if not weights_path.exists(): weights_path = Path(local_path) / "vae" / "diffusion_pytorch_model.bin" print(f"Loading weights from: {weights_path.name}") if weights_path.suffix == ".safetensors": from safetensors.torch import load_file pytorch_weights = load_file(str(weights_path)) else: pytorch_weights = torch.load(weights_path, map_location="cpu") print(f"✓ Loaded {len(pytorch_weights)} weight tensors") print() # Convert weights output_weights = {} if args.component in ["decoder", "both"]: print("Converting decoder weights...") decoder_weights = convert_pytorch_to_mlx(pytorch_weights, "decoder.") print(f" ✓ Converted {len(decoder_weights)} decoder weights") output_weights.update({f"decoder.{k}": v for k, v in decoder_weights.items()}) if args.component in ["encoder", "both"]: print("Converting encoder weights...") encoder_weights = convert_pytorch_to_mlx(pytorch_weights, "encoder.") print(f" ✓ Converted {len(encoder_weights)} encoder weights") output_weights.update({f"encoder.{k}": v for k, v in encoder_weights.items()}) print() # Add config to weights output_weights["config"] = json.dumps(config) # Save MLX weights print(f"Saving MLX weights to: {args.output}") np.savez(args.output, **output_weights) print("✓ Conversion complete!") print() print("=" * 80) print("Usage Example:") print("=" * 80) print() print("import mlx.core as mx") print("from sana_vae_mlx import DCAEDecoder") print() print(f'weights = np.load("{args.output}")') print("decoder = DCAEDecoder(...)") print("decoder.load_weights([(k, mx.array(v)) for k, v in weights.items()])") print() print("# Or use the built-in loader:") print(f'decoder = DCAEDecoder.from_pretrained("{args.model_version}")') print() print("# Decode latents") print("latents = mx.random.normal((1, 32, 16, 16)) # [B, C, H, W]") print("image = decoder.decode(latents) # [B, 512, 512, 3]") print() # Print size info total_size = sum(v.nbytes for v in output_weights.values() if isinstance(v, np.ndarray)) print(f"Total size: {total_size / 1024 / 1024:.1f} MB") if __name__ == "__main__": main()