matbee's picture
Add CLAP reranking support (audio + text encoders)
23278d3 verified
#!/usr/bin/env python3
"""
Export CLAP (Contrastive Language-Audio Pretraining) model to ONNX.
The CLAP model is used for reranking separation candidates by scoring
audio-text similarity.
Usage:
python -m onnx_export.export_clap --output-dir onnx_models --verify
"""
import os
import argparse
import json
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
def get_clap_model(checkpoint_file=None, device="cpu"):
"""Load the CLAP model from laion_clap."""
import laion_clap
model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-tiny").to(device)
if checkpoint_file is None:
checkpoint_file = hf_hub_download(
repo_id="lukewys/laion_clap", filename="630k-best.pt"
)
state_dict = torch.load(checkpoint_file, map_location=device, weights_only=False)["state_dict"]
# Handle module prefix from DataParallel
if next(iter(state_dict.items()))[0].startswith("module"):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# Remove position_ids if present (not needed)
if "text_branch.embeddings.position_ids" in state_dict:
del state_dict["text_branch.embeddings.position_ids"]
model.model.load_state_dict(state_dict)
return model.eval()
class CLAPAudioEncoderWrapper(nn.Module):
"""
Wrapper for CLAP audio encoder for ONNX export.
Takes waveform input directly and processes through the HTSAT audio branch.
"""
def __init__(self, model):
super().__init__()
self.audio_branch = model.model.audio_branch
self.audio_transform = model.model.audio_transform
self.audio_projection = model.model.audio_projection
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Args:
waveform: [batch, samples] audio waveform at 48kHz, 10 seconds (480000 samples)
Returns:
audio_embed: [batch, 512] normalized audio embedding
"""
# Compute spectrogram from waveform
x = self.audio_branch.spectrogram_extractor(waveform) # [B, 1, T, F]
x = self.audio_branch.logmel_extractor(x) # [B, 1, T, mel_bins]
# Batch normalization
x = x.transpose(1, 3) # [B, mel_bins, T, 1]
x = self.audio_branch.bn0(x)
x = x.transpose(1, 3) # [B, 1, T, mel_bins]
# Reshape for Swin Transformer using the original method
x = self.audio_branch.reshape_wav2img(x)
# Forward through transformer features
output_dict = self.audio_branch.forward_features(x)
embedding = output_dict["embedding"] # [B, 768]
# Project to 512-dim: projection first, then transform
x = self.audio_projection(embedding) # 768 -> 512
x = self.audio_transform(x) # 512 -> 512
# L2 normalize
x = x / x.norm(dim=-1, keepdim=True)
return x
class CLAPTextEncoderWrapper(nn.Module):
"""Wrapper for CLAP text encoder for ONNX export."""
def __init__(self, model):
super().__init__()
self.text_branch = model.model.text_branch
self.text_transform = model.model.text_transform
self.text_projection = model.model.text_projection
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Args:
input_ids: [batch, seq_len] token IDs
attention_mask: [batch, seq_len] attention mask
Returns:
text_embed: [batch, 512] normalized text embedding
"""
x = self.text_branch(input_ids=input_ids, attention_mask=attention_mask)
x = x.pooler_output # [B, 768]
x = self.text_projection(x) # 768 -> 512
x = self.text_transform(x) # 512 -> 512
# L2 normalize
x = x / x.norm(dim=-1, keepdim=True)
return x
def export_clap_audio_encoder(model, output_path, opset_version=17, device="cpu"):
"""Export CLAP audio encoder to ONNX."""
import onnx
print(f"Exporting CLAP audio encoder to {output_path}...")
wrapper = CLAPAudioEncoderWrapper(model).eval().to(device)
# Sample input: 10 seconds of audio at 48kHz (480000 samples)
batch_size = 1
num_samples = 480000 # 10 seconds at 48kHz
dummy_waveform = torch.randn(batch_size, num_samples, device=device)
# Test forward pass
with torch.no_grad():
output = wrapper(dummy_waveform)
print(f" Audio encoder output shape: {output.shape}")
torch.onnx.export(
wrapper,
(dummy_waveform,),
output_path,
input_names=["waveform"],
output_names=["audio_embed"],
dynamic_axes={
"waveform": {0: "batch_size"},
"audio_embed": {0: "batch_size"},
},
opset_version=opset_version,
do_constant_folding=True,
)
# Validate
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print(" ✓ CLAP audio encoder exported successfully")
return True
def export_clap_text_encoder(model, output_path, opset_version=17, device="cpu"):
"""Export CLAP text encoder to ONNX."""
import onnx
print(f"Exporting CLAP text encoder to {output_path}...")
wrapper = CLAPTextEncoderWrapper(model).eval().to(device)
# Sample input
batch_size = 1
seq_len = 77
dummy_input_ids = torch.randint(0, 50265, (batch_size, seq_len), device=device)
dummy_attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long, device=device)
# Test forward pass
with torch.no_grad():
output = wrapper(dummy_input_ids, dummy_attention_mask)
print(f" Text encoder output shape: {output.shape}")
torch.onnx.export(
wrapper,
(dummy_input_ids, dummy_attention_mask),
output_path,
input_names=["input_ids", "attention_mask"],
output_names=["text_embed"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"text_embed": {0: "batch_size"},
},
opset_version=opset_version,
do_constant_folding=True,
)
# Validate
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print(" ✓ CLAP text encoder exported successfully")
return True
def save_clap_config(model, output_path):
"""Save CLAP audio preprocessing config."""
audio_cfg = model.model_cfg["audio_cfg"]
config = {
"sample_rate": audio_cfg["sample_rate"],
"window_size": audio_cfg["window_size"],
"hop_size": audio_cfg["hop_size"],
"mel_bins": audio_cfg["mel_bins"],
"fmin": audio_cfg["fmin"],
"fmax": audio_cfg["fmax"],
"max_audio_len": 480000, # 10 seconds at 48kHz
"embed_dim": 512,
}
with open(output_path, "w") as f:
json.dump(config, f, indent=2)
print(f" ✓ Config saved to {output_path}")
return config
def save_clap_tokenizer(output_dir):
"""Save RoBERTa tokenizer for CLAP text encoding."""
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokenizer.save_pretrained(output_dir)
print(f" ✓ Tokenizer saved to {output_dir}")
def verify_clap(model, audio_onnx_path, text_onnx_path, config, device="cpu"):
"""Verify ONNX outputs match PyTorch."""
import onnxruntime as ort
import numpy as np
print("Verifying CLAP ONNX outputs...")
# Create sample audio (10 seconds at 48kHz)
sample_waveform = torch.randn(1, 480000) # [batch, samples]
# PyTorch audio embedding
wrapper = CLAPAudioEncoderWrapper(model).eval()
with torch.no_grad():
pytorch_audio_embed = wrapper(sample_waveform).numpy()
# ONNX audio embedding
audio_sess = ort.InferenceSession(audio_onnx_path, providers=["CPUExecutionProvider"])
onnx_audio_embed = audio_sess.run(
["audio_embed"],
{"waveform": sample_waveform.numpy().astype(np.float32)},
)[0]
audio_diff = np.abs(pytorch_audio_embed - onnx_audio_embed).max()
print(f" Audio encoder max diff: {audio_diff:.2e}")
# Text embedding verification
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
tokens = tokenizer(["a person speaking"], return_tensors="pt", padding=True, truncation=True)
text_wrapper = CLAPTextEncoderWrapper(model).eval()
with torch.no_grad():
pytorch_text_embed = text_wrapper(tokens["input_ids"], tokens["attention_mask"]).numpy()
text_sess = ort.InferenceSession(text_onnx_path, providers=["CPUExecutionProvider"])
onnx_text_embed = text_sess.run(
["text_embed"],
{
"input_ids": tokens["input_ids"].numpy().astype(np.int64),
"attention_mask": tokens["attention_mask"].numpy().astype(np.int64),
},
)[0]
text_diff = np.abs(pytorch_text_embed - onnx_text_embed).max()
print(f" Text encoder max diff: {text_diff:.2e}")
max_diff = max(audio_diff, text_diff)
if max_diff < 1e-4:
print(" ✓ Verification passed")
return True
else:
print(f" ✗ Verification failed (max diff: {max_diff:.2e})")
return False
def main():
parser = argparse.ArgumentParser(description="Export CLAP to ONNX")
parser.add_argument("--output-dir", type=str, default="onnx_models")
parser.add_argument("--checkpoint", type=str, default=None, help="CLAP checkpoint path")
parser.add_argument("--opset", type=int, default=18)
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--verify", action="store_true")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Load model
print("Loading CLAP model...")
model = get_clap_model(args.checkpoint, args.device)
# Export audio encoder
audio_path = os.path.join(args.output_dir, "clap_audio_encoder.onnx")
export_clap_audio_encoder(model, audio_path, args.opset, args.device)
# Export text encoder
text_path = os.path.join(args.output_dir, "clap_text_encoder.onnx")
export_clap_text_encoder(model, text_path, args.opset, args.device)
# Save config
config_path = os.path.join(args.output_dir, "clap_config.json")
config = save_clap_config(model, config_path)
# Save tokenizer
tokenizer_dir = os.path.join(args.output_dir, "clap_tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
save_clap_tokenizer(tokenizer_dir)
# Verify
if args.verify:
verify_clap(model, audio_path, text_path, config, args.device)
print(f"\n✓ Export complete!")
print(f" Audio encoder: {audio_path}")
print(f" Text encoder: {text_path}")
if __name__ == "__main__":
main()