FastVLM_SANA / ml-stable-diffusion /convert_vae_to_mlx.py
Fahad-S's picture
Upload ml-stable-diffusion/convert_vae_to_mlx.py with huggingface_hub
5723b4b verified
#!/usr/bin/env python3
"""
Convert Sana DC-AE VAE from PyTorch/Diffusers to MLX format
This avoids Core ML conversion issues by using MLX, Apple's optimized
framework for Apple Silicon.
Usage:
python convert_vae_to_mlx.py \
--model-version Efficient-Large-Model/Sana_600M_512px_diffusers \
--output sana_vae_mlx.npz
"""
import argparse
import json
import numpy as np
from pathlib import Path
import torch
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-version",
default="Efficient-Large-Model/Sana_600M_512px_diffusers",
help="Sana model from Hugging Face",
)
parser.add_argument(
"--output",
required=True,
help="Output .npz file for MLX weights",
)
parser.add_argument(
"--component",
choices=["decoder", "encoder", "both"],
default="decoder",
help="Which component to convert",
)
return parser.parse_args()
def convert_pytorch_to_mlx(pytorch_weights, prefix="decoder."):
"""
Convert PyTorch weights to MLX format
MLX uses channels-last format (BHWC) while PyTorch uses (BCHW)
"""
mlx_weights = {}
for key, value in pytorch_weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
mlx_key = key[len(prefix):]
# Convert tensor to numpy
if isinstance(value, torch.Tensor):
value = value.cpu().numpy()
# Convert Conv2d weights from (out_c, in_c, h, w) to (out_c, h, w, in_c)
if "conv" in mlx_key and value.ndim == 4:
value = np.transpose(value, (0, 2, 3, 1))
# Convert Linear weights - MLX uses (out, in) same as PyTorch
# So no conversion needed for linear layers
mlx_weights[mlx_key] = value.astype(np.float32)
return mlx_weights
def main():
args = get_arguments()
print("=" * 80)
print("Converting Sana VAE to MLX Format")
print("=" * 80)
print()
# Download model
print(f"Downloading model: {args.model_version}")
from huggingface_hub import snapshot_download
local_path = snapshot_download(
repo_id=args.model_version,
allow_patterns=["vae/*"],
)
print(f"βœ“ Downloaded to: {local_path}")
print()
# Load config
config_path = Path(local_path) / "vae" / "config.json"
with open(config_path) as f:
config = json.load(f)
print("VAE Configuration:")
print(f" Latent channels: {config.get('latent_channels', 32)}")
print(f" Scaling factor: {config.get('scaling_factor', 1.0)}")
print(f" Block channels: {config.get('block_out_channels')}")
print()
# Load PyTorch weights
weights_path = Path(local_path) / "vae" / "diffusion_pytorch_model.safetensors"
if not weights_path.exists():
weights_path = Path(local_path) / "vae" / "diffusion_pytorch_model.bin"
print(f"Loading weights from: {weights_path.name}")
if weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
pytorch_weights = load_file(str(weights_path))
else:
pytorch_weights = torch.load(weights_path, map_location="cpu")
print(f"βœ“ Loaded {len(pytorch_weights)} weight tensors")
print()
# Convert weights
output_weights = {}
if args.component in ["decoder", "both"]:
print("Converting decoder weights...")
decoder_weights = convert_pytorch_to_mlx(pytorch_weights, "decoder.")
print(f" βœ“ Converted {len(decoder_weights)} decoder weights")
output_weights.update({f"decoder.{k}": v for k, v in decoder_weights.items()})
if args.component in ["encoder", "both"]:
print("Converting encoder weights...")
encoder_weights = convert_pytorch_to_mlx(pytorch_weights, "encoder.")
print(f" βœ“ Converted {len(encoder_weights)} encoder weights")
output_weights.update({f"encoder.{k}": v for k, v in encoder_weights.items()})
print()
# Add config to weights
output_weights["config"] = json.dumps(config)
# Save MLX weights
print(f"Saving MLX weights to: {args.output}")
np.savez(args.output, **output_weights)
print("βœ“ Conversion complete!")
print()
print("=" * 80)
print("Usage Example:")
print("=" * 80)
print()
print("import mlx.core as mx")
print("from sana_vae_mlx import DCAEDecoder")
print()
print(f'weights = np.load("{args.output}")')
print("decoder = DCAEDecoder(...)")
print("decoder.load_weights([(k, mx.array(v)) for k, v in weights.items()])")
print()
print("# Or use the built-in loader:")
print(f'decoder = DCAEDecoder.from_pretrained("{args.model_version}")')
print()
print("# Decode latents")
print("latents = mx.random.normal((1, 32, 16, 16)) # [B, C, H, W]")
print("image = decoder.decode(latents) # [B, 512, 512, 3]")
print()
# Print size info
total_size = sum(v.nbytes for v in output_weights.values() if isinstance(v, np.ndarray))
print(f"Total size: {total_size / 1024 / 1024:.1f} MB")
if __name__ == "__main__":
main()