|
|
|
|
|
""" |
|
|
Convert Sana DC-AE VAE from PyTorch/Diffusers to MLX format |
|
|
|
|
|
This avoids Core ML conversion issues by using MLX, Apple's optimized |
|
|
framework for Apple Silicon. |
|
|
|
|
|
Usage: |
|
|
python convert_vae_to_mlx.py \ |
|
|
--model-version Efficient-Large-Model/Sana_600M_512px_diffusers \ |
|
|
--output sana_vae_mlx.npz |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import torch |
|
|
|
|
|
|
|
|
def get_arguments(): |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
|
"--model-version", |
|
|
default="Efficient-Large-Model/Sana_600M_512px_diffusers", |
|
|
help="Sana model from Hugging Face", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--output", |
|
|
required=True, |
|
|
help="Output .npz file for MLX weights", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--component", |
|
|
choices=["decoder", "encoder", "both"], |
|
|
default="decoder", |
|
|
help="Which component to convert", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def convert_pytorch_to_mlx(pytorch_weights, prefix="decoder."): |
|
|
""" |
|
|
Convert PyTorch weights to MLX format |
|
|
|
|
|
MLX uses channels-last format (BHWC) while PyTorch uses (BCHW) |
|
|
""" |
|
|
mlx_weights = {} |
|
|
|
|
|
for key, value in pytorch_weights.items(): |
|
|
if not key.startswith(prefix): |
|
|
continue |
|
|
|
|
|
|
|
|
mlx_key = key[len(prefix):] |
|
|
|
|
|
|
|
|
if isinstance(value, torch.Tensor): |
|
|
value = value.cpu().numpy() |
|
|
|
|
|
|
|
|
if "conv" in mlx_key and value.ndim == 4: |
|
|
value = np.transpose(value, (0, 2, 3, 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mlx_weights[mlx_key] = value.astype(np.float32) |
|
|
|
|
|
return mlx_weights |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = get_arguments() |
|
|
|
|
|
print("=" * 80) |
|
|
print("Converting Sana VAE to MLX Format") |
|
|
print("=" * 80) |
|
|
print() |
|
|
|
|
|
|
|
|
print(f"Downloading model: {args.model_version}") |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
local_path = snapshot_download( |
|
|
repo_id=args.model_version, |
|
|
allow_patterns=["vae/*"], |
|
|
) |
|
|
|
|
|
print(f"β Downloaded to: {local_path}") |
|
|
print() |
|
|
|
|
|
|
|
|
config_path = Path(local_path) / "vae" / "config.json" |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
print("VAE Configuration:") |
|
|
print(f" Latent channels: {config.get('latent_channels', 32)}") |
|
|
print(f" Scaling factor: {config.get('scaling_factor', 1.0)}") |
|
|
print(f" Block channels: {config.get('block_out_channels')}") |
|
|
print() |
|
|
|
|
|
|
|
|
weights_path = Path(local_path) / "vae" / "diffusion_pytorch_model.safetensors" |
|
|
if not weights_path.exists(): |
|
|
weights_path = Path(local_path) / "vae" / "diffusion_pytorch_model.bin" |
|
|
|
|
|
print(f"Loading weights from: {weights_path.name}") |
|
|
|
|
|
if weights_path.suffix == ".safetensors": |
|
|
from safetensors.torch import load_file |
|
|
pytorch_weights = load_file(str(weights_path)) |
|
|
else: |
|
|
pytorch_weights = torch.load(weights_path, map_location="cpu") |
|
|
|
|
|
print(f"β Loaded {len(pytorch_weights)} weight tensors") |
|
|
print() |
|
|
|
|
|
|
|
|
output_weights = {} |
|
|
|
|
|
if args.component in ["decoder", "both"]: |
|
|
print("Converting decoder weights...") |
|
|
decoder_weights = convert_pytorch_to_mlx(pytorch_weights, "decoder.") |
|
|
print(f" β Converted {len(decoder_weights)} decoder weights") |
|
|
output_weights.update({f"decoder.{k}": v for k, v in decoder_weights.items()}) |
|
|
|
|
|
if args.component in ["encoder", "both"]: |
|
|
print("Converting encoder weights...") |
|
|
encoder_weights = convert_pytorch_to_mlx(pytorch_weights, "encoder.") |
|
|
print(f" β Converted {len(encoder_weights)} encoder weights") |
|
|
output_weights.update({f"encoder.{k}": v for k, v in encoder_weights.items()}) |
|
|
|
|
|
print() |
|
|
|
|
|
|
|
|
output_weights["config"] = json.dumps(config) |
|
|
|
|
|
|
|
|
print(f"Saving MLX weights to: {args.output}") |
|
|
np.savez(args.output, **output_weights) |
|
|
|
|
|
print("β Conversion complete!") |
|
|
print() |
|
|
print("=" * 80) |
|
|
print("Usage Example:") |
|
|
print("=" * 80) |
|
|
print() |
|
|
print("import mlx.core as mx") |
|
|
print("from sana_vae_mlx import DCAEDecoder") |
|
|
print() |
|
|
print(f'weights = np.load("{args.output}")') |
|
|
print("decoder = DCAEDecoder(...)") |
|
|
print("decoder.load_weights([(k, mx.array(v)) for k, v in weights.items()])") |
|
|
print() |
|
|
print("# Or use the built-in loader:") |
|
|
print(f'decoder = DCAEDecoder.from_pretrained("{args.model_version}")') |
|
|
print() |
|
|
print("# Decode latents") |
|
|
print("latents = mx.random.normal((1, 32, 16, 16)) # [B, C, H, W]") |
|
|
print("image = decoder.decode(latents) # [B, 512, 512, 3]") |
|
|
print() |
|
|
|
|
|
|
|
|
total_size = sum(v.nbytes for v in output_weights.values() if isinstance(v, np.ndarray)) |
|
|
print(f"Total size: {total_size / 1024 / 1024:.1f} MB") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|