| |
| """ |
| Export CLAP (Contrastive Language-Audio Pretraining) model to ONNX. |
| |
| The CLAP model is used for reranking separation candidates by scoring |
| audio-text similarity. |
| |
| Usage: |
| python -m onnx_export.export_clap --output-dir onnx_models --verify |
| """ |
|
|
| import os |
| import argparse |
| import json |
| import torch |
| import torch.nn as nn |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| def get_clap_model(checkpoint_file=None, device="cpu"): |
| """Load the CLAP model from laion_clap.""" |
| import laion_clap |
|
|
| model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-tiny").to(device) |
|
|
| if checkpoint_file is None: |
| checkpoint_file = hf_hub_download( |
| repo_id="lukewys/laion_clap", filename="630k-best.pt" |
| ) |
|
|
| state_dict = torch.load(checkpoint_file, map_location=device, weights_only=False)["state_dict"] |
|
|
| |
| if next(iter(state_dict.items()))[0].startswith("module"): |
| state_dict = {k[7:]: v for k, v in state_dict.items()} |
|
|
| |
| if "text_branch.embeddings.position_ids" in state_dict: |
| del state_dict["text_branch.embeddings.position_ids"] |
|
|
| model.model.load_state_dict(state_dict) |
| return model.eval() |
|
|
|
|
| class CLAPAudioEncoderWrapper(nn.Module): |
| """ |
| Wrapper for CLAP audio encoder for ONNX export. |
| |
| Takes waveform input directly and processes through the HTSAT audio branch. |
| """ |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.audio_branch = model.model.audio_branch |
| self.audio_transform = model.model.audio_transform |
| self.audio_projection = model.model.audio_projection |
|
|
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| waveform: [batch, samples] audio waveform at 48kHz, 10 seconds (480000 samples) |
| |
| Returns: |
| audio_embed: [batch, 512] normalized audio embedding |
| """ |
| |
| x = self.audio_branch.spectrogram_extractor(waveform) |
| x = self.audio_branch.logmel_extractor(x) |
|
|
| |
| x = x.transpose(1, 3) |
| x = self.audio_branch.bn0(x) |
| x = x.transpose(1, 3) |
|
|
| |
| x = self.audio_branch.reshape_wav2img(x) |
|
|
| |
| output_dict = self.audio_branch.forward_features(x) |
| embedding = output_dict["embedding"] |
|
|
| |
| x = self.audio_projection(embedding) |
| x = self.audio_transform(x) |
|
|
| |
| x = x / x.norm(dim=-1, keepdim=True) |
| return x |
|
|
|
|
| class CLAPTextEncoderWrapper(nn.Module): |
| """Wrapper for CLAP text encoder for ONNX export.""" |
|
|
| def __init__(self, model): |
| super().__init__() |
| self.text_branch = model.model.text_branch |
| self.text_transform = model.model.text_transform |
| self.text_projection = model.model.text_projection |
|
|
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| input_ids: [batch, seq_len] token IDs |
| attention_mask: [batch, seq_len] attention mask |
| |
| Returns: |
| text_embed: [batch, 512] normalized text embedding |
| """ |
| x = self.text_branch(input_ids=input_ids, attention_mask=attention_mask) |
| x = x.pooler_output |
| x = self.text_projection(x) |
| x = self.text_transform(x) |
| |
| x = x / x.norm(dim=-1, keepdim=True) |
| return x |
|
|
|
|
| def export_clap_audio_encoder(model, output_path, opset_version=17, device="cpu"): |
| """Export CLAP audio encoder to ONNX.""" |
| import onnx |
|
|
| print(f"Exporting CLAP audio encoder to {output_path}...") |
|
|
| wrapper = CLAPAudioEncoderWrapper(model).eval().to(device) |
|
|
| |
| batch_size = 1 |
| num_samples = 480000 |
|
|
| dummy_waveform = torch.randn(batch_size, num_samples, device=device) |
|
|
| |
| with torch.no_grad(): |
| output = wrapper(dummy_waveform) |
| print(f" Audio encoder output shape: {output.shape}") |
|
|
| torch.onnx.export( |
| wrapper, |
| (dummy_waveform,), |
| output_path, |
| input_names=["waveform"], |
| output_names=["audio_embed"], |
| dynamic_axes={ |
| "waveform": {0: "batch_size"}, |
| "audio_embed": {0: "batch_size"}, |
| }, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| ) |
|
|
| |
| onnx_model = onnx.load(output_path) |
| onnx.checker.check_model(onnx_model) |
| print(" ✓ CLAP audio encoder exported successfully") |
|
|
| return True |
|
|
|
|
| def export_clap_text_encoder(model, output_path, opset_version=17, device="cpu"): |
| """Export CLAP text encoder to ONNX.""" |
| import onnx |
|
|
| print(f"Exporting CLAP text encoder to {output_path}...") |
|
|
| wrapper = CLAPTextEncoderWrapper(model).eval().to(device) |
|
|
| |
| batch_size = 1 |
| seq_len = 77 |
|
|
| dummy_input_ids = torch.randint(0, 50265, (batch_size, seq_len), device=device) |
| dummy_attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long, device=device) |
|
|
| |
| with torch.no_grad(): |
| output = wrapper(dummy_input_ids, dummy_attention_mask) |
| print(f" Text encoder output shape: {output.shape}") |
|
|
| torch.onnx.export( |
| wrapper, |
| (dummy_input_ids, dummy_attention_mask), |
| output_path, |
| input_names=["input_ids", "attention_mask"], |
| output_names=["text_embed"], |
| dynamic_axes={ |
| "input_ids": {0: "batch_size", 1: "seq_len"}, |
| "attention_mask": {0: "batch_size", 1: "seq_len"}, |
| "text_embed": {0: "batch_size"}, |
| }, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| ) |
|
|
| |
| onnx_model = onnx.load(output_path) |
| onnx.checker.check_model(onnx_model) |
| print(" ✓ CLAP text encoder exported successfully") |
|
|
| return True |
|
|
|
|
| def save_clap_config(model, output_path): |
| """Save CLAP audio preprocessing config.""" |
| audio_cfg = model.model_cfg["audio_cfg"] |
|
|
| config = { |
| "sample_rate": audio_cfg["sample_rate"], |
| "window_size": audio_cfg["window_size"], |
| "hop_size": audio_cfg["hop_size"], |
| "mel_bins": audio_cfg["mel_bins"], |
| "fmin": audio_cfg["fmin"], |
| "fmax": audio_cfg["fmax"], |
| "max_audio_len": 480000, |
| "embed_dim": 512, |
| } |
|
|
| with open(output_path, "w") as f: |
| json.dump(config, f, indent=2) |
|
|
| print(f" ✓ Config saved to {output_path}") |
| return config |
|
|
|
|
| def save_clap_tokenizer(output_dir): |
| """Save RoBERTa tokenizer for CLAP text encoding.""" |
| from transformers import RobertaTokenizer |
|
|
| tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
| tokenizer.save_pretrained(output_dir) |
| print(f" ✓ Tokenizer saved to {output_dir}") |
|
|
|
|
| def verify_clap(model, audio_onnx_path, text_onnx_path, config, device="cpu"): |
| """Verify ONNX outputs match PyTorch.""" |
| import onnxruntime as ort |
| import numpy as np |
|
|
| print("Verifying CLAP ONNX outputs...") |
|
|
| |
| sample_waveform = torch.randn(1, 480000) |
|
|
| |
| wrapper = CLAPAudioEncoderWrapper(model).eval() |
| with torch.no_grad(): |
| pytorch_audio_embed = wrapper(sample_waveform).numpy() |
|
|
| |
| audio_sess = ort.InferenceSession(audio_onnx_path, providers=["CPUExecutionProvider"]) |
| onnx_audio_embed = audio_sess.run( |
| ["audio_embed"], |
| {"waveform": sample_waveform.numpy().astype(np.float32)}, |
| )[0] |
|
|
| audio_diff = np.abs(pytorch_audio_embed - onnx_audio_embed).max() |
| print(f" Audio encoder max diff: {audio_diff:.2e}") |
|
|
| |
| from transformers import RobertaTokenizer |
| tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
| tokens = tokenizer(["a person speaking"], return_tensors="pt", padding=True, truncation=True) |
|
|
| text_wrapper = CLAPTextEncoderWrapper(model).eval() |
| with torch.no_grad(): |
| pytorch_text_embed = text_wrapper(tokens["input_ids"], tokens["attention_mask"]).numpy() |
|
|
| text_sess = ort.InferenceSession(text_onnx_path, providers=["CPUExecutionProvider"]) |
| onnx_text_embed = text_sess.run( |
| ["text_embed"], |
| { |
| "input_ids": tokens["input_ids"].numpy().astype(np.int64), |
| "attention_mask": tokens["attention_mask"].numpy().astype(np.int64), |
| }, |
| )[0] |
|
|
| text_diff = np.abs(pytorch_text_embed - onnx_text_embed).max() |
| print(f" Text encoder max diff: {text_diff:.2e}") |
|
|
| max_diff = max(audio_diff, text_diff) |
| if max_diff < 1e-4: |
| print(" ✓ Verification passed") |
| return True |
| else: |
| print(f" ✗ Verification failed (max diff: {max_diff:.2e})") |
| return False |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Export CLAP to ONNX") |
| parser.add_argument("--output-dir", type=str, default="onnx_models") |
| parser.add_argument("--checkpoint", type=str, default=None, help="CLAP checkpoint path") |
| parser.add_argument("--opset", type=int, default=18) |
| parser.add_argument("--device", type=str, default="cpu") |
| parser.add_argument("--verify", action="store_true") |
|
|
| args = parser.parse_args() |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| print("Loading CLAP model...") |
| model = get_clap_model(args.checkpoint, args.device) |
|
|
| |
| audio_path = os.path.join(args.output_dir, "clap_audio_encoder.onnx") |
| export_clap_audio_encoder(model, audio_path, args.opset, args.device) |
|
|
| |
| text_path = os.path.join(args.output_dir, "clap_text_encoder.onnx") |
| export_clap_text_encoder(model, text_path, args.opset, args.device) |
|
|
| |
| config_path = os.path.join(args.output_dir, "clap_config.json") |
| config = save_clap_config(model, config_path) |
|
|
| |
| tokenizer_dir = os.path.join(args.output_dir, "clap_tokenizer") |
| os.makedirs(tokenizer_dir, exist_ok=True) |
| save_clap_tokenizer(tokenizer_dir) |
|
|
| |
| if args.verify: |
| verify_clap(model, audio_path, text_path, config, args.device) |
|
|
| print(f"\n✓ Export complete!") |
| print(f" Audio encoder: {audio_path}") |
| print(f" Text encoder: {text_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|