matbee commited on
Commit
23278d3
·
verified ·
1 Parent(s): 7f4b648

Add CLAP reranking support (audio + text encoders)

Browse files
.gitattributes CHANGED
@@ -34,4 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.data filter=lfs diff=lfs merge=lfs -text
 
37
  test_audio.wav filter=lfs diff=lfs merge=lfs -text
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.data filter=lfs diff=lfs merge=lfs -text
37
+ residual.wav filter=lfs diff=lfs merge=lfs -text
38
  test_audio.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -26,6 +26,10 @@ ONNX-converted models for [SAM-Audio](https://github.com/facebookresearch/sam-au
26
  | `tokenizer/` | SentencePiece tokenizer files (T5) | - |
27
  | `peaframe_tokenizer/` | ModernBERT tokenizer files (PEAFrame) | - |
28
  | `peaframe_config.json` | PEAFrame scaling parameters | - |
 
 
 
 
29
 
30
  ## Installation
31
 
@@ -84,6 +88,24 @@ python onnx_inference.py \
84
  --output separated.wav
85
  ```
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  ### Visual Prompting with SAM3 Mask
88
  ```bash
89
  # First generate a mask with SAM3 (see generate_sam3_mask.py)
@@ -116,6 +138,11 @@ python onnx_inference.py \
116
  - Uses ModernBERT tokenizer
117
  - Processes audio in ~3.3s chunks with 50% overlap
118
  - Default threshold: 0.3
 
 
 
 
 
119
 
120
  ## Exporting Models
121
 
@@ -143,6 +170,9 @@ python -m onnx_export.export_vision --model facebook/sam-audio-small --output ./
143
 
144
  # PEAFrame Span Predictor
145
  python -m onnx_export.export_peaframe --output-dir ./onnx_models --verify
 
 
 
146
  ```
147
 
148
  ### FP16 Quantization (for large models)
@@ -170,6 +200,7 @@ The inference script automatically detects FP16 models and handles input convers
170
  | `export_t5.py` | T5 text encoder |
171
  | `export_vision.py` | Vision encoder (CLIP-based) |
172
  | `export_peaframe.py` | PEAFrame span predictor + tokenizer |
 
173
  | `standalone_config.py` | Config classes for standalone export |
174
 
175
  ## License
 
26
  | `tokenizer/` | SentencePiece tokenizer files (T5) | - |
27
  | `peaframe_tokenizer/` | ModernBERT tokenizer files (PEAFrame) | - |
28
  | `peaframe_config.json` | PEAFrame scaling parameters | - |
29
+ | `clap_audio_encoder.onnx` | CLAP audio encoder (HTSAT-tiny) | ~118 MB |
30
+ | `clap_text_encoder.onnx` | CLAP text encoder (RoBERTa-base) | ~481 MB |
31
+ | `clap_tokenizer/` | RoBERTa tokenizer files (CLAP) | - |
32
+ | `clap_config.json` | CLAP audio preprocessing parameters | - |
33
 
34
  ## Installation
35
 
 
88
  --output separated.wav
89
  ```
90
 
91
+ ### CLAP Reranking
92
+ Generate multiple candidates and select the best using CLAP audio-text similarity:
93
+ ```bash
94
+ python onnx_inference.py \
95
+ --audio input.wav \
96
+ --text "person speaking" \
97
+ --rerank \
98
+ --num-candidates 4 \
99
+ --output separated.wav
100
+ ```
101
+
102
+ Reranking generates multiple separation candidates with different random seeds and uses CLAP to score audio-text similarity, selecting the candidate that best matches the text description. This can improve quality at the cost of ~4x inference time.
103
+
104
+ Options:
105
+ - `--rerank` - Enable reranking mode
106
+ - `--num-candidates N` - Number of candidates (default: 4)
107
+ - `--rerank-seed SEED` - Random seed for reproducibility
108
+
109
  ### Visual Prompting with SAM3 Mask
110
  ```bash
111
  # First generate a mask with SAM3 (see generate_sam3_mask.py)
 
138
  - Uses ModernBERT tokenizer
139
  - Processes audio in ~3.3s chunks with 50% overlap
140
  - Default threshold: 0.3
141
+ - **CLAP**: Audio-text similarity model for candidate reranking
142
+ - Audio encoder: HTSAT-tiny
143
+ - Text encoder: RoBERTa-base
144
+ - Embedding dimension: 512
145
+ - Default candidates: 4
146
 
147
  ## Exporting Models
148
 
 
170
 
171
  # PEAFrame Span Predictor
172
  python -m onnx_export.export_peaframe --output-dir ./onnx_models --verify
173
+
174
+ # CLAP Reranking (audio + text encoders)
175
+ python -m onnx_export.export_clap --output-dir ./onnx_models --verify
176
  ```
177
 
178
  ### FP16 Quantization (for large models)
 
200
  | `export_t5.py` | T5 text encoder |
201
  | `export_vision.py` | Vision encoder (CLIP-based) |
202
  | `export_peaframe.py` | PEAFrame span predictor + tokenizer |
203
+ | `export_clap.py` | CLAP audio + text encoders for reranking |
204
  | `standalone_config.py` | Config classes for standalone export |
205
 
206
  ## License
clap_audio_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46fb0e4d80e2e6403361e1245fa298da9f1530365743082217a4e69d4bb127c6
3
+ size 1176682
clap_audio_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49456668f90249bd4429441b8a65440750a17965d28448f8c72de69849a61f0f
3
+ size 123731968
clap_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sample_rate": 48000,
3
+ "window_size": 1024,
4
+ "hop_size": 480,
5
+ "mel_bins": 64,
6
+ "fmin": 50,
7
+ "fmax": 14000,
8
+ "max_audio_len": 480000,
9
+ "embed_dim": 512
10
+ }
clap_text_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c700f9351d2a32cea5ebd0df0d8ce856f6436b9a54d70caf2d693ec79bb33373
3
+ size 1600036
clap_text_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:542b5813e0fbfcb341d6db39c2b38118178cd4b8c5397fb80906bee14b1fe579
3
+ size 503393280
clap_tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
clap_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
clap_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<pad>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "50264": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "<s>",
46
+ "clean_up_tokenization_spaces": false,
47
+ "cls_token": "<s>",
48
+ "eos_token": "</s>",
49
+ "errors": "replace",
50
+ "extra_special_tokens": {},
51
+ "mask_token": "<mask>",
52
+ "model_max_length": 512,
53
+ "pad_token": "<pad>",
54
+ "sep_token": "</s>",
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "unk_token": "<unk>"
57
+ }
clap_tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
onnx_export/export_clap.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export CLAP (Contrastive Language-Audio Pretraining) model to ONNX.
4
+
5
+ The CLAP model is used for reranking separation candidates by scoring
6
+ audio-text similarity.
7
+
8
+ Usage:
9
+ python -m onnx_export.export_clap --output-dir onnx_models --verify
10
+ """
11
+
12
+ import os
13
+ import argparse
14
+ import json
15
+ import torch
16
+ import torch.nn as nn
17
+ from huggingface_hub import hf_hub_download
18
+
19
+
20
+ def get_clap_model(checkpoint_file=None, device="cpu"):
21
+ """Load the CLAP model from laion_clap."""
22
+ import laion_clap
23
+
24
+ model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-tiny").to(device)
25
+
26
+ if checkpoint_file is None:
27
+ checkpoint_file = hf_hub_download(
28
+ repo_id="lukewys/laion_clap", filename="630k-best.pt"
29
+ )
30
+
31
+ state_dict = torch.load(checkpoint_file, map_location=device, weights_only=False)["state_dict"]
32
+
33
+ # Handle module prefix from DataParallel
34
+ if next(iter(state_dict.items()))[0].startswith("module"):
35
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
36
+
37
+ # Remove position_ids if present (not needed)
38
+ if "text_branch.embeddings.position_ids" in state_dict:
39
+ del state_dict["text_branch.embeddings.position_ids"]
40
+
41
+ model.model.load_state_dict(state_dict)
42
+ return model.eval()
43
+
44
+
45
+ class CLAPAudioEncoderWrapper(nn.Module):
46
+ """
47
+ Wrapper for CLAP audio encoder for ONNX export.
48
+
49
+ Takes waveform input directly and processes through the HTSAT audio branch.
50
+ """
51
+
52
+ def __init__(self, model):
53
+ super().__init__()
54
+ self.audio_branch = model.model.audio_branch
55
+ self.audio_transform = model.model.audio_transform
56
+ self.audio_projection = model.model.audio_projection
57
+
58
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Args:
61
+ waveform: [batch, samples] audio waveform at 48kHz, 10 seconds (480000 samples)
62
+
63
+ Returns:
64
+ audio_embed: [batch, 512] normalized audio embedding
65
+ """
66
+ # Compute spectrogram from waveform
67
+ x = self.audio_branch.spectrogram_extractor(waveform) # [B, 1, T, F]
68
+ x = self.audio_branch.logmel_extractor(x) # [B, 1, T, mel_bins]
69
+
70
+ # Batch normalization
71
+ x = x.transpose(1, 3) # [B, mel_bins, T, 1]
72
+ x = self.audio_branch.bn0(x)
73
+ x = x.transpose(1, 3) # [B, 1, T, mel_bins]
74
+
75
+ # Reshape for Swin Transformer using the original method
76
+ x = self.audio_branch.reshape_wav2img(x)
77
+
78
+ # Forward through transformer features
79
+ output_dict = self.audio_branch.forward_features(x)
80
+ embedding = output_dict["embedding"] # [B, 768]
81
+
82
+ # Project to 512-dim: projection first, then transform
83
+ x = self.audio_projection(embedding) # 768 -> 512
84
+ x = self.audio_transform(x) # 512 -> 512
85
+
86
+ # L2 normalize
87
+ x = x / x.norm(dim=-1, keepdim=True)
88
+ return x
89
+
90
+
91
+ class CLAPTextEncoderWrapper(nn.Module):
92
+ """Wrapper for CLAP text encoder for ONNX export."""
93
+
94
+ def __init__(self, model):
95
+ super().__init__()
96
+ self.text_branch = model.model.text_branch
97
+ self.text_transform = model.model.text_transform
98
+ self.text_projection = model.model.text_projection
99
+
100
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Args:
103
+ input_ids: [batch, seq_len] token IDs
104
+ attention_mask: [batch, seq_len] attention mask
105
+
106
+ Returns:
107
+ text_embed: [batch, 512] normalized text embedding
108
+ """
109
+ x = self.text_branch(input_ids=input_ids, attention_mask=attention_mask)
110
+ x = x.pooler_output # [B, 768]
111
+ x = self.text_projection(x) # 768 -> 512
112
+ x = self.text_transform(x) # 512 -> 512
113
+ # L2 normalize
114
+ x = x / x.norm(dim=-1, keepdim=True)
115
+ return x
116
+
117
+
118
+ def export_clap_audio_encoder(model, output_path, opset_version=17, device="cpu"):
119
+ """Export CLAP audio encoder to ONNX."""
120
+ import onnx
121
+
122
+ print(f"Exporting CLAP audio encoder to {output_path}...")
123
+
124
+ wrapper = CLAPAudioEncoderWrapper(model).eval().to(device)
125
+
126
+ # Sample input: 10 seconds of audio at 48kHz (480000 samples)
127
+ batch_size = 1
128
+ num_samples = 480000 # 10 seconds at 48kHz
129
+
130
+ dummy_waveform = torch.randn(batch_size, num_samples, device=device)
131
+
132
+ # Test forward pass
133
+ with torch.no_grad():
134
+ output = wrapper(dummy_waveform)
135
+ print(f" Audio encoder output shape: {output.shape}")
136
+
137
+ torch.onnx.export(
138
+ wrapper,
139
+ (dummy_waveform,),
140
+ output_path,
141
+ input_names=["waveform"],
142
+ output_names=["audio_embed"],
143
+ dynamic_axes={
144
+ "waveform": {0: "batch_size"},
145
+ "audio_embed": {0: "batch_size"},
146
+ },
147
+ opset_version=opset_version,
148
+ do_constant_folding=True,
149
+ )
150
+
151
+ # Validate
152
+ onnx_model = onnx.load(output_path)
153
+ onnx.checker.check_model(onnx_model)
154
+ print(" ��� CLAP audio encoder exported successfully")
155
+
156
+ return True
157
+
158
+
159
+ def export_clap_text_encoder(model, output_path, opset_version=17, device="cpu"):
160
+ """Export CLAP text encoder to ONNX."""
161
+ import onnx
162
+
163
+ print(f"Exporting CLAP text encoder to {output_path}...")
164
+
165
+ wrapper = CLAPTextEncoderWrapper(model).eval().to(device)
166
+
167
+ # Sample input
168
+ batch_size = 1
169
+ seq_len = 77
170
+
171
+ dummy_input_ids = torch.randint(0, 50265, (batch_size, seq_len), device=device)
172
+ dummy_attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long, device=device)
173
+
174
+ # Test forward pass
175
+ with torch.no_grad():
176
+ output = wrapper(dummy_input_ids, dummy_attention_mask)
177
+ print(f" Text encoder output shape: {output.shape}")
178
+
179
+ torch.onnx.export(
180
+ wrapper,
181
+ (dummy_input_ids, dummy_attention_mask),
182
+ output_path,
183
+ input_names=["input_ids", "attention_mask"],
184
+ output_names=["text_embed"],
185
+ dynamic_axes={
186
+ "input_ids": {0: "batch_size", 1: "seq_len"},
187
+ "attention_mask": {0: "batch_size", 1: "seq_len"},
188
+ "text_embed": {0: "batch_size"},
189
+ },
190
+ opset_version=opset_version,
191
+ do_constant_folding=True,
192
+ )
193
+
194
+ # Validate
195
+ onnx_model = onnx.load(output_path)
196
+ onnx.checker.check_model(onnx_model)
197
+ print(" ✓ CLAP text encoder exported successfully")
198
+
199
+ return True
200
+
201
+
202
+ def save_clap_config(model, output_path):
203
+ """Save CLAP audio preprocessing config."""
204
+ audio_cfg = model.model_cfg["audio_cfg"]
205
+
206
+ config = {
207
+ "sample_rate": audio_cfg["sample_rate"],
208
+ "window_size": audio_cfg["window_size"],
209
+ "hop_size": audio_cfg["hop_size"],
210
+ "mel_bins": audio_cfg["mel_bins"],
211
+ "fmin": audio_cfg["fmin"],
212
+ "fmax": audio_cfg["fmax"],
213
+ "max_audio_len": 480000, # 10 seconds at 48kHz
214
+ "embed_dim": 512,
215
+ }
216
+
217
+ with open(output_path, "w") as f:
218
+ json.dump(config, f, indent=2)
219
+
220
+ print(f" ✓ Config saved to {output_path}")
221
+ return config
222
+
223
+
224
+ def save_clap_tokenizer(output_dir):
225
+ """Save RoBERTa tokenizer for CLAP text encoding."""
226
+ from transformers import RobertaTokenizer
227
+
228
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
229
+ tokenizer.save_pretrained(output_dir)
230
+ print(f" ✓ Tokenizer saved to {output_dir}")
231
+
232
+
233
+ def verify_clap(model, audio_onnx_path, text_onnx_path, config, device="cpu"):
234
+ """Verify ONNX outputs match PyTorch."""
235
+ import onnxruntime as ort
236
+ import numpy as np
237
+
238
+ print("Verifying CLAP ONNX outputs...")
239
+
240
+ # Create sample audio (10 seconds at 48kHz)
241
+ sample_waveform = torch.randn(1, 480000) # [batch, samples]
242
+
243
+ # PyTorch audio embedding
244
+ wrapper = CLAPAudioEncoderWrapper(model).eval()
245
+ with torch.no_grad():
246
+ pytorch_audio_embed = wrapper(sample_waveform).numpy()
247
+
248
+ # ONNX audio embedding
249
+ audio_sess = ort.InferenceSession(audio_onnx_path, providers=["CPUExecutionProvider"])
250
+ onnx_audio_embed = audio_sess.run(
251
+ ["audio_embed"],
252
+ {"waveform": sample_waveform.numpy().astype(np.float32)},
253
+ )[0]
254
+
255
+ audio_diff = np.abs(pytorch_audio_embed - onnx_audio_embed).max()
256
+ print(f" Audio encoder max diff: {audio_diff:.2e}")
257
+
258
+ # Text embedding verification
259
+ from transformers import RobertaTokenizer
260
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
261
+ tokens = tokenizer(["a person speaking"], return_tensors="pt", padding=True, truncation=True)
262
+
263
+ text_wrapper = CLAPTextEncoderWrapper(model).eval()
264
+ with torch.no_grad():
265
+ pytorch_text_embed = text_wrapper(tokens["input_ids"], tokens["attention_mask"]).numpy()
266
+
267
+ text_sess = ort.InferenceSession(text_onnx_path, providers=["CPUExecutionProvider"])
268
+ onnx_text_embed = text_sess.run(
269
+ ["text_embed"],
270
+ {
271
+ "input_ids": tokens["input_ids"].numpy().astype(np.int64),
272
+ "attention_mask": tokens["attention_mask"].numpy().astype(np.int64),
273
+ },
274
+ )[0]
275
+
276
+ text_diff = np.abs(pytorch_text_embed - onnx_text_embed).max()
277
+ print(f" Text encoder max diff: {text_diff:.2e}")
278
+
279
+ max_diff = max(audio_diff, text_diff)
280
+ if max_diff < 1e-4:
281
+ print(" ✓ Verification passed")
282
+ return True
283
+ else:
284
+ print(f" ✗ Verification failed (max diff: {max_diff:.2e})")
285
+ return False
286
+
287
+
288
+ def main():
289
+ parser = argparse.ArgumentParser(description="Export CLAP to ONNX")
290
+ parser.add_argument("--output-dir", type=str, default="onnx_models")
291
+ parser.add_argument("--checkpoint", type=str, default=None, help="CLAP checkpoint path")
292
+ parser.add_argument("--opset", type=int, default=18)
293
+ parser.add_argument("--device", type=str, default="cpu")
294
+ parser.add_argument("--verify", action="store_true")
295
+
296
+ args = parser.parse_args()
297
+
298
+ os.makedirs(args.output_dir, exist_ok=True)
299
+
300
+ # Load model
301
+ print("Loading CLAP model...")
302
+ model = get_clap_model(args.checkpoint, args.device)
303
+
304
+ # Export audio encoder
305
+ audio_path = os.path.join(args.output_dir, "clap_audio_encoder.onnx")
306
+ export_clap_audio_encoder(model, audio_path, args.opset, args.device)
307
+
308
+ # Export text encoder
309
+ text_path = os.path.join(args.output_dir, "clap_text_encoder.onnx")
310
+ export_clap_text_encoder(model, text_path, args.opset, args.device)
311
+
312
+ # Save config
313
+ config_path = os.path.join(args.output_dir, "clap_config.json")
314
+ config = save_clap_config(model, config_path)
315
+
316
+ # Save tokenizer
317
+ tokenizer_dir = os.path.join(args.output_dir, "clap_tokenizer")
318
+ os.makedirs(tokenizer_dir, exist_ok=True)
319
+ save_clap_tokenizer(tokenizer_dir)
320
+
321
+ # Verify
322
+ if args.verify:
323
+ verify_clap(model, audio_path, text_path, config, args.device)
324
+
325
+ print(f"\n✓ Export complete!")
326
+ print(f" Audio encoder: {audio_path}")
327
+ print(f" Text encoder: {text_path}")
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()
onnx_inference.py CHANGED
@@ -177,6 +177,42 @@ class SAMAudioONNXPipeline:
177
  self.peaframe_config = json.load(f)
178
  print(" ✓ PEAFrame config loaded")
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  # Load tokenizer
181
  self._load_tokenizer()
182
  print(" ✓ Tokenizer loaded")
@@ -615,6 +651,154 @@ class SAMAudioONNXPipeline:
615
 
616
  return np.array([anchor_ids], dtype=np.int64), anchor_alignment
617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  def dit_step(
619
  self,
620
  noisy_audio: np.ndarray,
@@ -678,16 +862,25 @@ class SAMAudioONNXPipeline:
678
  predict_spans: bool = False,
679
  manual_anchors: Optional[list[tuple[str, float, float]]] = None,
680
  span_threshold: float = 0.3,
 
 
 
681
  ) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], float]:
682
  """
683
  Perform the full separation pipeline.
684
-
685
  Args:
686
  audio: Input mixture waveform
687
  text: Text description of the target source
688
  video_path: Optional path to a video for visual conditioning
689
  mask_path: Optional path to a video/image mask for visual prompting
690
-
 
 
 
 
 
 
691
  Returns:
692
  Tuple of (target audio, residual audio, masked video frames if any, fps)
693
  - target: The separated sound matching the text/visual prompt
@@ -740,41 +933,77 @@ class SAMAudioONNXPipeline:
740
  masked_video_features = self.encode_video(norm_frames) # This returns [B, 1024, T] (BCT)
741
  print(f" Video features shape: {masked_video_features.shape}")
742
 
743
- # 4. Run ODE solver (midpoint method)
744
- print("3. Running ODE solver...")
745
- # Start from random noise
746
- # Note: audio_features is [B, T, 256], DiT output is [B, T, 256]
747
- B, T, C = audio_features.shape
748
- x = np.random.randn(B, T, C).astype(np.float32)
749
-
750
- steps = self.num_ode_steps
751
- dt = 1.0 / steps
752
-
753
- for i in range(steps):
754
- t = i * dt
755
- print(f" ODE step {i+1}/{steps}", end="\r")
756
-
757
- k1 = self.dit_step(
758
- x, t, audio_features, text_features, text_mask,
759
- masked_video_features, anchor_ids, anchor_alignment
760
- )
761
- x_mid = x + k1 * (dt / 2.0)
762
- k2 = self.dit_step(
763
- x_mid, t + dt/2.0, audio_features, text_features, text_mask,
764
- masked_video_features, anchor_ids, anchor_alignment
765
  )
766
 
767
- x = x + k2 * dt
768
-
769
- # Extract target and residual latents
770
- # The DiT model produces [B, T, 256] where:
771
- # - First 128 channels = target (the separated sound)
772
- # - Last 128 channels = residual (everything else)
773
- # This matches the PyTorch implementation in sam_audio/model/model.py
774
- target_latent = x[:, :, :128].transpose(0, 2, 1) # [B, 128, T] for decoder
775
- residual_latent = x[:, :, 128:].transpose(0, 2, 1) # [B, 128, T] for decoder
776
- print(f"\n Target latent shape: {target_latent.shape}")
777
- print(f" Residual latent shape: {residual_latent.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
 
779
  # 5. Decode both to waveforms
780
  print("4. Decoding target audio...")
@@ -818,6 +1047,23 @@ def main():
818
  default=0.3,
819
  help="Threshold for span prediction (default: 0.3)",
820
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821
  parser.add_argument("--output", type=str, default="target.wav", help="Output WAV file path for target (separated) audio")
822
  parser.add_argument("--output-residual", type=str, default="residual.wav", help="Output WAV file path for residual audio")
823
  parser.add_argument("--output-video", type=str, help="Optional path to save masked video with separated audio")
@@ -870,6 +1116,9 @@ def main():
870
  predict_spans=args.predict_spans,
871
  manual_anchors=manual_anchors,
872
  span_threshold=args.span_threshold,
 
 
 
873
  )
874
 
875
  # Save output audio files
 
177
  self.peaframe_config = json.load(f)
178
  print(" ✓ PEAFrame config loaded")
179
 
180
+ # Load CLAP for reranking if available
181
+ self.clap_audio_encoder = None
182
+ self.clap_text_encoder = None
183
+ self.clap_tokenizer = None
184
+ self.clap_config = None
185
+
186
+ clap_audio_path = os.path.join(model_dir, "clap_audio_encoder.onnx")
187
+ clap_text_path = os.path.join(model_dir, "clap_text_encoder.onnx")
188
+
189
+ if os.path.exists(clap_audio_path) and os.path.exists(clap_text_path):
190
+ self.clap_audio_encoder = ort.InferenceSession(
191
+ clap_audio_path,
192
+ providers=providers,
193
+ )
194
+ print(" ✓ CLAP audio encoder loaded")
195
+
196
+ self.clap_text_encoder = ort.InferenceSession(
197
+ clap_text_path,
198
+ providers=providers,
199
+ )
200
+ print(" ✓ CLAP text encoder loaded")
201
+
202
+ # Load CLAP tokenizer
203
+ tokenizer_path = os.path.join(model_dir, "clap_tokenizer")
204
+ if os.path.exists(tokenizer_path):
205
+ from transformers import AutoTokenizer
206
+ self.clap_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
207
+ print(" ✓ CLAP tokenizer loaded")
208
+
209
+ # Load CLAP config
210
+ config_path = os.path.join(model_dir, "clap_config.json")
211
+ if os.path.exists(config_path):
212
+ with open(config_path) as f:
213
+ self.clap_config = json.load(f)
214
+ print(" ✓ CLAP config loaded")
215
+
216
  # Load tokenizer
217
  self._load_tokenizer()
218
  print(" ✓ Tokenizer loaded")
 
651
 
652
  return np.array([anchor_ids], dtype=np.int64), anchor_alignment
653
 
654
+ def score_with_clap(
655
+ self,
656
+ audio_candidates: list[np.ndarray],
657
+ text: str,
658
+ ) -> np.ndarray:
659
+ """
660
+ Score audio candidates against text using CLAP.
661
+
662
+ The CLAP audio encoder expects waveforms at 48kHz, padded/truncated to
663
+ 10 seconds (480000 samples).
664
+
665
+ Args:
666
+ audio_candidates: List of audio waveforms, each shape (samples,)
667
+ text: Text description to match against
668
+
669
+ Returns:
670
+ scores: Array of similarity scores, shape (num_candidates,)
671
+ """
672
+ if self.clap_audio_encoder is None:
673
+ raise RuntimeError("CLAP audio encoder not loaded")
674
+ if self.clap_text_encoder is None:
675
+ raise RuntimeError("CLAP text encoder not loaded")
676
+ if self.clap_tokenizer is None:
677
+ raise RuntimeError("CLAP tokenizer not loaded")
678
+ if self.clap_config is None:
679
+ raise RuntimeError("CLAP config not loaded")
680
+
681
+ config = self.clap_config
682
+ max_audio_len = config.get("max_audio_len", 480000)
683
+
684
+ # Encode text (only once, same for all candidates)
685
+ tokens = self.clap_tokenizer(
686
+ text,
687
+ return_tensors="np",
688
+ padding=True,
689
+ truncation=True,
690
+ max_length=77,
691
+ )
692
+
693
+ text_embed = self.clap_text_encoder.run(
694
+ ["text_embed"],
695
+ {
696
+ "input_ids": tokens["input_ids"].astype(np.int64),
697
+ "attention_mask": tokens["attention_mask"].astype(np.int64),
698
+ },
699
+ )[0] # [1, 512]
700
+
701
+ # Encode each audio candidate
702
+ audio_embeds = []
703
+ for audio in audio_candidates:
704
+ # Preprocess: quantize, pad/truncate
705
+ # Match PyTorch: int16_to_float32(float32_to_int16(audio))
706
+ audio = (audio * 32768.0).astype(np.int16).astype(np.float32) / 32768.0
707
+
708
+ # Pad or truncate to max_audio_len
709
+ if len(audio) > max_audio_len:
710
+ audio = audio[:max_audio_len]
711
+ elif len(audio) < max_audio_len:
712
+ # Repeat-pad
713
+ n_repeat = int(np.ceil(max_audio_len / len(audio)))
714
+ audio = np.tile(audio, n_repeat)[:max_audio_len]
715
+
716
+ # Reshape for CLAP: [batch, samples]
717
+ audio_input = audio.reshape(1, -1).astype(np.float32)
718
+
719
+ # Encode audio
720
+ audio_embed = self.clap_audio_encoder.run(
721
+ ["audio_embed"],
722
+ {"waveform": audio_input},
723
+ )[0] # [1, 512]
724
+
725
+ audio_embeds.append(audio_embed)
726
+
727
+ # Stack audio embeddings: [num_candidates, 512]
728
+ audio_embeds = np.concatenate(audio_embeds, axis=0)
729
+
730
+ # Compute similarity scores: audio @ text.T
731
+ # audio_embeds: [num_candidates, 512]
732
+ # text_embed: [1, 512]
733
+ scores = np.matmul(audio_embeds, text_embed.T).squeeze(-1) # [num_candidates]
734
+
735
+ return scores
736
+
737
+ def generate_candidates(
738
+ self,
739
+ audio_features: np.ndarray,
740
+ text_features: np.ndarray,
741
+ text_mask: np.ndarray,
742
+ num_candidates: int = 4,
743
+ masked_video_features: Optional[np.ndarray] = None,
744
+ anchor_ids: Optional[np.ndarray] = None,
745
+ anchor_alignment: Optional[np.ndarray] = None,
746
+ seed: Optional[int] = None,
747
+ ) -> list[tuple[np.ndarray, np.ndarray]]:
748
+ """
749
+ Generate multiple separation candidates with different random seeds.
750
+
751
+ Args:
752
+ audio_features: Encoded audio features [B, T, C]
753
+ text_features: Encoded text features
754
+ text_mask: Text attention mask
755
+ num_candidates: Number of candidates to generate
756
+ masked_video_features: Optional video features
757
+ anchor_ids: Optional anchor IDs
758
+ anchor_alignment: Optional anchor alignment
759
+ seed: Base random seed (candidates use seed, seed+1, seed+2, ...)
760
+
761
+ Returns:
762
+ List of (target_latent, residual_latent) tuples
763
+ """
764
+ B, T, C = audio_features.shape
765
+
766
+ candidates = []
767
+
768
+ for i in range(num_candidates):
769
+ # Set seed for reproducibility
770
+ if seed is not None:
771
+ np.random.seed(seed + i)
772
+
773
+ # Initialize with different random noise
774
+ x = np.random.randn(B, T, C).astype(np.float32)
775
+
776
+ # Run ODE solver
777
+ steps = self.num_ode_steps
778
+ dt = 1.0 / steps
779
+
780
+ for step_idx in range(steps):
781
+ t = step_idx * dt
782
+
783
+ k1 = self.dit_step(
784
+ x, t, audio_features, text_features, text_mask,
785
+ masked_video_features, anchor_ids, anchor_alignment
786
+ )
787
+ x_mid = x + k1 * (dt / 2.0)
788
+ k2 = self.dit_step(
789
+ x_mid, t + dt/2.0, audio_features, text_features, text_mask,
790
+ masked_video_features, anchor_ids, anchor_alignment
791
+ )
792
+ x = x + k2 * dt
793
+
794
+ # Extract target and residual latents
795
+ target_latent = x[:, :, :128].transpose(0, 2, 1) # [B, 128, T]
796
+ residual_latent = x[:, :, 128:].transpose(0, 2, 1) # [B, 128, T]
797
+
798
+ candidates.append((target_latent, residual_latent))
799
+
800
+ return candidates
801
+
802
  def dit_step(
803
  self,
804
  noisy_audio: np.ndarray,
 
862
  predict_spans: bool = False,
863
  manual_anchors: Optional[list[tuple[str, float, float]]] = None,
864
  span_threshold: float = 0.3,
865
+ rerank: bool = False,
866
+ num_candidates: int = 4,
867
+ rerank_seed: Optional[int] = None,
868
  ) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], float]:
869
  """
870
  Perform the full separation pipeline.
871
+
872
  Args:
873
  audio: Input mixture waveform
874
  text: Text description of the target source
875
  video_path: Optional path to a video for visual conditioning
876
  mask_path: Optional path to a video/image mask for visual prompting
877
+ predict_spans: Whether to use PEAFrame for span prediction
878
+ manual_anchors: Optional list of manual anchor spans
879
+ span_threshold: Threshold for span prediction
880
+ rerank: Whether to generate multiple candidates and rerank with CLAP
881
+ num_candidates: Number of candidates for reranking
882
+ rerank_seed: Random seed for reproducible candidate generation
883
+
884
  Returns:
885
  Tuple of (target audio, residual audio, masked video frames if any, fps)
886
  - target: The separated sound matching the text/visual prompt
 
933
  masked_video_features = self.encode_video(norm_frames) # This returns [B, 1024, T] (BCT)
934
  print(f" Video features shape: {masked_video_features.shape}")
935
 
936
+ # 4. Run ODE solver (with optional reranking)
937
+ if rerank and self.clap_audio_encoder is not None:
938
+ print(f"3. Generating {num_candidates} candidates for reranking...")
939
+
940
+ # Generate multiple candidates
941
+ candidates = self.generate_candidates(
942
+ audio_features, text_features, text_mask,
943
+ num_candidates=num_candidates,
944
+ masked_video_features=masked_video_features,
945
+ anchor_ids=anchor_ids,
946
+ anchor_alignment=anchor_alignment,
947
+ seed=rerank_seed,
 
 
 
 
 
 
 
 
 
 
948
  )
949
 
950
+ # Decode all candidate audios
951
+ print("3b. Decoding candidate audios...")
952
+ candidate_audios = []
953
+ for i, (target_latent, _) in enumerate(candidates):
954
+ decoded = self.decode_audio(target_latent)
955
+ candidate_audios.append(decoded)
956
+ print(f" Candidate {i+1}/{num_candidates} decoded", end="\r")
957
+ print()
958
+
959
+ # Score with CLAP
960
+ print("3c. Scoring candidates with CLAP...")
961
+ scores = self.score_with_clap(candidate_audios, text)
962
+ best_idx = int(np.argmax(scores))
963
+ print(f" Scores: {scores}")
964
+ print(f" Selected candidate {best_idx + 1}/{num_candidates} (score: {scores[best_idx]:.4f})")
965
+
966
+ # Use best candidate
967
+ target_latent, residual_latent = candidates[best_idx]
968
+ print(f" Target latent shape: {target_latent.shape}")
969
+ print(f" Residual latent shape: {residual_latent.shape}")
970
+
971
+ else:
972
+ # Single candidate path (original behavior)
973
+ print("3. Running ODE solver...")
974
+ # Start from random noise
975
+ # Note: audio_features is [B, T, 256], DiT output is [B, T, 256]
976
+ B, T, C = audio_features.shape
977
+ x = np.random.randn(B, T, C).astype(np.float32)
978
+
979
+ steps = self.num_ode_steps
980
+ dt = 1.0 / steps
981
+
982
+ for i in range(steps):
983
+ t = i * dt
984
+ print(f" ODE step {i+1}/{steps}", end="\r")
985
+
986
+ k1 = self.dit_step(
987
+ x, t, audio_features, text_features, text_mask,
988
+ masked_video_features, anchor_ids, anchor_alignment
989
+ )
990
+ x_mid = x + k1 * (dt / 2.0)
991
+ k2 = self.dit_step(
992
+ x_mid, t + dt/2.0, audio_features, text_features, text_mask,
993
+ masked_video_features, anchor_ids, anchor_alignment
994
+ )
995
+
996
+ x = x + k2 * dt
997
+
998
+ # Extract target and residual latents
999
+ # The DiT model produces [B, T, 256] where:
1000
+ # - First 128 channels = target (the separated sound)
1001
+ # - Last 128 channels = residual (everything else)
1002
+ # This matches the PyTorch implementation in sam_audio/model/model.py
1003
+ target_latent = x[:, :, :128].transpose(0, 2, 1) # [B, 128, T] for decoder
1004
+ residual_latent = x[:, :, 128:].transpose(0, 2, 1) # [B, 128, T] for decoder
1005
+ print(f"\n Target latent shape: {target_latent.shape}")
1006
+ print(f" Residual latent shape: {residual_latent.shape}")
1007
 
1008
  # 5. Decode both to waveforms
1009
  print("4. Decoding target audio...")
 
1047
  default=0.3,
1048
  help="Threshold for span prediction (default: 0.3)",
1049
  )
1050
+ parser.add_argument(
1051
+ "--rerank",
1052
+ action="store_true",
1053
+ help="Generate multiple candidates and rerank with CLAP",
1054
+ )
1055
+ parser.add_argument(
1056
+ "--num-candidates",
1057
+ type=int,
1058
+ default=4,
1059
+ help="Number of candidates for reranking (default: 4)",
1060
+ )
1061
+ parser.add_argument(
1062
+ "--rerank-seed",
1063
+ type=int,
1064
+ default=None,
1065
+ help="Random seed for reproducible candidate generation",
1066
+ )
1067
  parser.add_argument("--output", type=str, default="target.wav", help="Output WAV file path for target (separated) audio")
1068
  parser.add_argument("--output-residual", type=str, default="residual.wav", help="Output WAV file path for residual audio")
1069
  parser.add_argument("--output-video", type=str, help="Optional path to save masked video with separated audio")
 
1116
  predict_spans=args.predict_spans,
1117
  manual_anchors=manual_anchors,
1118
  span_threshold=args.span_threshold,
1119
+ rerank=args.rerank,
1120
+ num_candidates=args.num_candidates,
1121
+ rerank_seed=args.rerank_seed,
1122
  )
1123
 
1124
  # Save output audio files
residual.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4dfbb54fecf275f6cb4c13e934ccd2971ed17e454c7e52152dc8ae69fedf808
3
+ size 960044