|
|
|
|
|
""" |
|
|
End-to-end test comparing PyTorch SAM Audio with ONNX Runtime. |
|
|
|
|
|
This script: |
|
|
1. Loads a real audio sample from AudioCaps |
|
|
2. Runs PyTorch inference using the original SAMAudio model |
|
|
3. Runs ONNX inference using the exported models |
|
|
4. Compares the output waveforms |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
import os |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
def load_audiocaps_sample(): |
|
|
"""Load a sample from AudioCaps dataset.""" |
|
|
print("Loading AudioCaps sample...") |
|
|
dset = load_dataset( |
|
|
"parquet", |
|
|
data_files="hf://datasets/OpenSound/AudioCaps/data/test-00000-of-00041.parquet", |
|
|
) |
|
|
sample = dset["train"][8]["audio"].get_all_samples() |
|
|
print(f" Sample rate: {sample.sample_rate}") |
|
|
print(f" Duration: {sample.data.shape[-1] / sample.sample_rate:.2f}s") |
|
|
return sample |
|
|
|
|
|
|
|
|
def run_pytorch_inference(sample, device="cpu"): |
|
|
"""Run inference using PyTorch SAMAudio model.""" |
|
|
print("\n=== PyTorch Inference ===") |
|
|
|
|
|
from sam_audio import SAMAudio, SAMAudioProcessor |
|
|
|
|
|
|
|
|
print("Loading SAMAudio model...") |
|
|
model = SAMAudio.from_pretrained("facebook/sam-audio-small").to(device).eval() |
|
|
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-small") |
|
|
|
|
|
|
|
|
wav = torchaudio.functional.resample( |
|
|
sample.data, sample.sample_rate, processor.audio_sampling_rate |
|
|
) |
|
|
wav = wav.mean(0, keepdim=True) |
|
|
|
|
|
print(f" Input audio shape: {wav.shape}") |
|
|
print(f" Sample rate: {processor.audio_sampling_rate}") |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
audios=[wav], |
|
|
descriptions=["A horn honking"], |
|
|
anchors=[[["+", 6.3, 7.0]]] |
|
|
).to(device) |
|
|
|
|
|
|
|
|
print("Running separation...") |
|
|
with torch.inference_mode(): |
|
|
result = model.separate(inputs) |
|
|
|
|
|
separated_audio = result.target[0].cpu().numpy() |
|
|
print(f" Output shape: {separated_audio.shape}") |
|
|
|
|
|
return separated_audio, processor.audio_sampling_rate, wav.numpy() |
|
|
|
|
|
|
|
|
def run_onnx_inference(sample, model_dir="."): |
|
|
"""Run inference using ONNX models.""" |
|
|
print("\n=== ONNX Runtime Inference ===") |
|
|
|
|
|
import onnxruntime as ort |
|
|
from transformers import AutoTokenizer |
|
|
import json |
|
|
|
|
|
|
|
|
print("Loading ONNX models...") |
|
|
providers = ["CPUExecutionProvider"] |
|
|
|
|
|
dacvae_encoder = ort.InferenceSession( |
|
|
os.path.join(model_dir, "dacvae_encoder.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
dacvae_decoder = ort.InferenceSession( |
|
|
os.path.join(model_dir, "dacvae_decoder.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
t5_encoder = ort.InferenceSession( |
|
|
os.path.join(model_dir, "t5_encoder.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
dit = ort.InferenceSession( |
|
|
os.path.join(model_dir, "dit_single_step.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_dir, "tokenizer")) |
|
|
print(" All models loaded") |
|
|
|
|
|
|
|
|
wav = torchaudio.functional.resample( |
|
|
sample.data, sample.sample_rate, 44100 |
|
|
) |
|
|
wav = wav.mean(0, keepdim=True) |
|
|
audio = wav.numpy().reshape(1, 1, -1).astype(np.float32) |
|
|
|
|
|
print(f" Input audio shape: {audio.shape}") |
|
|
|
|
|
|
|
|
print("Encoding audio...") |
|
|
latent = dacvae_encoder.run( |
|
|
["latent_features"], |
|
|
{"audio": audio} |
|
|
)[0] |
|
|
print(f" Audio latent shape: {latent.shape}") |
|
|
|
|
|
|
|
|
print("Encoding text...") |
|
|
tokens = tokenizer( |
|
|
"A horn honking", |
|
|
return_tensors="np", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
) |
|
|
text_features = t5_encoder.run( |
|
|
["hidden_states"], |
|
|
{ |
|
|
"input_ids": tokens["input_ids"].astype(np.int64), |
|
|
"attention_mask": tokens["attention_mask"].astype(np.int64), |
|
|
} |
|
|
)[0] |
|
|
print(f" Text features shape: {text_features.shape}") |
|
|
|
|
|
|
|
|
print("Running DiT (simplified test - 1 step)...") |
|
|
batch_size = 1 |
|
|
latent_dim = latent.shape[1] |
|
|
time_steps = latent.shape[2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mixture_features = latent.transpose(0, 2, 1) |
|
|
|
|
|
|
|
|
audio_features = np.concatenate([ |
|
|
mixture_features, |
|
|
mixture_features |
|
|
], axis=-1) |
|
|
|
|
|
|
|
|
|
|
|
initial = np.random.randn(batch_size, time_steps, 256).astype(np.float32) |
|
|
|
|
|
|
|
|
velocity = dit.run( |
|
|
["velocity"], |
|
|
{ |
|
|
"noisy_audio": initial, |
|
|
"time": np.array([0.0], dtype=np.float32), |
|
|
"audio_features": audio_features, |
|
|
"text_features": text_features, |
|
|
"text_mask": tokens["attention_mask"].astype(bool), |
|
|
"masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32), |
|
|
"anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64), |
|
|
"anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64), |
|
|
"audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool), |
|
|
} |
|
|
)[0] |
|
|
print(f" DiT velocity shape: {velocity.shape}") |
|
|
|
|
|
|
|
|
|
|
|
print("Running full ODE solve (16 steps)...") |
|
|
num_steps = 16 |
|
|
dt = 1.0 / num_steps |
|
|
x = initial.copy() |
|
|
|
|
|
for i in range(num_steps): |
|
|
t = np.array([i * dt], dtype=np.float32) |
|
|
t_mid = np.array([t[0] + dt / 2], dtype=np.float32) |
|
|
|
|
|
|
|
|
k1 = dit.run( |
|
|
["velocity"], |
|
|
{ |
|
|
"noisy_audio": x, |
|
|
"time": t, |
|
|
"audio_features": audio_features, |
|
|
"text_features": text_features, |
|
|
"text_mask": tokens["attention_mask"].astype(bool), |
|
|
"masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32), |
|
|
"anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64), |
|
|
"anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64), |
|
|
"audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool), |
|
|
} |
|
|
)[0] |
|
|
|
|
|
|
|
|
x_mid = x + (dt / 2) * k1 |
|
|
|
|
|
|
|
|
k2 = dit.run( |
|
|
["velocity"], |
|
|
{ |
|
|
"noisy_audio": x_mid, |
|
|
"time": t_mid, |
|
|
"audio_features": audio_features, |
|
|
"text_features": text_features, |
|
|
"text_mask": tokens["attention_mask"].astype(bool), |
|
|
"masked_video_features": np.zeros((batch_size, 1024, time_steps), dtype=np.float32), |
|
|
"anchor_ids": np.zeros((batch_size, time_steps), dtype=np.int64), |
|
|
"anchor_alignment": np.zeros((batch_size, time_steps), dtype=np.int64), |
|
|
"audio_pad_mask": np.ones((batch_size, time_steps), dtype=bool), |
|
|
} |
|
|
)[0] |
|
|
|
|
|
|
|
|
x = x + dt * k2 |
|
|
print(f" Step {i+1}/{num_steps}") |
|
|
|
|
|
|
|
|
|
|
|
print("Decoding audio...") |
|
|
|
|
|
|
|
|
target_latent = x[:, :, :latent_dim].transpose(0, 2, 1) |
|
|
separated_latent = target_latent |
|
|
|
|
|
|
|
|
chunk_size = 25 |
|
|
T = separated_latent.shape[2] |
|
|
|
|
|
|
|
|
audio_chunks = [] |
|
|
for start_idx in range(0, T, chunk_size): |
|
|
end_idx = min(start_idx + chunk_size, T) |
|
|
chunk = separated_latent[:, :, start_idx:end_idx] |
|
|
|
|
|
|
|
|
actual_size = chunk.shape[2] |
|
|
if actual_size < chunk_size: |
|
|
pad_size = chunk_size - actual_size |
|
|
chunk = np.pad(chunk, ((0, 0), (0, 0), (0, pad_size)), mode='constant') |
|
|
|
|
|
chunk_audio = dacvae_decoder.run( |
|
|
["waveform"], |
|
|
{"latent_features": chunk.astype(np.float32)} |
|
|
)[0] |
|
|
|
|
|
|
|
|
if actual_size < chunk_size: |
|
|
|
|
|
samples_per_step = 1920 |
|
|
trim_samples = actual_size * samples_per_step |
|
|
chunk_audio = chunk_audio[:, :, :trim_samples] |
|
|
|
|
|
audio_chunks.append(chunk_audio) |
|
|
print(f" Decoded chunk {start_idx//chunk_size + 1}/{(T + chunk_size - 1)//chunk_size}") |
|
|
|
|
|
|
|
|
separated_audio = np.concatenate(audio_chunks, axis=2) |
|
|
|
|
|
print(f" Output audio shape: {separated_audio.shape}") |
|
|
|
|
|
return separated_audio.squeeze(), 44100 |
|
|
|
|
|
|
|
|
|
|
|
def compare_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr): |
|
|
"""Compare PyTorch and ONNX outputs.""" |
|
|
print("\n=== Comparison ===") |
|
|
|
|
|
import scipy.signal |
|
|
|
|
|
|
|
|
if pytorch_sr != onnx_sr: |
|
|
print(f"Resampling PyTorch output from {pytorch_sr} to {onnx_sr}...") |
|
|
|
|
|
num_samples = int(len(pytorch_audio) * onnx_sr / pytorch_sr) |
|
|
pytorch_audio_resampled = scipy.signal.resample(pytorch_audio, num_samples) |
|
|
else: |
|
|
pytorch_audio_resampled = pytorch_audio |
|
|
|
|
|
|
|
|
min_len = min(len(pytorch_audio_resampled), len(onnx_audio)) |
|
|
pytorch_trimmed = pytorch_audio_resampled[:min_len] |
|
|
onnx_trimmed = onnx_audio[:min_len] |
|
|
|
|
|
|
|
|
diff = np.abs(pytorch_trimmed - onnx_trimmed) |
|
|
max_diff = diff.max() |
|
|
mean_diff = diff.mean() |
|
|
|
|
|
|
|
|
correlation = np.corrcoef(pytorch_trimmed, onnx_trimmed)[0, 1] |
|
|
|
|
|
print(f" PyTorch audio length: {len(pytorch_audio)} samples") |
|
|
print(f" ONNX audio length: {len(onnx_audio)} samples") |
|
|
print(f" Max difference: {max_diff:.6f}") |
|
|
print(f" Mean difference: {mean_diff:.6f}") |
|
|
print(f" Correlation: {correlation:.6f}") |
|
|
|
|
|
return max_diff, mean_diff, correlation |
|
|
|
|
|
|
|
|
def save_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr, input_audio, input_sr): |
|
|
"""Save audio outputs for listening comparison.""" |
|
|
import soundfile as sf |
|
|
|
|
|
output_dir = "test_outputs" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
sf.write(os.path.join(output_dir, "input.wav"), input_audio.squeeze(), input_sr) |
|
|
print(f"Saved input to {output_dir}/input.wav") |
|
|
|
|
|
|
|
|
sf.write(os.path.join(output_dir, "pytorch_output.wav"), pytorch_audio, pytorch_sr) |
|
|
print(f"Saved PyTorch output to {output_dir}/pytorch_output.wav") |
|
|
|
|
|
|
|
|
sf.write(os.path.join(output_dir, "onnx_output.wav"), onnx_audio, onnx_sr) |
|
|
print(f"Saved ONNX output to {output_dir}/onnx_output.wav") |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="End-to-end SAM Audio test") |
|
|
parser.add_argument("--model-dir", default=".", help="ONNX model directory") |
|
|
parser.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) |
|
|
parser.add_argument("--save-outputs", action="store_true", help="Save audio files") |
|
|
parser.add_argument("--skip-pytorch", action="store_true", help="Skip PyTorch inference") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
sample = load_audiocaps_sample() |
|
|
|
|
|
|
|
|
if not args.skip_pytorch: |
|
|
pytorch_audio, pytorch_sr, input_audio = run_pytorch_inference(sample, args.device) |
|
|
else: |
|
|
print("\nSkipping PyTorch inference") |
|
|
pytorch_audio, pytorch_sr = None, None |
|
|
input_audio = sample.data.mean(0).numpy() |
|
|
|
|
|
|
|
|
onnx_audio, onnx_sr = run_onnx_inference(sample, args.model_dir) |
|
|
|
|
|
|
|
|
if pytorch_audio is not None: |
|
|
compare_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr) |
|
|
|
|
|
|
|
|
if args.save_outputs: |
|
|
print("\n=== Saving Outputs ===") |
|
|
if pytorch_audio is not None: |
|
|
save_outputs(pytorch_audio, onnx_audio, pytorch_sr, onnx_sr, |
|
|
input_audio, sample.sample_rate) |
|
|
else: |
|
|
import soundfile as sf |
|
|
os.makedirs("test_outputs", exist_ok=True) |
|
|
sf.write("test_outputs/onnx_output.wav", onnx_audio, onnx_sr) |
|
|
print("Saved ONNX output to test_outputs/onnx_output.wav") |
|
|
|
|
|
print("\n✓ End-to-end test complete!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|