matbee commited on
Commit
0b98973
·
verified ·
1 Parent(s): 820e270

Upload test_e2e.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_e2e.py +375 -0
test_e2e.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ End-to-end test comparing PyTorch SAM Audio with ONNX Runtime.
4
+
5
+ This script:
6
+ 1. Loads a real audio sample from AudioCaps
7
+ 2. Runs PyTorch inference using the original SAMAudio model
8
+ 3. Runs ONNX inference using the exported models
9
+ 4. Compares the output waveforms
10
+ """
11
+
12
+ import torch
13
+ import torchaudio
14
+ import numpy as np
15
+ import os
16
+ from datasets import load_dataset
17
+
18
+
19
+ def load_audiocaps_sample():
20
+ """Load a sample from AudioCaps dataset."""
21
+ print("Loading AudioCaps sample...")
22
+ dset = load_dataset(
23
+ "parquet",
24
+ data_files="hf://datasets/OpenSound/AudioCaps/data/test-00000-of-00041.parquet",
25
+ )
26
+ sample = dset["train"][8]["audio"].get_all_samples()
27
+ print(f" Sample rate: {sample.sample_rate}")
28
+ print(f" Duration: {sample.data.shape[-1] / sample.sample_rate:.2f}s")
29
+ return sample
30
+
31
+
32
+ def run_pytorch_inference(sample, device="cpu"):
33
+ """Run inference using PyTorch SAMAudio model."""
34
+ print("\n=== PyTorch Inference ===")
35
+
36
+ from sam_audio import SAMAudio, SAMAudioProcessor
37
+
38
+ # Load model and processor
39
+ print("Loading SAMAudio model...")
40
+ model = SAMAudio.from_pretrained("facebook/sam-audio-small").to(device).eval()
41
+ processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-small")
42
+
43
+ # Resample and prepare input
44
+ wav = torchaudio.functional.resample(
45
+ sample.data, sample.sample_rate, processor.audio_sampling_rate
46
+ )
47
+ wav = wav.mean(0, keepdim=True) # Convert to mono
48
+
49
+ print(f" Input audio shape: {wav.shape}")
50
+ print(f" Sample rate: {processor.audio_sampling_rate}")
51
+
52
+ # Prepare inputs with explicit anchor
53
+ inputs = processor(
54
+ audios=[wav],
55
+ descriptions=["A horn honking"],
56
+ anchors=[[["+", 6.3, 7.0]]]
57
+ ).to(device)
58
+
59
+ # Run separation
60
+ print("Running separation...")
61
+ with torch.inference_mode():
62
+ result = model.separate(inputs)
63
+
64
+ separated_audio = result.target[0].cpu().numpy()
65
+ print(f" Output shape: {separated_audio.shape}")
66
+
67
+ return separated_audio, processor.audio_sampling_rate, wav.numpy()
68
+
69
+
70
+ def run_onnx_inference(sample, model_dir="."):
71
+ """Run inference using ONNX models."""
72
+ print("\n=== ONNX Runtime Inference ===")
73
+
74
+ import onnxruntime as ort
75
+ from transformers import AutoTokenizer
76
+ import json
77
+
78
+ # Load models
79
+ print("Loading ONNX models...")
80
+ providers = ["CPUExecutionProvider"]
81
+
82
+ dacvae_encoder = ort.InferenceSession(
83
+ os.path.join(model_dir, "dacvae_encoder.onnx"),
84
+ providers=providers,
85
+ )
86
+ dacvae_decoder = ort.InferenceSession(
87
+ os.path.join(model_dir, "dacvae_decoder.onnx"),
88
+ providers=providers,
89
+ )
90
+ t5_encoder = ort.InferenceSession(
91
+ os.path.join(model_dir, "t5_encoder.onnx"),
92
+ providers=providers,
93
+ )
94
+ dit = ort.InferenceSession(
95
+ os.path.join(model_dir, "dit_single_step.onnx"),
96
+ providers=providers,
97
+ )
98
+
99
+ # Load tokenizer
100
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_dir, "tokenizer"))
101
+ print(" All models loaded")
102
+
103
+ # Prepare audio (resample to 44.1kHz for DACVAE)
104
+ wav = torchaudio.functional.resample(
105
+ sample.data, sample.sample_rate, 44100
106
+ )
107
+ wav = wav.mean(0, keepdim=True) # Convert to mono
108
+ audio = wav.numpy().reshape(1, 1, -1).astype(np.float32)
109
+
110
+ print(f" Input audio shape: {audio.shape}")
111
+
112
+ # 1. Encode audio
113
+ print("Encoding audio...")
114
+ latent = dacvae_encoder.run(
115
+ ["latent_features"],
116
+ {"audio": audio}
117
+ )[0]
118
+ print(f" Audio latent shape: {latent.shape}")
119
+
120
+ # 2. Encode text
121
+ print("Encoding text...")
122
+ tokens = tokenizer(
123
+ "A horn honking",
124
+ return_tensors="np",
125
+ padding=True,
126
+ truncation=True,
127
+ max_length=77,
128
+ )
129
+ text_features = t5_encoder.run(
130
+ ["hidden_states"],
131
+ {
132
+ "input_ids": tokens["input_ids"].astype(np.int64),
133
+ "attention_mask": tokens["attention_mask"].astype(np.int64),
134
+ }
135
+ )[0]
136
+ print(f" Text features shape: {text_features.shape}")
137
+
138
+ # 3. Run ODE solving (simplified - just one step for testing)
139
+ print("Running DiT (simplified test - 1 step)...")
140
+ batch_size = 1
141
+ latent_dim = latent.shape[1] # 128
142
+ time_steps = latent.shape[2]
143
+
144
+ # Prepare inputs
145
+ # SAMAudio._get_audio_features: returns torch.cat([audio_features, audio_features], dim=2)
146
+ # So audio_features is the mixture DUPLICATED, not mixture + zeros!
147
+ mixture_features = latent.transpose(0, 2, 1) # (B, T, 128) - from DACVAE
148
+
149
+ # Duplicate mixture features (this is what SAMAudio actually does)
150
+ audio_features = np.concatenate([
151
+ mixture_features, # Mixture latent
152
+ mixture_features # Mixture latent (DUPLICATE - not zeros!)
153
+ ], axis=-1) # -> (B, T, 256)
154
+
155
+ # noisy_audio starts from random noise for ODE solving from t=0 to t=1
156
+ # SAMAudio uses: noise = torch.randn_like(audio_features)
157
+ initial = np.random.randn(batch_size, time_steps, 256).astype(np.float32)
158
+
159
+ # Just run one step to verify the model works
160
+ velocity = dit.run(
161
+ ["velocity"],
162
+ {
163
+ "noisy_audio": initial,
164
+ "time": np.array([0.0], dtype=np.float32),
165
+ "audio_features": audio_features,
166
+ "text_features": text_features,
167
+ "text_mask": tokens["attention_mask"].astype(bool),
168
+ "masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32),
169
+ "anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64),
170
+ "anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64),
171
+ "audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool),
172
+ }
173
+ )[0]
174
+ print(f" DiT velocity shape: {velocity.shape}")
175
+
176
+
177
+ # 4. Run full ODE solve (16 steps midpoint method)
178
+ print("Running full ODE solve (16 steps)...")
179
+ num_steps = 16
180
+ dt = 1.0 / num_steps
181
+ x = initial.copy()
182
+
183
+ for i in range(num_steps):
184
+ t = np.array([i * dt], dtype=np.float32)
185
+ t_mid = np.array([t[0] + dt / 2], dtype=np.float32)
186
+
187
+ # k1 = f(t, x)
188
+ k1 = dit.run(
189
+ ["velocity"],
190
+ {
191
+ "noisy_audio": x,
192
+ "time": t,
193
+ "audio_features": audio_features,
194
+ "text_features": text_features,
195
+ "text_mask": tokens["attention_mask"].astype(bool),
196
+ "masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32),
197
+ "anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64),
198
+ "anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64),
199
+ "audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool),
200
+ }
201
+ )[0]
202
+
203
+ # Midpoint
204
+ x_mid = x + (dt / 2) * k1
205
+
206
+ # k2 = f(t_mid, x_mid)
207
+ k2 = dit.run(
208
+ ["velocity"],
209
+ {
210
+ "noisy_audio": x_mid,
211
+ "time": t_mid,
212
+ "audio_features": audio_features,
213
+ "text_features": text_features,
214
+ "text_mask": tokens["attention_mask"].astype(bool),
215
+ "masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32),
216
+ "anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64),
217
+ "anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64),
218
+ "audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool),
219
+ }
220
+ )[0]
221
+
222
+ # Update
223
+ x = x + dt * k2
224
+ print(f" Step {i+1}/{num_steps}")
225
+
226
+ # 5. Extract separated latent and decode in chunks
227
+ # (The DACVAE decoder was exported with fixed time=25, so we decode in chunks)
228
+ print("Decoding audio...")
229
+ # SAMAudio: target is first 128 dims, residual is second 128 dims
230
+ # generated_features.reshape(2*B, C, T) -> first B = channels 0:128 (target)
231
+ target_latent = x[:, :, :latent_dim].transpose(0, 2, 1) # (B, 128, T) - TARGET
232
+ separated_latent = target_latent
233
+
234
+ # The decoder expects chunks of 25 time steps
235
+ chunk_size = 25
236
+ T = separated_latent.shape[2]
237
+
238
+ # Process in chunks and concatenate
239
+ audio_chunks = []
240
+ for start_idx in range(0, T, chunk_size):
241
+ end_idx = min(start_idx + chunk_size, T)
242
+ chunk = separated_latent[:, :, start_idx:end_idx]
243
+
244
+ # Pad last chunk if needed
245
+ actual_size = chunk.shape[2]
246
+ if actual_size < chunk_size:
247
+ pad_size = chunk_size - actual_size
248
+ chunk = np.pad(chunk, ((0, 0), (0, 0), (0, pad_size)), mode='constant')
249
+
250
+ chunk_audio = dacvae_decoder.run(
251
+ ["waveform"],
252
+ {"latent_features": chunk.astype(np.float32)}
253
+ )[0]
254
+
255
+ # For padded chunks, trim the output
256
+ if actual_size < chunk_size:
257
+ # Each time step produces hop_length (1920) samples at 48kHz
258
+ samples_per_step = 1920
259
+ trim_samples = actual_size * samples_per_step
260
+ chunk_audio = chunk_audio[:, :, :trim_samples]
261
+
262
+ audio_chunks.append(chunk_audio)
263
+ print(f" Decoded chunk {start_idx//chunk_size + 1}/{(T + chunk_size - 1)//chunk_size}")
264
+
265
+ # Concatenate all chunks
266
+ separated_audio = np.concatenate(audio_chunks, axis=2)
267
+
268
+ print(f" Output audio shape: {separated_audio.shape}")
269
+
270
+ return separated_audio.squeeze(), 44100
271
+
272
+
273
+
274
+ def compare_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr):
275
+ """Compare PyTorch and ONNX outputs."""
276
+ print("\n=== Comparison ===")
277
+
278
+ import scipy.signal
279
+
280
+ # Resample to same rate if needed
281
+ if pytorch_sr != onnx_sr:
282
+ print(f"Resampling PyTorch output from {pytorch_sr} to {onnx_sr}...")
283
+ # Use scipy for resampling
284
+ num_samples = int(len(pytorch_audio) * onnx_sr / pytorch_sr)
285
+ pytorch_audio_resampled = scipy.signal.resample(pytorch_audio, num_samples)
286
+ else:
287
+ pytorch_audio_resampled = pytorch_audio
288
+
289
+ # Trim to same length
290
+ min_len = min(len(pytorch_audio_resampled), len(onnx_audio))
291
+ pytorch_trimmed = pytorch_audio_resampled[:min_len]
292
+ onnx_trimmed = onnx_audio[:min_len]
293
+
294
+ # Compute differences
295
+ diff = np.abs(pytorch_trimmed - onnx_trimmed)
296
+ max_diff = diff.max()
297
+ mean_diff = diff.mean()
298
+
299
+ # Compute correlation
300
+ correlation = np.corrcoef(pytorch_trimmed, onnx_trimmed)[0, 1]
301
+
302
+ print(f" PyTorch audio length: {len(pytorch_audio)} samples")
303
+ print(f" ONNX audio length: {len(onnx_audio)} samples")
304
+ print(f" Max difference: {max_diff:.6f}")
305
+ print(f" Mean difference: {mean_diff:.6f}")
306
+ print(f" Correlation: {correlation:.6f}")
307
+
308
+ return max_diff, mean_diff, correlation
309
+
310
+
311
+ def save_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr, input_audio, input_sr):
312
+ """Save audio outputs for listening comparison."""
313
+ import soundfile as sf
314
+
315
+ output_dir = "test_outputs"
316
+ os.makedirs(output_dir, exist_ok=True)
317
+
318
+ # Save input
319
+ sf.write(os.path.join(output_dir, "input.wav"), input_audio.squeeze(), input_sr)
320
+ print(f"Saved input to {output_dir}/input.wav")
321
+
322
+ # Save PyTorch output
323
+ sf.write(os.path.join(output_dir, "pytorch_output.wav"), pytorch_audio, pytorch_sr)
324
+ print(f"Saved PyTorch output to {output_dir}/pytorch_output.wav")
325
+
326
+ # Save ONNX output
327
+ sf.write(os.path.join(output_dir, "onnx_output.wav"), onnx_audio, onnx_sr)
328
+ print(f"Saved ONNX output to {output_dir}/onnx_output.wav")
329
+
330
+
331
+ def main():
332
+ import argparse
333
+
334
+ parser = argparse.ArgumentParser(description="End-to-end SAM Audio test")
335
+ parser.add_argument("--model-dir", default=".", help="ONNX model directory")
336
+ parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"])
337
+ parser.add_argument("--save-outputs", action="store_true", help="Save audio files")
338
+ parser.add_argument("--skip-pytorch", action="store_true", help="Skip PyTorch inference")
339
+ args = parser.parse_args()
340
+
341
+ # Load sample
342
+ sample = load_audiocaps_sample()
343
+
344
+ # Run PyTorch inference
345
+ if not args.skip_pytorch:
346
+ pytorch_audio, pytorch_sr, input_audio = run_pytorch_inference(sample, args.device)
347
+ else:
348
+ print("\nSkipping PyTorch inference")
349
+ pytorch_audio, pytorch_sr = None, None
350
+ input_audio = sample.data.mean(0).numpy()
351
+
352
+ # Run ONNX inference
353
+ onnx_audio, onnx_sr = run_onnx_inference(sample, args.model_dir)
354
+
355
+ # Compare outputs
356
+ if pytorch_audio is not None:
357
+ compare_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr)
358
+
359
+ # Save outputs
360
+ if args.save_outputs:
361
+ print("\n=== Saving Outputs ===")
362
+ if pytorch_audio is not None:
363
+ save_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr,
364
+ input_audio, sample.sample_rate)
365
+ else:
366
+ import soundfile as sf
367
+ os.makedirs("test_outputs", exist_ok=True)
368
+ sf.write("test_outputs/onnx_output.wav", onnx_audio, onnx_sr)
369
+ print("Saved ONNX output to test_outputs/onnx_output.wav")
370
+
371
+ print("\n✓ End-to-end test complete!")
372
+
373
+
374
+ if __name__ == "__main__":
375
+ main()