#!/usr/bin/env python3 """ Export DACVAE (audio codec) to ONNX format. This exports the encoder and decoder separately: - Encoder: audio waveform → latent features - Decoder: latent features → audio waveform Usage: python -m onnx_export.export_dacvae --output-dir onnx_models --verify """ import os import argparse import torch import torch.nn as nn import dacvae from huggingface_hub import hf_hub_download # Default DACVAE configuration (matches SAM Audio) DEFAULT_CONFIG = { "encoder_dim": 64, "encoder_rates": [2, 8, 10, 12], "latent_dim": 1024, "decoder_dim": 1536, "decoder_rates": [12, 10, 8, 2], "n_codebooks": 16, "codebook_size": 1024, "codebook_dim": 128, "quantizer_dropout": False, "sample_rate": 48000, } class DACVAEEncoderWrapper(nn.Module): """Wrapper for DACVAE encoder that outputs continuous latent features.""" def __init__(self, encoder, quantizer): super().__init__() self.encoder = encoder self.in_proj = quantizer.in_proj def forward(self, audio: torch.Tensor) -> torch.Tensor: """ Encode audio to latent features. Args: audio: Input waveform, shape (batch, 1, samples) Returns: latent_features: Continuous latent mean, shape (batch, 128, time_steps) """ x = self.encoder(audio) # in_proj outputs 256 dim, chunk into mean and variance, use only mean mean, _ = self.in_proj(x).chunk(2, dim=1) return mean class DACVAEDecoderWrapper(nn.Module): """Wrapper for DACVAE decoder that takes continuous latent features.""" def __init__(self, decoder, quantizer): super().__init__() self.decoder = decoder self.out_proj = quantizer.out_proj def forward(self, latent_features: torch.Tensor) -> torch.Tensor: """ Decode latent features to audio. Args: latent_features: Continuous latent, shape (batch, 128, time_steps) Returns: audio: Output waveform, shape (batch, 1, samples) """ x = self.out_proj(latent_features) return self.decoder(x) def create_dacvae_model(model_id: str = "facebook/sam-audio-small") -> dacvae.DACVAE: """ Create and load DACVAE model with weights from SAM Audio checkpoint. This uses the standalone dacvae library, avoiding loading the full SAM Audio model and its dependencies (vision encoder, imagebind, etc). """ print(f"Creating DACVAE model...") model = dacvae.DACVAE( encoder_dim=DEFAULT_CONFIG["encoder_dim"], encoder_rates=DEFAULT_CONFIG["encoder_rates"], latent_dim=DEFAULT_CONFIG["latent_dim"], decoder_dim=DEFAULT_CONFIG["decoder_dim"], decoder_rates=DEFAULT_CONFIG["decoder_rates"], n_codebooks=DEFAULT_CONFIG["n_codebooks"], codebook_size=DEFAULT_CONFIG["codebook_size"], codebook_dim=DEFAULT_CONFIG["codebook_dim"], quantizer_dropout=DEFAULT_CONFIG["quantizer_dropout"], sample_rate=DEFAULT_CONFIG["sample_rate"], ).eval() # Load weights from SAM Audio checkpoint print(f"Downloading checkpoint from {model_id}...") checkpoint_path = hf_hub_download( repo_id=model_id, filename="checkpoint.pt", ) print("Loading DACVAE weights from checkpoint...") state_dict = torch.load( checkpoint_path, map_location="cpu", weights_only=True, mmap=True, # Memory-efficient loading ) # Extract only DACVAE weights (prefixed with "audio_codec.") dacvae_state_dict = {} for k, v in state_dict.items(): if k.startswith("audio_codec."): new_key = k.replace("audio_codec.", "") dacvae_state_dict[new_key] = v.clone() # Load weights model.load_state_dict(dacvae_state_dict, strict=False) # Clear large checkpoint from memory del state_dict print(f" ✓ Loaded {len(dacvae_state_dict)} DACVAE weight tensors") # Calculate hop_length for reference import numpy as np hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"])) model.hop_length = hop_length model.sample_rate = DEFAULT_CONFIG["sample_rate"] return model def export_encoder( dacvae_model: dacvae.DACVAE, output_path: str, opset_version: int = 21, device: str = "cpu", ) -> None: """Export DACVAE encoder to ONNX.""" print(f"Exporting DACVAE encoder to {output_path}...") wrapper = DACVAEEncoderWrapper( dacvae_model.encoder, dacvae_model.quantizer ).eval().to(device) # Sample input: 1 second of audio at 48kHz sample_rate = DEFAULT_CONFIG["sample_rate"] dummy_audio = torch.randn(1, 1, sample_rate, device=device) torch.onnx.export( wrapper, (dummy_audio,), output_path, input_names=["audio"], output_names=["latent_features"], dynamic_axes={ "audio": {0: "batch", 2: "samples"}, "latent_features": {0: "batch", 2: "time_steps"}, }, opset_version=opset_version, do_constant_folding=True, dynamo=True, external_data=True, ) print(f" ✓ Encoder exported successfully") # Validate import onnx # Load without external data to avoid OOM - we just need to validate structure model = onnx.load(output_path, load_external_data=False) onnx.checker.check_model(model, full_check=False) print(f" ✓ ONNX model validation passed") def export_decoder( dacvae_model: dacvae.DACVAE, output_path: str, opset_version: int = 21, device: str = "cpu", ) -> None: """Export DACVAE decoder to ONNX.""" print(f"Exporting DACVAE decoder to {output_path}...") wrapper = DACVAEDecoderWrapper( dacvae_model.decoder, dacvae_model.quantizer ).eval().to(device) # Sample input: 25 time steps (1 second at 48kHz with hop_length=1920) hop_length = int(__import__("numpy").prod(DEFAULT_CONFIG["encoder_rates"])) time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length dummy_latent = torch.randn(1, 128, time_steps, device=device) torch.onnx.export( wrapper, (dummy_latent,), output_path, input_names=["latent_features"], output_names=["waveform"], dynamic_axes={ "latent_features": {0: "batch", 2: "time_steps"}, "waveform": {0: "batch", 2: "samples"}, }, opset_version=opset_version, do_constant_folding=True, dynamo=True, external_data=True, ) print(f" ✓ Decoder exported successfully") # Validate import onnx # Load without external data to avoid OOM - we just need to validate structure model = onnx.load(output_path, load_external_data=False) onnx.checker.check_model(model, full_check=False) print(f" ✓ ONNX model validation passed") def verify_encoder( dacvae_model: dacvae.DACVAE, onnx_path: str, device: str = "cpu", tolerance: float = 1e-4, ) -> bool: """Verify ONNX encoder output matches PyTorch.""" import onnxruntime as ort import numpy as np print("Verifying encoder output...") wrapper = DACVAEEncoderWrapper( dacvae_model.encoder, dacvae_model.quantizer ).eval().to(device) # Test with random audio sample_rate = DEFAULT_CONFIG["sample_rate"] test_audio = torch.randn(1, 1, sample_rate * 2, device=device) # 2 seconds # PyTorch output with torch.no_grad(): pytorch_output = wrapper(test_audio).cpu().numpy() # ONNX Runtime output sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) onnx_output = sess.run( ["latent_features"], {"audio": test_audio.cpu().numpy()} )[0] # Compare max_diff = np.abs(pytorch_output - onnx_output).max() mean_diff = np.abs(pytorch_output - onnx_output).mean() print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}") if max_diff > tolerance: print(f" ✗ Verification failed (tolerance: {tolerance})") return False print(f" ✓ Verification passed (tolerance: {tolerance})") return True def verify_decoder( dacvae_model: dacvae.DACVAE, onnx_path: str, device: str = "cpu", tolerance: float = 1e-3, ) -> bool: """Verify ONNX decoder output matches PyTorch.""" import onnxruntime as ort import numpy as np print("Verifying decoder output...") wrapper = DACVAEDecoderWrapper( dacvae_model.decoder, dacvae_model.quantizer ).eval().to(device) # Test with random latent hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"])) time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length # 25 steps = 1 second test_latent = torch.randn(1, 128, time_steps, device=device) # PyTorch output with torch.no_grad(): pytorch_output = wrapper(test_latent).cpu().numpy() # ONNX Runtime output sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) onnx_output = sess.run( ["waveform"], {"latent_features": test_latent.cpu().numpy()} )[0] # Compare max_diff = np.abs(pytorch_output - onnx_output).max() mean_diff = np.abs(pytorch_output - onnx_output).mean() print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}") if max_diff > tolerance: print(f" ✗ Verification failed (tolerance: {tolerance})") return False print(f" ✓ Verification passed (tolerance: {tolerance})") return True def main(): parser = argparse.ArgumentParser(description="Export DACVAE to ONNX") parser.add_argument( "--model-id", type=str, default="facebook/sam-audio-small", help="HuggingFace model ID (default: facebook/sam-audio-small)", ) parser.add_argument( "--output-dir", type=str, default="onnx_models", help="Output directory for ONNX models", ) parser.add_argument( "--opset-version", type=int, default=18, help="ONNX opset version (default: 18)", ) parser.add_argument( "--device", type=str, default="cpu", help="Device to use for export (default: cpu)", ) parser.add_argument( "--verify", action="store_true", help="Verify ONNX output matches PyTorch", ) parser.add_argument( "--tolerance", type=float, default=1e-4, help="Tolerance for verification (default: 1e-4)", ) parser.add_argument( "--encoder-only", action="store_true", help="Export only the encoder", ) parser.add_argument( "--decoder-only", action="store_true", help="Export only the decoder", ) args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Load model dacvae_model = create_dacvae_model(args.model_id) print(f"\nDACVAE Configuration:") print(f" Model: {args.model_id}") print(f" Sample rate: {DEFAULT_CONFIG['sample_rate']} Hz") print(f" Hop length: {int(__import__('numpy').prod(DEFAULT_CONFIG['encoder_rates']))}") print(f" Latent dim: 128 (continuous)") # Export encoder if not args.decoder_only: encoder_path = os.path.join(args.output_dir, "dacvae_encoder.onnx") export_encoder( dacvae_model, encoder_path, opset_version=args.opset_version, device=args.device, ) if args.verify: verify_encoder( dacvae_model, encoder_path, device=args.device, tolerance=args.tolerance, ) # Export decoder if not args.encoder_only: decoder_path = os.path.join(args.output_dir, "dacvae_decoder.onnx") export_decoder( dacvae_model, decoder_path, opset_version=args.opset_version, device=args.device, ) if args.verify: verify_decoder( dacvae_model, decoder_path, device=args.device, tolerance=args.tolerance * 10, # Decoder has higher tolerance ) print(f"\n✓ Export complete! Models saved to {args.output_dir}/") if __name__ == "__main__": main()