| | |
| | """ |
| | End-to-end test comparing PyTorch SAM Audio with ONNX Runtime. |
| | |
| | This script: |
| | 1. Loads a real audio sample from AudioCaps |
| | 2. Runs PyTorch inference using the original SAMAudio model |
| | 3. Runs ONNX inference using the exported models |
| | 4. Compares the output waveforms |
| | """ |
| |
|
| | import torch |
| | import torchaudio |
| | import numpy as np |
| | import os |
| | from datasets import load_dataset |
| |
|
| |
|
| | def load_audiocaps_sample(): |
| | """Load a sample from AudioCaps dataset.""" |
| | print("Loading AudioCaps sample...") |
| | dset = load_dataset( |
| | "parquet", |
| | data_files="hf://datasets/OpenSound/AudioCaps/data/test-00000-of-00041.parquet", |
| | ) |
| | sample = dset["train"][8]["audio"].get_all_samples() |
| | print(f" Sample rate: {sample.sample_rate}") |
| | print(f" Duration: {sample.data.shape[-1] / sample.sample_rate:.2f}s") |
| | return sample |
| |
|
| |
|
| | def run_pytorch_inference(sample, device="cpu"): |
| | """Run inference using PyTorch SAMAudio model.""" |
| | print("\n=== PyTorch Inference ===") |
| | |
| | from sam_audio import SAMAudio, SAMAudioProcessor |
| | |
| | |
| | print("Loading SAMAudio model...") |
| | model = SAMAudio.from_pretrained("facebook/sam-audio-small").to(device).eval() |
| | processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-small") |
| | |
| | |
| | wav = torchaudio.functional.resample( |
| | sample.data, sample.sample_rate, processor.audio_sampling_rate |
| | ) |
| | wav = wav.mean(0, keepdim=True) |
| | |
| | print(f" Input audio shape: {wav.shape}") |
| | print(f" Sample rate: {processor.audio_sampling_rate}") |
| | |
| | |
| | inputs = processor( |
| | audios=[wav], |
| | descriptions=["A horn honking"], |
| | anchors=[[["+", 6.3, 7.0]]] |
| | ).to(device) |
| | |
| | |
| | print("Running separation...") |
| | with torch.inference_mode(): |
| | result = model.separate(inputs) |
| | |
| | separated_audio = result.target[0].cpu().numpy() |
| | print(f" Output shape: {separated_audio.shape}") |
| | |
| | return separated_audio, processor.audio_sampling_rate, wav.numpy() |
| |
|
| |
|
| | def run_onnx_inference(sample, model_dir="."): |
| | """Run inference using ONNX models.""" |
| | print("\n=== ONNX Runtime Inference ===") |
| | |
| | import onnxruntime as ort |
| | from transformers import AutoTokenizer |
| | import json |
| | |
| | |
| | print("Loading ONNX models...") |
| | providers = ["CPUExecutionProvider"] |
| | |
| | dacvae_encoder = ort.InferenceSession( |
| | os.path.join(model_dir, "dacvae_encoder.onnx"), |
| | providers=providers, |
| | ) |
| | dacvae_decoder = ort.InferenceSession( |
| | os.path.join(model_dir, "dacvae_decoder.onnx"), |
| | providers=providers, |
| | ) |
| | t5_encoder = ort.InferenceSession( |
| | os.path.join(model_dir, "t5_encoder.onnx"), |
| | providers=providers, |
| | ) |
| | dit = ort.InferenceSession( |
| | os.path.join(model_dir, "dit_single_step.onnx"), |
| | providers=providers, |
| | ) |
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_dir, "tokenizer")) |
| | print(" All models loaded") |
| | |
| | |
| | wav = torchaudio.functional.resample( |
| | sample.data, sample.sample_rate, 44100 |
| | ) |
| | wav = wav.mean(0, keepdim=True) |
| | audio = wav.numpy().reshape(1, 1, -1).astype(np.float32) |
| | |
| | print(f" Input audio shape: {audio.shape}") |
| | |
| | |
| | print("Encoding audio...") |
| | latent = dacvae_encoder.run( |
| | ["latent_features"], |
| | {"audio": audio} |
| | )[0] |
| | print(f" Audio latent shape: {latent.shape}") |
| | |
| | |
| | print("Encoding text...") |
| | tokens = tokenizer( |
| | "A horn honking", |
| | return_tensors="np", |
| | padding=True, |
| | truncation=True, |
| | max_length=77, |
| | ) |
| | text_features = t5_encoder.run( |
| | ["hidden_states"], |
| | { |
| | "input_ids": tokens["input_ids"].astype(np.int64), |
| | "attention_mask": tokens["attention_mask"].astype(np.int64), |
| | } |
| | )[0] |
| | print(f" Text features shape: {text_features.shape}") |
| | |
| | |
| | print("Running DiT (simplified test - 1 step)...") |
| | batch_size = 1 |
| | latent_dim = latent.shape[1] |
| | time_steps = latent.shape[2] |
| | |
| | |
| | |
| | |
| | mixture_features = latent.transpose(0, 2, 1) |
| | |
| | |
| | audio_features = np.concatenate([ |
| | mixture_features, |
| | mixture_features |
| | ], axis=-1) |
| | |
| | |
| | |
| | initial = np.random.randn(batch_size, time_steps, 256).astype(np.float32) |
| | |
| | |
| | velocity = dit.run( |
| | ["velocity"], |
| | { |
| | "noisy_audio": initial, |
| | "time": np.array([0.0], dtype=np.float32), |
| | "audio_features": audio_features, |
| | "text_features": text_features, |
| | "text_mask": tokens["attention_mask"].astype(bool), |
| | "masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32), |
| | "anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64), |
| | "anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64), |
| | "audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool), |
| | } |
| | )[0] |
| | print(f" DiT velocity shape: {velocity.shape}") |
| |
|
| | |
| | |
| | print("Running full ODE solve (16 steps)...") |
| | num_steps = 16 |
| | dt = 1.0 / num_steps |
| | x = initial.copy() |
| | |
| | for i in range(num_steps): |
| | t = np.array([i * dt], dtype=np.float32) |
| | t_mid = np.array([t[0] + dt / 2], dtype=np.float32) |
| | |
| | |
| | k1 = dit.run( |
| | ["velocity"], |
| | { |
| | "noisy_audio": x, |
| | "time": t, |
| | "audio_features": audio_features, |
| | "text_features": text_features, |
| | "text_mask": tokens["attention_mask"].astype(bool), |
| | "masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32), |
| | "anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64), |
| | "anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64), |
| | "audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool), |
| | } |
| | )[0] |
| | |
| | |
| | x_mid = x + (dt / 2) * k1 |
| | |
| | |
| | k2 = dit.run( |
| | ["velocity"], |
| | { |
| | "noisy_audio": x_mid, |
| | "time": t_mid, |
| | "audio_features": audio_features, |
| | "text_features": text_features, |
| | "text_mask": tokens["attention_mask"].astype(bool), |
| | "masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32), |
| | "anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64), |
| | "anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64), |
| | "audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool), |
| | } |
| | )[0] |
| | |
| | |
| | x = x + dt * k2 |
| | print(f" Step {i+1}/{num_steps}") |
| | |
| | |
| | |
| | print("Decoding audio...") |
| | |
| | |
| | target_latent = x[:, :, :latent_dim].transpose(0, 2, 1) |
| | separated_latent = target_latent |
| | |
| | |
| | chunk_size = 25 |
| | T = separated_latent.shape[2] |
| | |
| | |
| | audio_chunks = [] |
| | for start_idx in range(0, T, chunk_size): |
| | end_idx = min(start_idx + chunk_size, T) |
| | chunk = separated_latent[:, :, start_idx:end_idx] |
| | |
| | |
| | actual_size = chunk.shape[2] |
| | if actual_size < chunk_size: |
| | pad_size = chunk_size - actual_size |
| | chunk = np.pad(chunk, ((0, 0), (0, 0), (0, pad_size)), mode='constant') |
| | |
| | chunk_audio = dacvae_decoder.run( |
| | ["waveform"], |
| | {"latent_features": chunk.astype(np.float32)} |
| | )[0] |
| | |
| | |
| | if actual_size < chunk_size: |
| | |
| | samples_per_step = 1920 |
| | trim_samples = actual_size * samples_per_step |
| | chunk_audio = chunk_audio[:, :, :trim_samples] |
| | |
| | audio_chunks.append(chunk_audio) |
| | print(f" Decoded chunk {start_idx//chunk_size + 1}/{(T + chunk_size - 1)//chunk_size}") |
| | |
| | |
| | separated_audio = np.concatenate(audio_chunks, axis=2) |
| | |
| | print(f" Output audio shape: {separated_audio.shape}") |
| | |
| | return separated_audio.squeeze(), 44100 |
| |
|
| |
|
| |
|
| | def compare_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr): |
| | """Compare PyTorch and ONNX outputs.""" |
| | print("\n=== Comparison ===") |
| | |
| | import scipy.signal |
| | |
| | |
| | if pytorch_sr != onnx_sr: |
| | print(f"Resampling PyTorch output from {pytorch_sr} to {onnx_sr}...") |
| | |
| | num_samples = int(len(pytorch_audio) * onnx_sr / pytorch_sr) |
| | pytorch_audio_resampled = scipy.signal.resample(pytorch_audio, num_samples) |
| | else: |
| | pytorch_audio_resampled = pytorch_audio |
| | |
| | |
| | min_len = min(len(pytorch_audio_resampled), len(onnx_audio)) |
| | pytorch_trimmed = pytorch_audio_resampled[:min_len] |
| | onnx_trimmed = onnx_audio[:min_len] |
| | |
| | |
| | diff = np.abs(pytorch_trimmed - onnx_trimmed) |
| | max_diff = diff.max() |
| | mean_diff = diff.mean() |
| | |
| | |
| | correlation = np.corrcoef(pytorch_trimmed, onnx_trimmed)[0, 1] |
| | |
| | print(f" PyTorch audio length: {len(pytorch_audio)} samples") |
| | print(f" ONNX audio length: {len(onnx_audio)} samples") |
| | print(f" Max difference: {max_diff:.6f}") |
| | print(f" Mean difference: {mean_diff:.6f}") |
| | print(f" Correlation: {correlation:.6f}") |
| | |
| | return max_diff, mean_diff, correlation |
| |
|
| |
|
| | def save_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr, input_audio, input_sr): |
| | """Save audio outputs for listening comparison.""" |
| | import soundfile as sf |
| | |
| | output_dir = "test_outputs" |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | sf.write(os.path.join(output_dir, "input.wav"), input_audio.squeeze(), input_sr) |
| | print(f"Saved input to {output_dir}/input.wav") |
| | |
| | |
| | sf.write(os.path.join(output_dir, "pytorch_output.wav"), pytorch_audio, pytorch_sr) |
| | print(f"Saved PyTorch output to {output_dir}/pytorch_output.wav") |
| | |
| | |
| | sf.write(os.path.join(output_dir, "onnx_output.wav"), onnx_audio, onnx_sr) |
| | print(f"Saved ONNX output to {output_dir}/onnx_output.wav") |
| |
|
| |
|
| | def main(): |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="End-to-end SAM Audio test") |
| | parser.add_argument("--model-dir", default=".", help="ONNX model directory") |
| | parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) |
| | parser.add_argument("--save-outputs", action="store_true", help="Save audio files") |
| | parser.add_argument("--skip-pytorch", action="store_true", help="Skip PyTorch inference") |
| | args = parser.parse_args() |
| | |
| | |
| | sample = load_audiocaps_sample() |
| | |
| | |
| | if not args.skip_pytorch: |
| | pytorch_audio, pytorch_sr, input_audio = run_pytorch_inference(sample, args.device) |
| | else: |
| | print("\nSkipping PyTorch inference") |
| | pytorch_audio, pytorch_sr = None, None |
| | input_audio = sample.data.mean(0).numpy() |
| | |
| | |
| | onnx_audio, onnx_sr = run_onnx_inference(sample, args.model_dir) |
| | |
| | |
| | if pytorch_audio is not None: |
| | compare_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr) |
| | |
| | |
| | if args.save_outputs: |
| | print("\n=== Saving Outputs ===") |
| | if pytorch_audio is not None: |
| | save_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr, |
| | input_audio, sample.sample_rate) |
| | else: |
| | import soundfile as sf |
| | os.makedirs("test_outputs", exist_ok=True) |
| | sf.write("test_outputs/onnx_output.wav", onnx_audio, onnx_sr) |
| | print("Saved ONNX output to test_outputs/onnx_output.wav") |
| | |
| | print("\n✓ End-to-end test complete!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|