matbee's picture
Upload folder using huggingface_hub
ba60410 verified
#!/usr/bin/env python3
"""
Export DACVAE (audio codec) to ONNX format.
This exports the encoder and decoder separately:
- Encoder: audio waveform → latent features
- Decoder: latent features → audio waveform
Usage:
python -m onnx_export.export_dacvae --output-dir onnx_models --verify
"""
import os
import argparse
import torch
import torch.nn as nn
import dacvae
from huggingface_hub import hf_hub_download
# Default DACVAE configuration (matches SAM Audio)
DEFAULT_CONFIG = {
"encoder_dim": 64,
"encoder_rates": [2, 8, 10, 12],
"latent_dim": 1024,
"decoder_dim": 1536,
"decoder_rates": [12, 10, 8, 2],
"n_codebooks": 16,
"codebook_size": 1024,
"codebook_dim": 128,
"quantizer_dropout": False,
"sample_rate": 48000,
}
class DACVAEEncoderWrapper(nn.Module):
"""Wrapper for DACVAE encoder that outputs continuous latent features."""
def __init__(self, encoder, quantizer):
super().__init__()
self.encoder = encoder
self.in_proj = quantizer.in_proj
def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
Encode audio to latent features.
Args:
audio: Input waveform, shape (batch, 1, samples)
Returns:
latent_features: Continuous latent mean, shape (batch, 128, time_steps)
"""
x = self.encoder(audio)
# in_proj outputs 256 dim, chunk into mean and variance, use only mean
mean, _ = self.in_proj(x).chunk(2, dim=1)
return mean
class DACVAEDecoderWrapper(nn.Module):
"""Wrapper for DACVAE decoder that takes continuous latent features."""
def __init__(self, decoder, quantizer):
super().__init__()
self.decoder = decoder
self.out_proj = quantizer.out_proj
def forward(self, latent_features: torch.Tensor) -> torch.Tensor:
"""
Decode latent features to audio.
Args:
latent_features: Continuous latent, shape (batch, 128, time_steps)
Returns:
audio: Output waveform, shape (batch, 1, samples)
"""
x = self.out_proj(latent_features)
return self.decoder(x)
def create_dacvae_model(model_id: str = "facebook/sam-audio-small") -> dacvae.DACVAE:
"""
Create and load DACVAE model with weights from SAM Audio checkpoint.
This uses the standalone dacvae library, avoiding loading the full SAM Audio
model and its dependencies (vision encoder, imagebind, etc).
"""
print(f"Creating DACVAE model...")
model = dacvae.DACVAE(
encoder_dim=DEFAULT_CONFIG["encoder_dim"],
encoder_rates=DEFAULT_CONFIG["encoder_rates"],
latent_dim=DEFAULT_CONFIG["latent_dim"],
decoder_dim=DEFAULT_CONFIG["decoder_dim"],
decoder_rates=DEFAULT_CONFIG["decoder_rates"],
n_codebooks=DEFAULT_CONFIG["n_codebooks"],
codebook_size=DEFAULT_CONFIG["codebook_size"],
codebook_dim=DEFAULT_CONFIG["codebook_dim"],
quantizer_dropout=DEFAULT_CONFIG["quantizer_dropout"],
sample_rate=DEFAULT_CONFIG["sample_rate"],
).eval()
# Load weights from SAM Audio checkpoint
print(f"Downloading checkpoint from {model_id}...")
checkpoint_path = hf_hub_download(
repo_id=model_id,
filename="checkpoint.pt",
)
print("Loading DACVAE weights from checkpoint...")
state_dict = torch.load(
checkpoint_path,
map_location="cpu",
weights_only=True,
mmap=True, # Memory-efficient loading
)
# Extract only DACVAE weights (prefixed with "audio_codec.")
dacvae_state_dict = {}
for k, v in state_dict.items():
if k.startswith("audio_codec."):
new_key = k.replace("audio_codec.", "")
dacvae_state_dict[new_key] = v.clone()
# Load weights
model.load_state_dict(dacvae_state_dict, strict=False)
# Clear large checkpoint from memory
del state_dict
print(f" ✓ Loaded {len(dacvae_state_dict)} DACVAE weight tensors")
# Calculate hop_length for reference
import numpy as np
hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"]))
model.hop_length = hop_length
model.sample_rate = DEFAULT_CONFIG["sample_rate"]
return model
def export_encoder(
dacvae_model: dacvae.DACVAE,
output_path: str,
opset_version: int = 21,
device: str = "cpu",
) -> None:
"""Export DACVAE encoder to ONNX."""
print(f"Exporting DACVAE encoder to {output_path}...")
wrapper = DACVAEEncoderWrapper(
dacvae_model.encoder,
dacvae_model.quantizer
).eval().to(device)
# Sample input: 1 second of audio at 48kHz
sample_rate = DEFAULT_CONFIG["sample_rate"]
dummy_audio = torch.randn(1, 1, sample_rate, device=device)
torch.onnx.export(
wrapper,
(dummy_audio,),
output_path,
input_names=["audio"],
output_names=["latent_features"],
dynamic_axes={
"audio": {0: "batch", 2: "samples"},
"latent_features": {0: "batch", 2: "time_steps"},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=True,
external_data=True,
)
print(f" ✓ Encoder exported successfully")
# Validate
import onnx
# Load without external data to avoid OOM - we just need to validate structure
model = onnx.load(output_path, load_external_data=False)
onnx.checker.check_model(model, full_check=False)
print(f" ✓ ONNX model validation passed")
def export_decoder(
dacvae_model: dacvae.DACVAE,
output_path: str,
opset_version: int = 21,
device: str = "cpu",
) -> None:
"""Export DACVAE decoder to ONNX."""
print(f"Exporting DACVAE decoder to {output_path}...")
wrapper = DACVAEDecoderWrapper(
dacvae_model.decoder,
dacvae_model.quantizer
).eval().to(device)
# Sample input: 25 time steps (1 second at 48kHz with hop_length=1920)
hop_length = int(__import__("numpy").prod(DEFAULT_CONFIG["encoder_rates"]))
time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length
dummy_latent = torch.randn(1, 128, time_steps, device=device)
torch.onnx.export(
wrapper,
(dummy_latent,),
output_path,
input_names=["latent_features"],
output_names=["waveform"],
dynamic_axes={
"latent_features": {0: "batch", 2: "time_steps"},
"waveform": {0: "batch", 2: "samples"},
},
opset_version=opset_version,
do_constant_folding=True,
dynamo=True,
external_data=True,
)
print(f" ✓ Decoder exported successfully")
# Validate
import onnx
# Load without external data to avoid OOM - we just need to validate structure
model = onnx.load(output_path, load_external_data=False)
onnx.checker.check_model(model, full_check=False)
print(f" ✓ ONNX model validation passed")
def verify_encoder(
dacvae_model: dacvae.DACVAE,
onnx_path: str,
device: str = "cpu",
tolerance: float = 1e-4,
) -> bool:
"""Verify ONNX encoder output matches PyTorch."""
import onnxruntime as ort
import numpy as np
print("Verifying encoder output...")
wrapper = DACVAEEncoderWrapper(
dacvae_model.encoder,
dacvae_model.quantizer
).eval().to(device)
# Test with random audio
sample_rate = DEFAULT_CONFIG["sample_rate"]
test_audio = torch.randn(1, 1, sample_rate * 2, device=device) # 2 seconds
# PyTorch output
with torch.no_grad():
pytorch_output = wrapper(test_audio).cpu().numpy()
# ONNX Runtime output
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
onnx_output = sess.run(
["latent_features"],
{"audio": test_audio.cpu().numpy()}
)[0]
# Compare
max_diff = np.abs(pytorch_output - onnx_output).max()
mean_diff = np.abs(pytorch_output - onnx_output).mean()
print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
if max_diff > tolerance:
print(f" ✗ Verification failed (tolerance: {tolerance})")
return False
print(f" ✓ Verification passed (tolerance: {tolerance})")
return True
def verify_decoder(
dacvae_model: dacvae.DACVAE,
onnx_path: str,
device: str = "cpu",
tolerance: float = 1e-3,
) -> bool:
"""Verify ONNX decoder output matches PyTorch."""
import onnxruntime as ort
import numpy as np
print("Verifying decoder output...")
wrapper = DACVAEDecoderWrapper(
dacvae_model.decoder,
dacvae_model.quantizer
).eval().to(device)
# Test with random latent
hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"]))
time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length # 25 steps = 1 second
test_latent = torch.randn(1, 128, time_steps, device=device)
# PyTorch output
with torch.no_grad():
pytorch_output = wrapper(test_latent).cpu().numpy()
# ONNX Runtime output
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
onnx_output = sess.run(
["waveform"],
{"latent_features": test_latent.cpu().numpy()}
)[0]
# Compare
max_diff = np.abs(pytorch_output - onnx_output).max()
mean_diff = np.abs(pytorch_output - onnx_output).mean()
print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
if max_diff > tolerance:
print(f" ✗ Verification failed (tolerance: {tolerance})")
return False
print(f" ✓ Verification passed (tolerance: {tolerance})")
return True
def main():
parser = argparse.ArgumentParser(description="Export DACVAE to ONNX")
parser.add_argument(
"--model-id",
type=str,
default="facebook/sam-audio-small",
help="HuggingFace model ID (default: facebook/sam-audio-small)",
)
parser.add_argument(
"--output-dir",
type=str,
default="onnx_models",
help="Output directory for ONNX models",
)
parser.add_argument(
"--opset-version",
type=int,
default=18,
help="ONNX opset version (default: 18)",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to use for export (default: cpu)",
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify ONNX output matches PyTorch",
)
parser.add_argument(
"--tolerance",
type=float,
default=1e-4,
help="Tolerance for verification (default: 1e-4)",
)
parser.add_argument(
"--encoder-only",
action="store_true",
help="Export only the encoder",
)
parser.add_argument(
"--decoder-only",
action="store_true",
help="Export only the decoder",
)
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load model
dacvae_model = create_dacvae_model(args.model_id)
print(f"\nDACVAE Configuration:")
print(f" Model: {args.model_id}")
print(f" Sample rate: {DEFAULT_CONFIG['sample_rate']} Hz")
print(f" Hop length: {int(__import__('numpy').prod(DEFAULT_CONFIG['encoder_rates']))}")
print(f" Latent dim: 128 (continuous)")
# Export encoder
if not args.decoder_only:
encoder_path = os.path.join(args.output_dir, "dacvae_encoder.onnx")
export_encoder(
dacvae_model,
encoder_path,
opset_version=args.opset_version,
device=args.device,
)
if args.verify:
verify_encoder(
dacvae_model,
encoder_path,
device=args.device,
tolerance=args.tolerance,
)
# Export decoder
if not args.encoder_only:
decoder_path = os.path.join(args.output_dir, "dacvae_decoder.onnx")
export_decoder(
dacvae_model,
decoder_path,
opset_version=args.opset_version,
device=args.device,
)
if args.verify:
verify_decoder(
dacvae_model,
decoder_path,
device=args.device,
tolerance=args.tolerance * 10, # Decoder has higher tolerance
)
print(f"\n✓ Export complete! Models saved to {args.output_dir}/")
if __name__ == "__main__":
main()