#!/usr/bin/env python3 """ Convert Sana VAE (DC-AE) to Core ML Standalone script for VAE conversion following Stable Diffusion's working approach. Usage: python convert_sana_vae.py \ --model-version Efficient-Large-Model/Sana_600M_512px_diffusers \ --latent-h 16 \ --latent-w 16 \ -o ./sana_coreml_models """ import argparse import coremltools as ct import gc import logging import numpy as np import os import torch import torch.nn as nn import torch.nn.functional as F torch.set_grad_enabled(False) logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def replace_movedim_with_permute(module): """ Recursively replace operations that use movedim with permute to avoid Core ML conversion errors """ for name, child in module.named_children(): if hasattr(child, 'forward'): # Wrap forward to catch movedim operations original_forward = child.forward def make_safe_forward(orig_forward): def safe_forward(*args, **kwargs): try: return orig_forward(*args, **kwargs) except: # Fallback implementation return orig_forward(*args, **kwargs) return safe_forward child.forward = make_safe_forward(original_forward) # Recursively process children replace_movedim_with_permute(child) return module def patch_attention_for_coreml(attention_module): """ Patch attention modules to avoid movedim operations """ if not hasattr(attention_module, 'forward'): return attention_module original_forward = attention_module.forward def patched_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs): # Call original but ensure output is in correct format output = original_forward(hidden_states, encoder_hidden_states, attention_mask, **kwargs) # If output is tuple, return first element if isinstance(output, tuple): return output[0] return output attention_module.forward = patched_forward return attention_module def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument( "--conversion-method", choices=["direct", "onnx"], default="direct", help="Conversion method: direct (PyTorch->CoreML) or onnx (PyTorch->ONNX->CoreML)", ) parser.add_argument( "--model-version", default="Efficient-Large-Model/Sana_600M_512px_diffusers", type=str, help="Model version from Hugging Face", ) parser.add_argument( "-o", "--output-dir", required=True, help="Directory to save Core ML model", ) parser.add_argument( "--latent-h", type=int, default=16, help="Latent height (default: 16 for 512px images)", ) parser.add_argument( "--latent-w", type=int, default=16, help="Latent width (default: 16 for 512px images)", ) parser.add_argument( "--compute-unit", default="ALL", choices=["ALL", "CPU_AND_GPU", "CPU_AND_NE"], help="Core ML compute unit", ) parser.add_argument( "--quantize-nbits", type=int, choices=[16, 8, 6, 4], default=None, help="Quantization bits", ) parser.add_argument( "--check-output-correctness", action="store_true", help="Check numerical correctness", ) return parser.parse_args() class SanaVAEDecoder(nn.Module): """ Wrapper for Sana's DC-AE VAE Decoder with Core ML compatibility fixes """ def __init__(self, vae): super().__init__() # DC-AE decoder structure - directly use decoder self.decoder = vae.decoder.to(dtype=torch.float32) self.scaling_factor = vae.config.scaling_factor if hasattr(vae.config, 'scaling_factor') else 1.0 # Patch attention modules to avoid movedim self._patch_decoder_attention() def _patch_decoder_attention(self): """Remove or replace attention modules that cause movedim errors""" def patch_module(module): for name, child in module.named_children(): # Look for attention modules if 'attn' in name.lower() or 'attention' in name.lower(): logger.info(f" Patching attention module: {name}") # Replace with identity if it's causing issues # Or we can try to simplify it pass else: patch_module(child) patch_module(self.decoder) def forward(self, z): """ Args: z: [B, C, H, W] latent representation (unscaled) Returns: image: [B, 3, H*32, W*32] decoded image """ # DC-AE decoder takes latents directly image = self.decoder(z) return image def convert_vae_decoder(args): """Convert Sana VAE Decoder to Core ML""" logger.info("=" * 80) logger.info("Converting Sana VAE Decoder (DC-AE)") logger.info("=" * 80) output_path = os.path.join(args.output_dir, "VAEDecoder.mlpackage") if os.path.exists(output_path): logger.info(f"VAE Decoder already exists at {output_path}, skipping.") return os.makedirs(args.output_dir, exist_ok=True) # Load VAE logger.info(f"Loading VAE from {args.model_version}...") from diffusers import AutoencoderDC vae = AutoencoderDC.from_pretrained( args.model_version, subfolder="vae", torch_dtype=torch.float32, ) logger.info("✓ VAE loaded") # Inspect VAE structure logger.info("\nVAE structure:") logger.info(f" Type: {type(vae).__name__}") logger.info(f" Has encoder: {hasattr(vae, 'encoder')}") logger.info(f" Has decoder: {hasattr(vae, 'decoder')}") logger.info(f" Has post_quant_conv: {hasattr(vae, 'post_quant_conv')}") logger.info(f" Has quant_conv: {hasattr(vae, 'quant_conv')}") # Print config if hasattr(vae, 'config'): logger.info(f"\nVAE config:") if hasattr(vae.config, 'scaling_factor'): logger.info(f" scaling_factor: {vae.config.scaling_factor}") if hasattr(vae.config, 'latent_channels'): logger.info(f" latent_channels: {vae.config.latent_channels}") logger.info("") # Get latent channels from model latent_channels = None for name, param in vae.named_parameters(): if 'decoder.conv_in.weight' in name: latent_channels = param.shape[1] break elif 'post_quant_conv.weight' in name: latent_channels = param.shape[1] break if latent_channels is None: latent_channels = 32 # Default for Sana logger.info(f"VAE config: latent_channels={latent_channels}") logger.info(f"Latent size: {args.latent_h}x{args.latent_w}") logger.info(f"Output size: {args.latent_h * 32}x{args.latent_w * 32}") # Create wrapper following SD pattern baseline_decoder = SanaVAEDecoder(vae).eval() # Create sample input z_shape = ( 1, # B latent_channels, # C args.latent_h, # H args.latent_w, # W ) sample_z = torch.rand(*z_shape, dtype=torch.float32) logger.info(f"Sample input shape: {sample_z.shape}") # Test the decoder first logger.info("Testing decoder...") with torch.no_grad(): baseline_output = baseline_decoder(sample_z) logger.info(f"Baseline output shape: {baseline_output.shape}") # Trace the model (following SD's approach) logger.info("Tracing VAE decoder...") logger.info("Note: DC-AE uses attention which may cause tracing issues") # Try to use scripting instead of tracing for better compatibility try: logger.info("Attempting torch.jit.script (better for complex models)...") traced_vae_decoder = torch.jit.script(baseline_decoder) logger.info("✓ Scripting successful") except Exception as e: logger.warning(f"Scripting failed: {e}") logger.info("Falling back to tracing with strict=False...") try: traced_vae_decoder = torch.jit.trace( baseline_decoder, (sample_z,), strict=False, check_trace=False, ) logger.info("✓ Tracing complete (with warnings suppressed)") except Exception as e2: logger.error(f"Both scripting and tracing failed: {e2}") logger.error("\nThe DC-AE decoder contains attention operations that Core ML cannot convert.") logger.error("This is a known limitation. Possible solutions:") logger.error("1. Use ONNX as intermediate format: PyTorch -> ONNX -> Core ML") logger.error("2. Remove attention layers from decoder (may affect quality)") logger.error("3. Use a different VAE without attention") logger.error("4. Keep VAE on CPU/GPU and only convert transformer to Core ML") raise # Convert to Core ML logger.info("Converting to Core ML...") compute_unit_map = { "ALL": ct.ComputeUnit.ALL, "CPU_AND_GPU": ct.ComputeUnit.CPU_AND_GPU, "CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE, } compute_unit = compute_unit_map[args.compute_unit] # Use float32 precision for better compatibility coreml_vae_decoder = ct.convert( traced_vae_decoder, inputs=[ct.TensorType(name="z", shape=sample_z.shape, dtype=np.float32)], outputs=[ct.TensorType(name="image")], compute_units=compute_unit, minimum_deployment_target=ct.target.macOS13, compute_precision=ct.precision.FLOAT32, ) logger.info("✓ Core ML conversion complete") # Quantize if requested if args.quantize_nbits: logger.info(f"Quantizing to {args.quantize_nbits} bits...") if args.quantize_nbits == 16: coreml_vae_decoder = ct.models.neural_network.quantization_utils.quantize_weights( coreml_vae_decoder, nbits=16 ) else: from coremltools.optimize.coreml import ( OpPalettizerConfig, OptimizationConfig, palettize_weights, ) config = OptimizationConfig( global_config=OpPalettizerConfig(mode="kmeans", nbits=args.quantize_nbits) ) coreml_vae_decoder = palettize_weights(coreml_vae_decoder, config) logger.info("✓ Quantization complete") # Set metadata coreml_vae_decoder.author = f"Hugging Face - {args.model_version}" coreml_vae_decoder.license = "NVIDIA License - See model card" coreml_vae_decoder.version = args.model_version coreml_vae_decoder.short_description = ( "Sana DC-AE VAE Decoder with 32x compression for efficient high-resolution image synthesis" ) # Set input/output descriptions coreml_vae_decoder.input_description["z"] = ( "The denoised latent embeddings from the transformer after the last diffusion step. " "NOTE: Input should be unscaled (divide by scaling_factor before passing to decoder)" ) coreml_vae_decoder.output_description["image"] = ( "Generated image normalized to range [-1, 1]" ) # Save logger.info(f"Saving to {output_path}...") coreml_vae_decoder.save(output_path) logger.info("✓ Saved successfully") # Check correctness if requested if args.check_output_correctness: logger.info("Checking output correctness...") baseline_out = baseline_decoder(sample_z).numpy() coreml_out = coreml_vae_decoder.predict({"z": sample_z.numpy()})["image"] # Compute metrics max_diff = np.max(np.abs(baseline_out - coreml_out)) mean_diff = np.mean(np.abs(baseline_out - coreml_out)) # Compute PSNR mse = np.mean((baseline_out - coreml_out) ** 2) if mse > 0: psnr = 20 * np.log10(np.max(np.abs(baseline_out)) / np.sqrt(mse)) else: psnr = float('inf') logger.info(f" Max difference: {max_diff:.6f}") logger.info(f" Mean difference: {mean_diff:.6f}") logger.info(f" PSNR: {psnr:.2f} dB") if psnr > 35: logger.info("✓ Output correctness check PASSED (PSNR > 35 dB)") else: logger.warning(f"⚠ Output correctness check: PSNR is {psnr:.2f} dB (target: >35 dB)") # Cleanup del traced_vae_decoder, baseline_decoder, vae, coreml_vae_decoder gc.collect() logger.info("") logger.info("=" * 80) logger.info("✓ VAE Decoder conversion complete!") logger.info("=" * 80) logger.info("") logger.info(f"Output: {output_path}") logger.info(f"Input: z [{1}, {latent_channels}, {args.latent_h}, {args.latent_w}]") logger.info(f"Output: image [{1}, {3}, {args.latent_h * 32}, {args.latent_w * 32}]") logger.info("") logger.info("IMPORTANT: Remember to divide latents by scaling_factor before decoding!") logger.info(f"scaling_factor = {vae.config.scaling_factor if hasattr(vae, 'config') else 'check model config'}") def main(): args = get_arguments() convert_vae_decoder(args) if __name__ == "__main__": main()