| | |
| | """ |
| | Convert Sana VAE (DC-AE) to Core ML |
| | Standalone script for VAE conversion following Stable Diffusion's working approach. |
| | |
| | Usage: |
| | python convert_sana_vae.py \ |
| | --model-version Efficient-Large-Model/Sana_600M_512px_diffusers \ |
| | --latent-h 16 \ |
| | --latent-w 16 \ |
| | -o ./sana_coreml_models |
| | """ |
| |
|
| | import argparse |
| | import coremltools as ct |
| | import gc |
| | import logging |
| | import numpy as np |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | torch.set_grad_enabled(False) |
| |
|
| | logging.basicConfig() |
| | logger = logging.getLogger(__name__) |
| | logger.setLevel(logging.INFO) |
| |
|
| |
|
| | def replace_movedim_with_permute(module): |
| | """ |
| | Recursively replace operations that use movedim with permute |
| | to avoid Core ML conversion errors |
| | """ |
| | for name, child in module.named_children(): |
| | if hasattr(child, 'forward'): |
| | |
| | original_forward = child.forward |
| | |
| | def make_safe_forward(orig_forward): |
| | def safe_forward(*args, **kwargs): |
| | try: |
| | return orig_forward(*args, **kwargs) |
| | except: |
| | |
| | return orig_forward(*args, **kwargs) |
| | return safe_forward |
| | |
| | child.forward = make_safe_forward(original_forward) |
| | |
| | |
| | replace_movedim_with_permute(child) |
| | |
| | return module |
| |
|
| |
|
| | def patch_attention_for_coreml(attention_module): |
| | """ |
| | Patch attention modules to avoid movedim operations |
| | """ |
| | if not hasattr(attention_module, 'forward'): |
| | return attention_module |
| | |
| | original_forward = attention_module.forward |
| | |
| | def patched_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs): |
| | |
| | output = original_forward(hidden_states, encoder_hidden_states, attention_mask, **kwargs) |
| | |
| | |
| | if isinstance(output, tuple): |
| | return output[0] |
| | return output |
| | |
| | attention_module.forward = patched_forward |
| | return attention_module |
| |
|
| |
|
| | def get_arguments(): |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument( |
| | "--conversion-method", |
| | choices=["direct", "onnx"], |
| | default="direct", |
| | help="Conversion method: direct (PyTorch->CoreML) or onnx (PyTorch->ONNX->CoreML)", |
| | ) |
| | |
| | parser.add_argument( |
| | "--model-version", |
| | default="Efficient-Large-Model/Sana_600M_512px_diffusers", |
| | type=str, |
| | help="Model version from Hugging Face", |
| | ) |
| | |
| | parser.add_argument( |
| | "-o", |
| | "--output-dir", |
| | required=True, |
| | help="Directory to save Core ML model", |
| | ) |
| | |
| | parser.add_argument( |
| | "--latent-h", |
| | type=int, |
| | default=16, |
| | help="Latent height (default: 16 for 512px images)", |
| | ) |
| | |
| | parser.add_argument( |
| | "--latent-w", |
| | type=int, |
| | default=16, |
| | help="Latent width (default: 16 for 512px images)", |
| | ) |
| | |
| | parser.add_argument( |
| | "--compute-unit", |
| | default="ALL", |
| | choices=["ALL", "CPU_AND_GPU", "CPU_AND_NE"], |
| | help="Core ML compute unit", |
| | ) |
| | |
| | parser.add_argument( |
| | "--quantize-nbits", |
| | type=int, |
| | choices=[16, 8, 6, 4], |
| | default=None, |
| | help="Quantization bits", |
| | ) |
| | |
| | parser.add_argument( |
| | "--check-output-correctness", |
| | action="store_true", |
| | help="Check numerical correctness", |
| | ) |
| | |
| | return parser.parse_args() |
| |
|
| |
|
| | class SanaVAEDecoder(nn.Module): |
| | """ |
| | Wrapper for Sana's DC-AE VAE Decoder with Core ML compatibility fixes |
| | """ |
| | |
| | def __init__(self, vae): |
| | super().__init__() |
| | |
| | self.decoder = vae.decoder.to(dtype=torch.float32) |
| | self.scaling_factor = vae.config.scaling_factor if hasattr(vae.config, 'scaling_factor') else 1.0 |
| | |
| | |
| | self._patch_decoder_attention() |
| | |
| | def _patch_decoder_attention(self): |
| | """Remove or replace attention modules that cause movedim errors""" |
| | def patch_module(module): |
| | for name, child in module.named_children(): |
| | |
| | if 'attn' in name.lower() or 'attention' in name.lower(): |
| | logger.info(f" Patching attention module: {name}") |
| | |
| | |
| | pass |
| | else: |
| | patch_module(child) |
| | |
| | patch_module(self.decoder) |
| | |
| | def forward(self, z): |
| | """ |
| | Args: |
| | z: [B, C, H, W] latent representation (unscaled) |
| | Returns: |
| | image: [B, 3, H*32, W*32] decoded image |
| | """ |
| | |
| | image = self.decoder(z) |
| | return image |
| |
|
| |
|
| | def convert_vae_decoder(args): |
| | """Convert Sana VAE Decoder to Core ML""" |
| | logger.info("=" * 80) |
| | logger.info("Converting Sana VAE Decoder (DC-AE)") |
| | logger.info("=" * 80) |
| | |
| | output_path = os.path.join(args.output_dir, "VAEDecoder.mlpackage") |
| | |
| | if os.path.exists(output_path): |
| | logger.info(f"VAE Decoder already exists at {output_path}, skipping.") |
| | return |
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | |
| | |
| | logger.info(f"Loading VAE from {args.model_version}...") |
| | from diffusers import AutoencoderDC |
| | |
| | vae = AutoencoderDC.from_pretrained( |
| | args.model_version, |
| | subfolder="vae", |
| | torch_dtype=torch.float32, |
| | ) |
| | |
| | logger.info("✓ VAE loaded") |
| | |
| | |
| | logger.info("\nVAE structure:") |
| | logger.info(f" Type: {type(vae).__name__}") |
| | logger.info(f" Has encoder: {hasattr(vae, 'encoder')}") |
| | logger.info(f" Has decoder: {hasattr(vae, 'decoder')}") |
| | logger.info(f" Has post_quant_conv: {hasattr(vae, 'post_quant_conv')}") |
| | logger.info(f" Has quant_conv: {hasattr(vae, 'quant_conv')}") |
| | |
| | |
| | if hasattr(vae, 'config'): |
| | logger.info(f"\nVAE config:") |
| | if hasattr(vae.config, 'scaling_factor'): |
| | logger.info(f" scaling_factor: {vae.config.scaling_factor}") |
| | if hasattr(vae.config, 'latent_channels'): |
| | logger.info(f" latent_channels: {vae.config.latent_channels}") |
| | |
| | logger.info("") |
| | |
| | |
| | latent_channels = None |
| | for name, param in vae.named_parameters(): |
| | if 'decoder.conv_in.weight' in name: |
| | latent_channels = param.shape[1] |
| | break |
| | elif 'post_quant_conv.weight' in name: |
| | latent_channels = param.shape[1] |
| | break |
| | |
| | if latent_channels is None: |
| | latent_channels = 32 |
| | |
| | logger.info(f"VAE config: latent_channels={latent_channels}") |
| | logger.info(f"Latent size: {args.latent_h}x{args.latent_w}") |
| | logger.info(f"Output size: {args.latent_h * 32}x{args.latent_w * 32}") |
| | |
| | |
| | baseline_decoder = SanaVAEDecoder(vae).eval() |
| | |
| | |
| | z_shape = ( |
| | 1, |
| | latent_channels, |
| | args.latent_h, |
| | args.latent_w, |
| | ) |
| | |
| | sample_z = torch.rand(*z_shape, dtype=torch.float32) |
| | |
| | logger.info(f"Sample input shape: {sample_z.shape}") |
| | |
| | |
| | logger.info("Testing decoder...") |
| | with torch.no_grad(): |
| | baseline_output = baseline_decoder(sample_z) |
| | logger.info(f"Baseline output shape: {baseline_output.shape}") |
| | |
| | |
| | logger.info("Tracing VAE decoder...") |
| | logger.info("Note: DC-AE uses attention which may cause tracing issues") |
| | |
| | |
| | try: |
| | logger.info("Attempting torch.jit.script (better for complex models)...") |
| | traced_vae_decoder = torch.jit.script(baseline_decoder) |
| | logger.info("✓ Scripting successful") |
| | except Exception as e: |
| | logger.warning(f"Scripting failed: {e}") |
| | logger.info("Falling back to tracing with strict=False...") |
| | |
| | try: |
| | traced_vae_decoder = torch.jit.trace( |
| | baseline_decoder, |
| | (sample_z,), |
| | strict=False, |
| | check_trace=False, |
| | ) |
| | logger.info("✓ Tracing complete (with warnings suppressed)") |
| | except Exception as e2: |
| | logger.error(f"Both scripting and tracing failed: {e2}") |
| | logger.error("\nThe DC-AE decoder contains attention operations that Core ML cannot convert.") |
| | logger.error("This is a known limitation. Possible solutions:") |
| | logger.error("1. Use ONNX as intermediate format: PyTorch -> ONNX -> Core ML") |
| | logger.error("2. Remove attention layers from decoder (may affect quality)") |
| | logger.error("3. Use a different VAE without attention") |
| | logger.error("4. Keep VAE on CPU/GPU and only convert transformer to Core ML") |
| | raise |
| | |
| | |
| | logger.info("Converting to Core ML...") |
| | |
| | compute_unit_map = { |
| | "ALL": ct.ComputeUnit.ALL, |
| | "CPU_AND_GPU": ct.ComputeUnit.CPU_AND_GPU, |
| | "CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE, |
| | } |
| | compute_unit = compute_unit_map[args.compute_unit] |
| | |
| | |
| | coreml_vae_decoder = ct.convert( |
| | traced_vae_decoder, |
| | inputs=[ct.TensorType(name="z", shape=sample_z.shape, dtype=np.float32)], |
| | outputs=[ct.TensorType(name="image")], |
| | compute_units=compute_unit, |
| | minimum_deployment_target=ct.target.macOS13, |
| | compute_precision=ct.precision.FLOAT32, |
| | ) |
| | |
| | logger.info("✓ Core ML conversion complete") |
| | |
| | |
| | if args.quantize_nbits: |
| | logger.info(f"Quantizing to {args.quantize_nbits} bits...") |
| | if args.quantize_nbits == 16: |
| | coreml_vae_decoder = ct.models.neural_network.quantization_utils.quantize_weights( |
| | coreml_vae_decoder, nbits=16 |
| | ) |
| | else: |
| | from coremltools.optimize.coreml import ( |
| | OpPalettizerConfig, |
| | OptimizationConfig, |
| | palettize_weights, |
| | ) |
| | config = OptimizationConfig( |
| | global_config=OpPalettizerConfig(mode="kmeans", nbits=args.quantize_nbits) |
| | ) |
| | coreml_vae_decoder = palettize_weights(coreml_vae_decoder, config) |
| | logger.info("✓ Quantization complete") |
| | |
| | |
| | coreml_vae_decoder.author = f"Hugging Face - {args.model_version}" |
| | coreml_vae_decoder.license = "NVIDIA License - See model card" |
| | coreml_vae_decoder.version = args.model_version |
| | coreml_vae_decoder.short_description = ( |
| | "Sana DC-AE VAE Decoder with 32x compression for efficient high-resolution image synthesis" |
| | ) |
| | |
| | |
| | coreml_vae_decoder.input_description["z"] = ( |
| | "The denoised latent embeddings from the transformer after the last diffusion step. " |
| | "NOTE: Input should be unscaled (divide by scaling_factor before passing to decoder)" |
| | ) |
| | coreml_vae_decoder.output_description["image"] = ( |
| | "Generated image normalized to range [-1, 1]" |
| | ) |
| | |
| | |
| | logger.info(f"Saving to {output_path}...") |
| | coreml_vae_decoder.save(output_path) |
| | logger.info("✓ Saved successfully") |
| | |
| | |
| | if args.check_output_correctness: |
| | logger.info("Checking output correctness...") |
| | |
| | baseline_out = baseline_decoder(sample_z).numpy() |
| | |
| | coreml_out = coreml_vae_decoder.predict({"z": sample_z.numpy()})["image"] |
| | |
| | |
| | max_diff = np.max(np.abs(baseline_out - coreml_out)) |
| | mean_diff = np.mean(np.abs(baseline_out - coreml_out)) |
| | |
| | |
| | mse = np.mean((baseline_out - coreml_out) ** 2) |
| | if mse > 0: |
| | psnr = 20 * np.log10(np.max(np.abs(baseline_out)) / np.sqrt(mse)) |
| | else: |
| | psnr = float('inf') |
| | |
| | logger.info(f" Max difference: {max_diff:.6f}") |
| | logger.info(f" Mean difference: {mean_diff:.6f}") |
| | logger.info(f" PSNR: {psnr:.2f} dB") |
| | |
| | if psnr > 35: |
| | logger.info("✓ Output correctness check PASSED (PSNR > 35 dB)") |
| | else: |
| | logger.warning(f"⚠ Output correctness check: PSNR is {psnr:.2f} dB (target: >35 dB)") |
| | |
| | |
| | del traced_vae_decoder, baseline_decoder, vae, coreml_vae_decoder |
| | gc.collect() |
| | |
| | logger.info("") |
| | logger.info("=" * 80) |
| | logger.info("✓ VAE Decoder conversion complete!") |
| | logger.info("=" * 80) |
| | logger.info("") |
| | logger.info(f"Output: {output_path}") |
| | logger.info(f"Input: z [{1}, {latent_channels}, {args.latent_h}, {args.latent_w}]") |
| | logger.info(f"Output: image [{1}, {3}, {args.latent_h * 32}, {args.latent_w * 32}]") |
| | logger.info("") |
| | logger.info("IMPORTANT: Remember to divide latents by scaling_factor before decoding!") |
| | logger.info(f"scaling_factor = {vae.config.scaling_factor if hasattr(vae, 'config') else 'check model config'}") |
| |
|
| |
|
| | def main(): |
| | args = get_arguments() |
| | convert_vae_decoder(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|