|
|
|
|
|
""" |
|
|
Export DACVAE (audio codec) to ONNX format. |
|
|
|
|
|
This exports the encoder and decoder separately: |
|
|
- Encoder: audio waveform → latent features |
|
|
- Decoder: latent features → audio waveform |
|
|
|
|
|
Usage: |
|
|
python -m onnx_export.export_dacvae --output-dir onnx_models --verify |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import dacvae |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG = { |
|
|
"encoder_dim": 64, |
|
|
"encoder_rates": [2, 8, 10, 12], |
|
|
"latent_dim": 1024, |
|
|
"decoder_dim": 1536, |
|
|
"decoder_rates": [12, 10, 8, 2], |
|
|
"n_codebooks": 16, |
|
|
"codebook_size": 1024, |
|
|
"codebook_dim": 128, |
|
|
"quantizer_dropout": False, |
|
|
"sample_rate": 48000, |
|
|
} |
|
|
|
|
|
|
|
|
class DACVAEEncoderWrapper(nn.Module): |
|
|
"""Wrapper for DACVAE encoder that outputs continuous latent features.""" |
|
|
|
|
|
def __init__(self, encoder, quantizer): |
|
|
super().__init__() |
|
|
self.encoder = encoder |
|
|
self.in_proj = quantizer.in_proj |
|
|
|
|
|
def forward(self, audio: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode audio to latent features. |
|
|
|
|
|
Args: |
|
|
audio: Input waveform, shape (batch, 1, samples) |
|
|
|
|
|
Returns: |
|
|
latent_features: Continuous latent mean, shape (batch, 128, time_steps) |
|
|
""" |
|
|
x = self.encoder(audio) |
|
|
|
|
|
mean, _ = self.in_proj(x).chunk(2, dim=1) |
|
|
return mean |
|
|
|
|
|
|
|
|
class DACVAEDecoderWrapper(nn.Module): |
|
|
"""Wrapper for DACVAE decoder that takes continuous latent features.""" |
|
|
|
|
|
def __init__(self, decoder, quantizer): |
|
|
super().__init__() |
|
|
self.decoder = decoder |
|
|
self.out_proj = quantizer.out_proj |
|
|
|
|
|
def forward(self, latent_features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Decode latent features to audio. |
|
|
|
|
|
Args: |
|
|
latent_features: Continuous latent, shape (batch, 128, time_steps) |
|
|
|
|
|
Returns: |
|
|
audio: Output waveform, shape (batch, 1, samples) |
|
|
""" |
|
|
x = self.out_proj(latent_features) |
|
|
return self.decoder(x) |
|
|
|
|
|
|
|
|
def create_dacvae_model(model_id: str = "facebook/sam-audio-small") -> dacvae.DACVAE: |
|
|
""" |
|
|
Create and load DACVAE model with weights from SAM Audio checkpoint. |
|
|
|
|
|
This uses the standalone dacvae library, avoiding loading the full SAM Audio |
|
|
model and its dependencies (vision encoder, imagebind, etc). |
|
|
""" |
|
|
print(f"Creating DACVAE model...") |
|
|
|
|
|
model = dacvae.DACVAE( |
|
|
encoder_dim=DEFAULT_CONFIG["encoder_dim"], |
|
|
encoder_rates=DEFAULT_CONFIG["encoder_rates"], |
|
|
latent_dim=DEFAULT_CONFIG["latent_dim"], |
|
|
decoder_dim=DEFAULT_CONFIG["decoder_dim"], |
|
|
decoder_rates=DEFAULT_CONFIG["decoder_rates"], |
|
|
n_codebooks=DEFAULT_CONFIG["n_codebooks"], |
|
|
codebook_size=DEFAULT_CONFIG["codebook_size"], |
|
|
codebook_dim=DEFAULT_CONFIG["codebook_dim"], |
|
|
quantizer_dropout=DEFAULT_CONFIG["quantizer_dropout"], |
|
|
sample_rate=DEFAULT_CONFIG["sample_rate"], |
|
|
).eval() |
|
|
|
|
|
|
|
|
print(f"Downloading checkpoint from {model_id}...") |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=model_id, |
|
|
filename="checkpoint.pt", |
|
|
) |
|
|
|
|
|
print("Loading DACVAE weights from checkpoint...") |
|
|
state_dict = torch.load( |
|
|
checkpoint_path, |
|
|
map_location="cpu", |
|
|
weights_only=True, |
|
|
mmap=True, |
|
|
) |
|
|
|
|
|
|
|
|
dacvae_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith("audio_codec."): |
|
|
new_key = k.replace("audio_codec.", "") |
|
|
dacvae_state_dict[new_key] = v.clone() |
|
|
|
|
|
|
|
|
model.load_state_dict(dacvae_state_dict, strict=False) |
|
|
|
|
|
|
|
|
del state_dict |
|
|
|
|
|
print(f" ✓ Loaded {len(dacvae_state_dict)} DACVAE weight tensors") |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"])) |
|
|
model.hop_length = hop_length |
|
|
model.sample_rate = DEFAULT_CONFIG["sample_rate"] |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def export_encoder( |
|
|
dacvae_model: dacvae.DACVAE, |
|
|
output_path: str, |
|
|
opset_version: int = 21, |
|
|
device: str = "cpu", |
|
|
) -> None: |
|
|
"""Export DACVAE encoder to ONNX.""" |
|
|
print(f"Exporting DACVAE encoder to {output_path}...") |
|
|
|
|
|
wrapper = DACVAEEncoderWrapper( |
|
|
dacvae_model.encoder, |
|
|
dacvae_model.quantizer |
|
|
).eval().to(device) |
|
|
|
|
|
|
|
|
sample_rate = DEFAULT_CONFIG["sample_rate"] |
|
|
dummy_audio = torch.randn(1, 1, sample_rate, device=device) |
|
|
|
|
|
torch.onnx.export( |
|
|
wrapper, |
|
|
(dummy_audio,), |
|
|
output_path, |
|
|
input_names=["audio"], |
|
|
output_names=["latent_features"], |
|
|
dynamic_axes={ |
|
|
"audio": {0: "batch", 2: "samples"}, |
|
|
"latent_features": {0: "batch", 2: "time_steps"}, |
|
|
}, |
|
|
opset_version=opset_version, |
|
|
do_constant_folding=True, |
|
|
dynamo=True, |
|
|
external_data=True, |
|
|
) |
|
|
|
|
|
print(f" ✓ Encoder exported successfully") |
|
|
|
|
|
|
|
|
import onnx |
|
|
|
|
|
model = onnx.load(output_path, load_external_data=False) |
|
|
onnx.checker.check_model(model, full_check=False) |
|
|
print(f" ✓ ONNX model validation passed") |
|
|
|
|
|
|
|
|
def export_decoder( |
|
|
dacvae_model: dacvae.DACVAE, |
|
|
output_path: str, |
|
|
opset_version: int = 21, |
|
|
device: str = "cpu", |
|
|
) -> None: |
|
|
"""Export DACVAE decoder to ONNX.""" |
|
|
print(f"Exporting DACVAE decoder to {output_path}...") |
|
|
|
|
|
wrapper = DACVAEDecoderWrapper( |
|
|
dacvae_model.decoder, |
|
|
dacvae_model.quantizer |
|
|
).eval().to(device) |
|
|
|
|
|
|
|
|
hop_length = int(__import__("numpy").prod(DEFAULT_CONFIG["encoder_rates"])) |
|
|
time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length |
|
|
dummy_latent = torch.randn(1, 128, time_steps, device=device) |
|
|
|
|
|
torch.onnx.export( |
|
|
wrapper, |
|
|
(dummy_latent,), |
|
|
output_path, |
|
|
input_names=["latent_features"], |
|
|
output_names=["waveform"], |
|
|
dynamic_axes={ |
|
|
"latent_features": {0: "batch", 2: "time_steps"}, |
|
|
"waveform": {0: "batch", 2: "samples"}, |
|
|
}, |
|
|
opset_version=opset_version, |
|
|
do_constant_folding=True, |
|
|
dynamo=True, |
|
|
external_data=True, |
|
|
) |
|
|
|
|
|
print(f" ✓ Decoder exported successfully") |
|
|
|
|
|
|
|
|
import onnx |
|
|
|
|
|
model = onnx.load(output_path, load_external_data=False) |
|
|
onnx.checker.check_model(model, full_check=False) |
|
|
print(f" ✓ ONNX model validation passed") |
|
|
|
|
|
|
|
|
def verify_encoder( |
|
|
dacvae_model: dacvae.DACVAE, |
|
|
onnx_path: str, |
|
|
device: str = "cpu", |
|
|
tolerance: float = 1e-4, |
|
|
) -> bool: |
|
|
"""Verify ONNX encoder output matches PyTorch.""" |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
|
|
|
print("Verifying encoder output...") |
|
|
|
|
|
wrapper = DACVAEEncoderWrapper( |
|
|
dacvae_model.encoder, |
|
|
dacvae_model.quantizer |
|
|
).eval().to(device) |
|
|
|
|
|
|
|
|
sample_rate = DEFAULT_CONFIG["sample_rate"] |
|
|
test_audio = torch.randn(1, 1, sample_rate * 2, device=device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pytorch_output = wrapper(test_audio).cpu().numpy() |
|
|
|
|
|
|
|
|
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
|
|
onnx_output = sess.run( |
|
|
["latent_features"], |
|
|
{"audio": test_audio.cpu().numpy()} |
|
|
)[0] |
|
|
|
|
|
|
|
|
max_diff = np.abs(pytorch_output - onnx_output).max() |
|
|
mean_diff = np.abs(pytorch_output - onnx_output).mean() |
|
|
|
|
|
print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}") |
|
|
|
|
|
if max_diff > tolerance: |
|
|
print(f" ✗ Verification failed (tolerance: {tolerance})") |
|
|
return False |
|
|
|
|
|
print(f" ✓ Verification passed (tolerance: {tolerance})") |
|
|
return True |
|
|
|
|
|
|
|
|
def verify_decoder( |
|
|
dacvae_model: dacvae.DACVAE, |
|
|
onnx_path: str, |
|
|
device: str = "cpu", |
|
|
tolerance: float = 1e-3, |
|
|
) -> bool: |
|
|
"""Verify ONNX decoder output matches PyTorch.""" |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
|
|
|
print("Verifying decoder output...") |
|
|
|
|
|
wrapper = DACVAEDecoderWrapper( |
|
|
dacvae_model.decoder, |
|
|
dacvae_model.quantizer |
|
|
).eval().to(device) |
|
|
|
|
|
|
|
|
hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"])) |
|
|
time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length |
|
|
test_latent = torch.randn(1, 128, time_steps, device=device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
pytorch_output = wrapper(test_latent).cpu().numpy() |
|
|
|
|
|
|
|
|
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
|
|
onnx_output = sess.run( |
|
|
["waveform"], |
|
|
{"latent_features": test_latent.cpu().numpy()} |
|
|
)[0] |
|
|
|
|
|
|
|
|
max_diff = np.abs(pytorch_output - onnx_output).max() |
|
|
mean_diff = np.abs(pytorch_output - onnx_output).mean() |
|
|
|
|
|
print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}") |
|
|
|
|
|
if max_diff > tolerance: |
|
|
print(f" ✗ Verification failed (tolerance: {tolerance})") |
|
|
return False |
|
|
|
|
|
print(f" ✓ Verification passed (tolerance: {tolerance})") |
|
|
return True |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Export DACVAE to ONNX") |
|
|
parser.add_argument( |
|
|
"--model-id", |
|
|
type=str, |
|
|
default="facebook/sam-audio-small", |
|
|
help="HuggingFace model ID (default: facebook/sam-audio-small)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="onnx_models", |
|
|
help="Output directory for ONNX models", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--opset-version", |
|
|
type=int, |
|
|
default=18, |
|
|
help="ONNX opset version (default: 18)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cpu", |
|
|
help="Device to use for export (default: cpu)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verify", |
|
|
action="store_true", |
|
|
help="Verify ONNX output matches PyTorch", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tolerance", |
|
|
type=float, |
|
|
default=1e-4, |
|
|
help="Tolerance for verification (default: 1e-4)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--encoder-only", |
|
|
action="store_true", |
|
|
help="Export only the encoder", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--decoder-only", |
|
|
action="store_true", |
|
|
help="Export only the decoder", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
dacvae_model = create_dacvae_model(args.model_id) |
|
|
|
|
|
print(f"\nDACVAE Configuration:") |
|
|
print(f" Model: {args.model_id}") |
|
|
print(f" Sample rate: {DEFAULT_CONFIG['sample_rate']} Hz") |
|
|
print(f" Hop length: {int(__import__('numpy').prod(DEFAULT_CONFIG['encoder_rates']))}") |
|
|
print(f" Latent dim: 128 (continuous)") |
|
|
|
|
|
|
|
|
if not args.decoder_only: |
|
|
encoder_path = os.path.join(args.output_dir, "dacvae_encoder.onnx") |
|
|
export_encoder( |
|
|
dacvae_model, |
|
|
encoder_path, |
|
|
opset_version=args.opset_version, |
|
|
device=args.device, |
|
|
) |
|
|
|
|
|
if args.verify: |
|
|
verify_encoder( |
|
|
dacvae_model, |
|
|
encoder_path, |
|
|
device=args.device, |
|
|
tolerance=args.tolerance, |
|
|
) |
|
|
|
|
|
|
|
|
if not args.encoder_only: |
|
|
decoder_path = os.path.join(args.output_dir, "dacvae_decoder.onnx") |
|
|
export_decoder( |
|
|
dacvae_model, |
|
|
decoder_path, |
|
|
opset_version=args.opset_version, |
|
|
device=args.device, |
|
|
) |
|
|
|
|
|
if args.verify: |
|
|
verify_decoder( |
|
|
dacvae_model, |
|
|
decoder_path, |
|
|
device=args.device, |
|
|
tolerance=args.tolerance * 10, |
|
|
) |
|
|
|
|
|
print(f"\n✓ Export complete! Models saved to {args.output_dir}/") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|