FastVLM_SANA / ml-stable-diffusion /convert_sana_vae.py
Fahad-S's picture
Upload ml-stable-diffusion/convert_sana_vae.py with huggingface_hub
f05ec5a verified
#!/usr/bin/env python3
"""
Convert Sana VAE (DC-AE) to Core ML
Standalone script for VAE conversion following Stable Diffusion's working approach.
Usage:
python convert_sana_vae.py \
--model-version Efficient-Large-Model/Sana_600M_512px_diffusers \
--latent-h 16 \
--latent-w 16 \
-o ./sana_coreml_models
"""
import argparse
import coremltools as ct
import gc
import logging
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_grad_enabled(False)
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def replace_movedim_with_permute(module):
"""
Recursively replace operations that use movedim with permute
to avoid Core ML conversion errors
"""
for name, child in module.named_children():
if hasattr(child, 'forward'):
# Wrap forward to catch movedim operations
original_forward = child.forward
def make_safe_forward(orig_forward):
def safe_forward(*args, **kwargs):
try:
return orig_forward(*args, **kwargs)
except:
# Fallback implementation
return orig_forward(*args, **kwargs)
return safe_forward
child.forward = make_safe_forward(original_forward)
# Recursively process children
replace_movedim_with_permute(child)
return module
def patch_attention_for_coreml(attention_module):
"""
Patch attention modules to avoid movedim operations
"""
if not hasattr(attention_module, 'forward'):
return attention_module
original_forward = attention_module.forward
def patched_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
# Call original but ensure output is in correct format
output = original_forward(hidden_states, encoder_hidden_states, attention_mask, **kwargs)
# If output is tuple, return first element
if isinstance(output, tuple):
return output[0]
return output
attention_module.forward = patched_forward
return attention_module
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--conversion-method",
choices=["direct", "onnx"],
default="direct",
help="Conversion method: direct (PyTorch->CoreML) or onnx (PyTorch->ONNX->CoreML)",
)
parser.add_argument(
"--model-version",
default="Efficient-Large-Model/Sana_600M_512px_diffusers",
type=str,
help="Model version from Hugging Face",
)
parser.add_argument(
"-o",
"--output-dir",
required=True,
help="Directory to save Core ML model",
)
parser.add_argument(
"--latent-h",
type=int,
default=16,
help="Latent height (default: 16 for 512px images)",
)
parser.add_argument(
"--latent-w",
type=int,
default=16,
help="Latent width (default: 16 for 512px images)",
)
parser.add_argument(
"--compute-unit",
default="ALL",
choices=["ALL", "CPU_AND_GPU", "CPU_AND_NE"],
help="Core ML compute unit",
)
parser.add_argument(
"--quantize-nbits",
type=int,
choices=[16, 8, 6, 4],
default=None,
help="Quantization bits",
)
parser.add_argument(
"--check-output-correctness",
action="store_true",
help="Check numerical correctness",
)
return parser.parse_args()
class SanaVAEDecoder(nn.Module):
"""
Wrapper for Sana's DC-AE VAE Decoder with Core ML compatibility fixes
"""
def __init__(self, vae):
super().__init__()
# DC-AE decoder structure - directly use decoder
self.decoder = vae.decoder.to(dtype=torch.float32)
self.scaling_factor = vae.config.scaling_factor if hasattr(vae.config, 'scaling_factor') else 1.0
# Patch attention modules to avoid movedim
self._patch_decoder_attention()
def _patch_decoder_attention(self):
"""Remove or replace attention modules that cause movedim errors"""
def patch_module(module):
for name, child in module.named_children():
# Look for attention modules
if 'attn' in name.lower() or 'attention' in name.lower():
logger.info(f" Patching attention module: {name}")
# Replace with identity if it's causing issues
# Or we can try to simplify it
pass
else:
patch_module(child)
patch_module(self.decoder)
def forward(self, z):
"""
Args:
z: [B, C, H, W] latent representation (unscaled)
Returns:
image: [B, 3, H*32, W*32] decoded image
"""
# DC-AE decoder takes latents directly
image = self.decoder(z)
return image
def convert_vae_decoder(args):
"""Convert Sana VAE Decoder to Core ML"""
logger.info("=" * 80)
logger.info("Converting Sana VAE Decoder (DC-AE)")
logger.info("=" * 80)
output_path = os.path.join(args.output_dir, "VAEDecoder.mlpackage")
if os.path.exists(output_path):
logger.info(f"VAE Decoder already exists at {output_path}, skipping.")
return
os.makedirs(args.output_dir, exist_ok=True)
# Load VAE
logger.info(f"Loading VAE from {args.model_version}...")
from diffusers import AutoencoderDC
vae = AutoencoderDC.from_pretrained(
args.model_version,
subfolder="vae",
torch_dtype=torch.float32,
)
logger.info("✓ VAE loaded")
# Inspect VAE structure
logger.info("\nVAE structure:")
logger.info(f" Type: {type(vae).__name__}")
logger.info(f" Has encoder: {hasattr(vae, 'encoder')}")
logger.info(f" Has decoder: {hasattr(vae, 'decoder')}")
logger.info(f" Has post_quant_conv: {hasattr(vae, 'post_quant_conv')}")
logger.info(f" Has quant_conv: {hasattr(vae, 'quant_conv')}")
# Print config
if hasattr(vae, 'config'):
logger.info(f"\nVAE config:")
if hasattr(vae.config, 'scaling_factor'):
logger.info(f" scaling_factor: {vae.config.scaling_factor}")
if hasattr(vae.config, 'latent_channels'):
logger.info(f" latent_channels: {vae.config.latent_channels}")
logger.info("")
# Get latent channels from model
latent_channels = None
for name, param in vae.named_parameters():
if 'decoder.conv_in.weight' in name:
latent_channels = param.shape[1]
break
elif 'post_quant_conv.weight' in name:
latent_channels = param.shape[1]
break
if latent_channels is None:
latent_channels = 32 # Default for Sana
logger.info(f"VAE config: latent_channels={latent_channels}")
logger.info(f"Latent size: {args.latent_h}x{args.latent_w}")
logger.info(f"Output size: {args.latent_h * 32}x{args.latent_w * 32}")
# Create wrapper following SD pattern
baseline_decoder = SanaVAEDecoder(vae).eval()
# Create sample input
z_shape = (
1, # B
latent_channels, # C
args.latent_h, # H
args.latent_w, # W
)
sample_z = torch.rand(*z_shape, dtype=torch.float32)
logger.info(f"Sample input shape: {sample_z.shape}")
# Test the decoder first
logger.info("Testing decoder...")
with torch.no_grad():
baseline_output = baseline_decoder(sample_z)
logger.info(f"Baseline output shape: {baseline_output.shape}")
# Trace the model (following SD's approach)
logger.info("Tracing VAE decoder...")
logger.info("Note: DC-AE uses attention which may cause tracing issues")
# Try to use scripting instead of tracing for better compatibility
try:
logger.info("Attempting torch.jit.script (better for complex models)...")
traced_vae_decoder = torch.jit.script(baseline_decoder)
logger.info("✓ Scripting successful")
except Exception as e:
logger.warning(f"Scripting failed: {e}")
logger.info("Falling back to tracing with strict=False...")
try:
traced_vae_decoder = torch.jit.trace(
baseline_decoder,
(sample_z,),
strict=False,
check_trace=False,
)
logger.info("✓ Tracing complete (with warnings suppressed)")
except Exception as e2:
logger.error(f"Both scripting and tracing failed: {e2}")
logger.error("\nThe DC-AE decoder contains attention operations that Core ML cannot convert.")
logger.error("This is a known limitation. Possible solutions:")
logger.error("1. Use ONNX as intermediate format: PyTorch -> ONNX -> Core ML")
logger.error("2. Remove attention layers from decoder (may affect quality)")
logger.error("3. Use a different VAE without attention")
logger.error("4. Keep VAE on CPU/GPU and only convert transformer to Core ML")
raise
# Convert to Core ML
logger.info("Converting to Core ML...")
compute_unit_map = {
"ALL": ct.ComputeUnit.ALL,
"CPU_AND_GPU": ct.ComputeUnit.CPU_AND_GPU,
"CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE,
}
compute_unit = compute_unit_map[args.compute_unit]
# Use float32 precision for better compatibility
coreml_vae_decoder = ct.convert(
traced_vae_decoder,
inputs=[ct.TensorType(name="z", shape=sample_z.shape, dtype=np.float32)],
outputs=[ct.TensorType(name="image")],
compute_units=compute_unit,
minimum_deployment_target=ct.target.macOS13,
compute_precision=ct.precision.FLOAT32,
)
logger.info("✓ Core ML conversion complete")
# Quantize if requested
if args.quantize_nbits:
logger.info(f"Quantizing to {args.quantize_nbits} bits...")
if args.quantize_nbits == 16:
coreml_vae_decoder = ct.models.neural_network.quantization_utils.quantize_weights(
coreml_vae_decoder, nbits=16
)
else:
from coremltools.optimize.coreml import (
OpPalettizerConfig,
OptimizationConfig,
palettize_weights,
)
config = OptimizationConfig(
global_config=OpPalettizerConfig(mode="kmeans", nbits=args.quantize_nbits)
)
coreml_vae_decoder = palettize_weights(coreml_vae_decoder, config)
logger.info("✓ Quantization complete")
# Set metadata
coreml_vae_decoder.author = f"Hugging Face - {args.model_version}"
coreml_vae_decoder.license = "NVIDIA License - See model card"
coreml_vae_decoder.version = args.model_version
coreml_vae_decoder.short_description = (
"Sana DC-AE VAE Decoder with 32x compression for efficient high-resolution image synthesis"
)
# Set input/output descriptions
coreml_vae_decoder.input_description["z"] = (
"The denoised latent embeddings from the transformer after the last diffusion step. "
"NOTE: Input should be unscaled (divide by scaling_factor before passing to decoder)"
)
coreml_vae_decoder.output_description["image"] = (
"Generated image normalized to range [-1, 1]"
)
# Save
logger.info(f"Saving to {output_path}...")
coreml_vae_decoder.save(output_path)
logger.info("✓ Saved successfully")
# Check correctness if requested
if args.check_output_correctness:
logger.info("Checking output correctness...")
baseline_out = baseline_decoder(sample_z).numpy()
coreml_out = coreml_vae_decoder.predict({"z": sample_z.numpy()})["image"]
# Compute metrics
max_diff = np.max(np.abs(baseline_out - coreml_out))
mean_diff = np.mean(np.abs(baseline_out - coreml_out))
# Compute PSNR
mse = np.mean((baseline_out - coreml_out) ** 2)
if mse > 0:
psnr = 20 * np.log10(np.max(np.abs(baseline_out)) / np.sqrt(mse))
else:
psnr = float('inf')
logger.info(f" Max difference: {max_diff:.6f}")
logger.info(f" Mean difference: {mean_diff:.6f}")
logger.info(f" PSNR: {psnr:.2f} dB")
if psnr > 35:
logger.info("✓ Output correctness check PASSED (PSNR > 35 dB)")
else:
logger.warning(f"⚠ Output correctness check: PSNR is {psnr:.2f} dB (target: >35 dB)")
# Cleanup
del traced_vae_decoder, baseline_decoder, vae, coreml_vae_decoder
gc.collect()
logger.info("")
logger.info("=" * 80)
logger.info("✓ VAE Decoder conversion complete!")
logger.info("=" * 80)
logger.info("")
logger.info(f"Output: {output_path}")
logger.info(f"Input: z [{1}, {latent_channels}, {args.latent_h}, {args.latent_w}]")
logger.info(f"Output: image [{1}, {3}, {args.latent_h * 32}, {args.latent_w * 32}]")
logger.info("")
logger.info("IMPORTANT: Remember to divide latents by scaling_factor before decoding!")
logger.info(f"scaling_factor = {vae.config.scaling_factor if hasattr(vae, 'config') else 'check model config'}")
def main():
args = get_arguments()
convert_vae_decoder(args)
if __name__ == "__main__":
main()