|
|
|
|
|
""" |
|
|
SAM Audio ONNX Runtime Inference Example |
|
|
|
|
|
This script demonstrates how to use the exported ONNX models for audio source |
|
|
separation inference. It shows the complete pipeline from text input to |
|
|
separated audio output. |
|
|
|
|
|
Usage: |
|
|
python onnx_inference.py --audio input.wav --text "a person speaking" |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import numpy as np |
|
|
import json |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
def load_audio(path: str, target_sr: int = 48000) -> np.ndarray: |
|
|
"""Load audio file and resample to target sample rate. Supports video files via torchaudio/librosa.""" |
|
|
|
|
|
try: |
|
|
import torchaudio |
|
|
import torch |
|
|
wav, sr = torchaudio.load(path) |
|
|
if wav.shape[0] > 1: |
|
|
wav = wav.mean(0, keepdim=True) |
|
|
if sr != target_sr: |
|
|
resampler = torchaudio.transforms.Resample(sr, target_sr) |
|
|
wav = resampler(wav) |
|
|
return wav.squeeze().numpy().astype(np.float32) |
|
|
except Exception as e: |
|
|
|
|
|
try: |
|
|
import librosa |
|
|
audio, sr = librosa.load(path, sr=target_sr, mono=True) |
|
|
return audio.astype(np.float32) |
|
|
except ImportError: |
|
|
raise ImportError("Please install torchaudio or librosa: pip install torchaudio librosa") |
|
|
except Exception as e2: |
|
|
raise RuntimeError(f"Failed to load audio from {path}: {e2}") |
|
|
|
|
|
|
|
|
def save_audio(audio: np.ndarray, path: str, sample_rate: int = 48000): |
|
|
"""Save audio to WAV file.""" |
|
|
try: |
|
|
import soundfile as sf |
|
|
|
|
|
if audio.ndim > 1: |
|
|
audio = audio.flatten() |
|
|
sf.write(path, audio, sample_rate) |
|
|
print(f"Saved audio to {path}") |
|
|
except ImportError: |
|
|
raise ImportError("Please install soundfile: pip install soundfile") |
|
|
|
|
|
|
|
|
def save_video_with_audio(frames: np.ndarray, audio: np.ndarray, path: str, sample_rate: int = 48000, fps: float = 24.0): |
|
|
"""Save masked video frames and separated audio to a movie file.""" |
|
|
try: |
|
|
import torch |
|
|
import torchvision |
|
|
import torchaudio |
|
|
|
|
|
|
|
|
|
|
|
frames_uint8 = ((frames * 0.5 + 0.5) * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
video_tensor = torch.from_numpy(frames_uint8).permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
if audio.ndim == 1: |
|
|
audio = audio[None, :] |
|
|
audio_tensor = torch.from_numpy(audio) |
|
|
|
|
|
print(f"Saving merged video to {path}...") |
|
|
torchvision.io.write_video( |
|
|
path, |
|
|
video_tensor, |
|
|
fps=fps, |
|
|
video_codec="libx264", |
|
|
audio_array=audio_tensor, |
|
|
audio_fps=sample_rate, |
|
|
audio_codec="aac" |
|
|
) |
|
|
print(f" ✓ Video saved to {path}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to save video: {e}") |
|
|
|
|
|
|
|
|
class SAMAudioONNXPipeline: |
|
|
""" |
|
|
ONNX-based SAM Audio inference pipeline. |
|
|
|
|
|
This class orchestrates all the ONNX models to perform audio source separation. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_dir: str = "onnx_models", |
|
|
device: str = "cpu", |
|
|
num_ode_steps: int = 16, |
|
|
): |
|
|
import onnxruntime as ort |
|
|
|
|
|
self.model_dir = model_dir |
|
|
self.num_ode_steps = num_ode_steps |
|
|
self.step_size = 1.0 / num_ode_steps |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
|
else: |
|
|
providers = ["CPUExecutionProvider"] |
|
|
|
|
|
|
|
|
print("Loading ONNX models...") |
|
|
|
|
|
self.dacvae_encoder = ort.InferenceSession( |
|
|
os.path.join(model_dir, "dacvae_encoder.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ DACVAE encoder loaded") |
|
|
|
|
|
self.dacvae_decoder = ort.InferenceSession( |
|
|
os.path.join(model_dir, "dacvae_decoder.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ DACVAE decoder loaded") |
|
|
|
|
|
self.t5_encoder = ort.InferenceSession( |
|
|
os.path.join(model_dir, "t5_encoder.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ T5 encoder loaded") |
|
|
|
|
|
self.dit = ort.InferenceSession( |
|
|
os.path.join(model_dir, "dit_single_step.onnx"), |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ DiT denoiser loaded") |
|
|
|
|
|
|
|
|
self.vision_encoder = None |
|
|
vision_path = os.path.join(model_dir, "vision_encoder.onnx") |
|
|
if os.path.exists(vision_path): |
|
|
self.vision_encoder = ort.InferenceSession( |
|
|
vision_path, |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ Vision encoder loaded") |
|
|
|
|
|
|
|
|
self.peaframe = None |
|
|
self.peaframe_tokenizer = None |
|
|
self.peaframe_config = None |
|
|
peaframe_path = os.path.join(model_dir, "peaframe.onnx") |
|
|
if os.path.exists(peaframe_path): |
|
|
self.peaframe = ort.InferenceSession( |
|
|
peaframe_path, |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ PEAFrame loaded") |
|
|
|
|
|
|
|
|
tokenizer_path = os.path.join(model_dir, "peaframe_tokenizer") |
|
|
if os.path.exists(tokenizer_path): |
|
|
from transformers import AutoTokenizer |
|
|
self.peaframe_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
print(" ✓ PEAFrame tokenizer loaded") |
|
|
|
|
|
|
|
|
config_path = os.path.join(model_dir, "peaframe_config.json") |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path) as f: |
|
|
self.peaframe_config = json.load(f) |
|
|
print(" ✓ PEAFrame config loaded") |
|
|
|
|
|
|
|
|
self.clap_audio_encoder = None |
|
|
self.clap_text_encoder = None |
|
|
self.clap_tokenizer = None |
|
|
self.clap_config = None |
|
|
|
|
|
clap_audio_path = os.path.join(model_dir, "clap_audio_encoder.onnx") |
|
|
clap_text_path = os.path.join(model_dir, "clap_text_encoder.onnx") |
|
|
|
|
|
if os.path.exists(clap_audio_path) and os.path.exists(clap_text_path): |
|
|
self.clap_audio_encoder = ort.InferenceSession( |
|
|
clap_audio_path, |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ CLAP audio encoder loaded") |
|
|
|
|
|
self.clap_text_encoder = ort.InferenceSession( |
|
|
clap_text_path, |
|
|
providers=providers, |
|
|
) |
|
|
print(" ✓ CLAP text encoder loaded") |
|
|
|
|
|
|
|
|
tokenizer_path = os.path.join(model_dir, "clap_tokenizer") |
|
|
if os.path.exists(tokenizer_path): |
|
|
from transformers import AutoTokenizer |
|
|
self.clap_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
print(" ✓ CLAP tokenizer loaded") |
|
|
|
|
|
|
|
|
config_path = os.path.join(model_dir, "clap_config.json") |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path) as f: |
|
|
self.clap_config = json.load(f) |
|
|
print(" ✓ CLAP config loaded") |
|
|
|
|
|
|
|
|
self._load_tokenizer() |
|
|
print(" ✓ Tokenizer loaded") |
|
|
|
|
|
print("All models loaded!") |
|
|
|
|
|
def _load_tokenizer(self): |
|
|
""" |
|
|
Load the T5 tokenizer using SentencePiece. |
|
|
This avoids the dependency on the 'transformers' library. |
|
|
""" |
|
|
try: |
|
|
import sentencepiece as spm |
|
|
except ImportError: |
|
|
raise ImportError("Please install sentencepiece: pip install sentencepiece") |
|
|
|
|
|
|
|
|
sp_path = os.path.join(self.model_dir, "tokenizer", "spiece.model") |
|
|
if not os.path.exists(sp_path): |
|
|
sp_path = os.path.join(self.model_dir, "spiece.model") |
|
|
|
|
|
if not os.path.exists(sp_path): |
|
|
raise FileNotFoundError(f"SentencePiece model not found at {sp_path}") |
|
|
|
|
|
|
|
|
class T5ONNXTokenizer: |
|
|
def __init__(self, sp_path): |
|
|
self.sp = spm.SentencePieceProcessor() |
|
|
self.sp.load(sp_path) |
|
|
|
|
|
def encode(self, text: str) -> np.ndarray: |
|
|
ids = self.sp.encode(text) |
|
|
if len(ids) > 0 and ids[-1] != 1: |
|
|
ids.append(1) |
|
|
elif len(ids) == 0: |
|
|
ids = [1] |
|
|
return np.array(ids, dtype=np.int64).reshape(1, -1) |
|
|
|
|
|
def decode(self, tokens: np.ndarray) -> str: |
|
|
if tokens.ndim > 1: |
|
|
tokens = tokens.flatten() |
|
|
return self.sp.decode(tokens.tolist()) |
|
|
|
|
|
self.tokenizer = T5ONNXTokenizer(sp_path) |
|
|
|
|
|
def load_video_frames(self, path: str, num_steps: int, mask_path: Optional[str] = None) -> tuple[np.ndarray, np.ndarray, float]: |
|
|
""" |
|
|
Load video frames and align them to audio latent steps. |
|
|
Optionally applies a binary mask for visual prompting. |
|
|
Returns (normalized_frames, visual_frames). |
|
|
""" |
|
|
try: |
|
|
from torchcodec.decoders import VideoDecoder |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
except ImportError: |
|
|
raise ImportError("Please install torchcodec and torch: pip install torchcodec torch") |
|
|
|
|
|
decoder = VideoDecoder(path, dimension_order="NCHW") |
|
|
all_data = decoder.get_frames_in_range(0, len(decoder)) |
|
|
|
|
|
|
|
|
|
|
|
hop_length = 1536 |
|
|
sample_rate = 48000 |
|
|
step_timestamps = np.arange(num_steps) * hop_length / sample_rate |
|
|
|
|
|
|
|
|
metadata = decoder.metadata |
|
|
fps = metadata.average_fps if metadata.average_fps is not None else 24.0 |
|
|
|
|
|
|
|
|
diffs = np.abs(all_data.pts_seconds.numpy()[:, None] - step_timestamps[None, :]) |
|
|
frame_idxs = np.argmin(diffs, axis=0) |
|
|
|
|
|
frames = all_data.data[frame_idxs] |
|
|
|
|
|
|
|
|
if mask_path: |
|
|
print(f" Applying mask from {mask_path}...") |
|
|
mask_decoder = VideoDecoder(mask_path, dimension_order="NCHW") |
|
|
mask_data = mask_decoder.get_frames_in_range(0, len(mask_decoder)) |
|
|
|
|
|
|
|
|
m_diffs = np.abs(mask_data.pts_seconds.numpy()[:, None] - step_timestamps[None, :]) |
|
|
m_frame_idxs = np.argmin(m_diffs, axis=0) |
|
|
masks = mask_data.data[m_frame_idxs] |
|
|
|
|
|
|
|
|
|
|
|
binary_mask = (masks.float().mean(dim=1, keepdim=True) > 128).float() |
|
|
frames = frames.float() * (1.0 - binary_mask) |
|
|
|
|
|
|
|
|
image_size = 336 |
|
|
frames_resized = F.interpolate(frames.float(), size=(image_size, image_size), mode="bicubic") |
|
|
frames_norm = (frames_resized / 255.0 - 0.5) / 0.5 |
|
|
|
|
|
return frames_norm.numpy(), frames_norm.numpy(), fps |
|
|
|
|
|
def encode_video(self, frames: np.ndarray) -> np.ndarray: |
|
|
"""Run vision encoder on framed images.""" |
|
|
if self.vision_encoder is None: |
|
|
raise RuntimeError("Vision encoder model not loaded") |
|
|
|
|
|
|
|
|
|
|
|
all_features = [] |
|
|
for i in range(len(frames)): |
|
|
frame = frames[i:i+1] |
|
|
outputs = self.vision_encoder.run( |
|
|
["vision_features"], |
|
|
{"video_frames": frame} |
|
|
) |
|
|
all_features.append(outputs[0]) |
|
|
|
|
|
features = np.concatenate(all_features, axis=0) |
|
|
|
|
|
|
|
|
return features.transpose(1, 0)[None, :, :] |
|
|
|
|
|
|
|
|
def encode_audio(self, audio: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Encode audio waveform to latent features. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform, shape (samples,) or (1, 1, samples) |
|
|
|
|
|
Returns: |
|
|
Latent features, shape (1, latent_dim, time_steps) |
|
|
""" |
|
|
|
|
|
if audio.ndim == 1: |
|
|
audio = audio.reshape(1, 1, -1) |
|
|
elif audio.ndim == 2: |
|
|
audio = audio.reshape(1, *audio.shape) |
|
|
|
|
|
outputs = self.dacvae_encoder.run( |
|
|
["latent_features"], |
|
|
{"audio": audio.astype(np.float32)}, |
|
|
) |
|
|
return outputs[0] |
|
|
|
|
|
def decode_audio(self, latent: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Decode latent features to audio waveform. |
|
|
|
|
|
Uses chunked decoding since the DACVAE decoder was exported with |
|
|
fixed 25 time steps. Processes in chunks and concatenates. |
|
|
|
|
|
Args: |
|
|
latent: Latent features, shape (1, latent_dim, time_steps) |
|
|
|
|
|
Returns: |
|
|
Audio waveform, shape (samples,) |
|
|
""" |
|
|
chunk_size = 25 |
|
|
hop_length = 1920 |
|
|
|
|
|
_, _, time_steps = latent.shape |
|
|
|
|
|
audio_chunks = [] |
|
|
for start_idx in range(0, time_steps, chunk_size): |
|
|
end_idx = min(start_idx + chunk_size, time_steps) |
|
|
chunk = latent[:, :, start_idx:end_idx] |
|
|
|
|
|
|
|
|
actual_size = chunk.shape[2] |
|
|
if actual_size < chunk_size: |
|
|
pad_size = chunk_size - actual_size |
|
|
chunk = np.pad(chunk, ((0, 0), (0, 0), (0, pad_size)), mode='constant') |
|
|
|
|
|
|
|
|
chunk_audio = self.dacvae_decoder.run( |
|
|
["waveform"], |
|
|
{"latent_features": chunk.astype(np.float32)}, |
|
|
)[0] |
|
|
|
|
|
|
|
|
if actual_size < chunk_size: |
|
|
trim_samples = actual_size * hop_length |
|
|
chunk_audio = chunk_audio[:, :, :trim_samples] |
|
|
|
|
|
audio_chunks.append(chunk_audio) |
|
|
|
|
|
|
|
|
full_audio = np.concatenate(audio_chunks, axis=2) |
|
|
return full_audio.squeeze() |
|
|
|
|
|
def encode_text(self, text: str) -> tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Encode text prompt to features. |
|
|
|
|
|
Args: |
|
|
text: Text description of the audio to separate |
|
|
|
|
|
Returns: |
|
|
Tuple of (hidden_states, attention_mask) |
|
|
""" |
|
|
input_ids = self.tokenizer.encode(text) |
|
|
attention_mask = np.ones_like(input_ids) |
|
|
|
|
|
outputs = self.t5_encoder.run( |
|
|
["hidden_states"], |
|
|
{ |
|
|
"input_ids": input_ids.astype(np.int64), |
|
|
"attention_mask": attention_mask.astype(np.int64), |
|
|
}, |
|
|
) |
|
|
|
|
|
return outputs[0], attention_mask |
|
|
|
|
|
def predict_spans( |
|
|
self, |
|
|
audio: np.ndarray, |
|
|
text: str, |
|
|
threshold: Optional[float] = None, |
|
|
) -> list[tuple[float, float]]: |
|
|
""" |
|
|
Predict time spans in audio that match the text description. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform, shape (samples,) |
|
|
text: Text description of target sound |
|
|
threshold: Detection threshold (default from config) |
|
|
|
|
|
Returns: |
|
|
List of (start_seconds, end_seconds) tuples |
|
|
""" |
|
|
if self.peaframe is None: |
|
|
raise RuntimeError("PEAFrame model not loaded") |
|
|
if self.peaframe_tokenizer is None: |
|
|
raise RuntimeError("PEAFrame tokenizer not loaded") |
|
|
if self.peaframe_config is None: |
|
|
raise RuntimeError("PEAFrame config not loaded") |
|
|
|
|
|
config = self.peaframe_config |
|
|
if threshold is None: |
|
|
threshold = config.get("threshold", 0.3) |
|
|
|
|
|
|
|
|
tokens = self.peaframe_tokenizer( |
|
|
text, |
|
|
return_tensors="np", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
sample_rate = config.get("sampling_rate", 48000) |
|
|
hop_length = config.get("hop_length", 1920) |
|
|
expected_samples = 160000 |
|
|
|
|
|
|
|
|
audio_len = len(audio) |
|
|
all_probs = [] |
|
|
|
|
|
if audio_len <= expected_samples: |
|
|
|
|
|
if audio.ndim == 1: |
|
|
audio_input = np.pad(audio, (0, expected_samples - audio_len)) |
|
|
audio_input = audio_input.reshape(1, 1, -1) |
|
|
else: |
|
|
audio_input = audio.reshape(1, *audio.shape) |
|
|
|
|
|
|
|
|
outputs = self.peaframe.run( |
|
|
["audio_embeds", "text_embeds"], |
|
|
{ |
|
|
"input_ids": tokens["input_ids"].astype(np.int64), |
|
|
"input_values": audio_input.astype(np.float32), |
|
|
"attention_mask": tokens["attention_mask"].astype(np.int64), |
|
|
}, |
|
|
) |
|
|
audio_embeds = outputs[0] |
|
|
text_embeds = outputs[1] |
|
|
|
|
|
|
|
|
logits = np.matmul(audio_embeds, text_embeds[:, :, None]) |
|
|
logits = logits.squeeze(-1) |
|
|
|
|
|
|
|
|
logit_scale = config.get("logit_scale", 0.0) |
|
|
logit_bias = config.get("logit_bias", 0.0) |
|
|
logits = logits * logit_scale + logit_bias |
|
|
|
|
|
|
|
|
probs = 1.0 / (1.0 + np.exp(-logits)) |
|
|
|
|
|
|
|
|
num_frames = (audio_len + hop_length - 1) // hop_length |
|
|
all_probs = probs[0, :num_frames] |
|
|
else: |
|
|
|
|
|
chunk_size = expected_samples |
|
|
stride = chunk_size // 2 |
|
|
|
|
|
for start in range(0, audio_len, stride): |
|
|
end = min(start + chunk_size, audio_len) |
|
|
chunk = audio[start:end] |
|
|
|
|
|
|
|
|
if len(chunk) < chunk_size: |
|
|
chunk = np.pad(chunk, (0, chunk_size - len(chunk))) |
|
|
|
|
|
chunk_input = chunk.reshape(1, 1, -1) |
|
|
|
|
|
|
|
|
outputs = self.peaframe.run( |
|
|
["audio_embeds", "text_embeds"], |
|
|
{ |
|
|
"input_ids": tokens["input_ids"].astype(np.int64), |
|
|
"input_values": chunk_input.astype(np.float32), |
|
|
"attention_mask": tokens["attention_mask"].astype(np.int64), |
|
|
}, |
|
|
) |
|
|
audio_embeds = outputs[0] |
|
|
text_embeds = outputs[1] |
|
|
|
|
|
|
|
|
logits = np.matmul(audio_embeds, text_embeds[:, :, None]) |
|
|
logits = logits.squeeze(-1) |
|
|
|
|
|
|
|
|
logit_scale = config.get("logit_scale", 0.0) |
|
|
logit_bias = config.get("logit_bias", 0.0) |
|
|
logits = logits * logit_scale + logit_bias |
|
|
|
|
|
|
|
|
chunk_probs = 1.0 / (1.0 + np.exp(-logits)) |
|
|
all_probs.append(chunk_probs[0]) |
|
|
|
|
|
|
|
|
if end >= audio_len: |
|
|
break |
|
|
|
|
|
|
|
|
if len(all_probs) == 1: |
|
|
all_probs = all_probs[0] |
|
|
else: |
|
|
|
|
|
total_frames = (audio_len + hop_length - 1) // hop_length |
|
|
merged_probs = np.zeros(total_frames) |
|
|
counts = np.zeros(total_frames) |
|
|
|
|
|
for i, chunk_probs in enumerate(all_probs): |
|
|
chunk_start = (i * stride) // hop_length |
|
|
chunk_frames = len(chunk_probs) |
|
|
chunk_end = min(chunk_start + chunk_frames, total_frames) |
|
|
actual_frames = chunk_end - chunk_start |
|
|
|
|
|
merged_probs[chunk_start:chunk_end] += chunk_probs[:actual_frames] |
|
|
counts[chunk_start:chunk_end] += 1 |
|
|
|
|
|
|
|
|
all_probs = merged_probs / np.maximum(counts, 1) |
|
|
|
|
|
|
|
|
preds = all_probs > threshold |
|
|
|
|
|
|
|
|
spans = [] |
|
|
hop_length = config.get("hop_length", 1920) |
|
|
sample_rate = config.get("sampling_rate", 48000) |
|
|
|
|
|
in_span = False |
|
|
start_idx = 0 |
|
|
for i, pred in enumerate(preds): |
|
|
if pred and not in_span: |
|
|
start_idx = i |
|
|
in_span = True |
|
|
elif not pred and in_span: |
|
|
end_idx = i |
|
|
start_sec = start_idx * hop_length / sample_rate |
|
|
end_sec = end_idx * hop_length / sample_rate |
|
|
spans.append((start_sec, end_sec)) |
|
|
in_span = False |
|
|
|
|
|
|
|
|
if in_span: |
|
|
end_sec = len(preds) * hop_length / sample_rate |
|
|
start_sec = start_idx * hop_length / sample_rate |
|
|
spans.append((start_sec, end_sec)) |
|
|
|
|
|
return spans |
|
|
|
|
|
def process_anchors( |
|
|
self, |
|
|
spans: list[tuple[str, float, float]], |
|
|
seq_len: int, |
|
|
sample_rate: int = 48000, |
|
|
hop_length: int = 1920, |
|
|
) -> tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Convert span predictions to anchor tensors for DiT. |
|
|
|
|
|
Args: |
|
|
spans: List of (sign, start_sec, end_sec) tuples |
|
|
sign is "+", "-", or "null" |
|
|
seq_len: Number of audio feature frames |
|
|
sample_rate: Audio sample rate |
|
|
hop_length: Samples per feature frame |
|
|
|
|
|
Returns: |
|
|
Tuple of (anchor_ids, anchor_alignment) |
|
|
- anchor_ids: [1, num_anchors] - anchor type indices |
|
|
- anchor_alignment: [1, seq_len] - maps each frame to anchor index |
|
|
""" |
|
|
|
|
|
anchor_dict = {"<null>": 0, "+": 1, "-": 2, "<pad>": 3, "null": 0} |
|
|
|
|
|
|
|
|
anchor_ids = [anchor_dict["<null>"], anchor_dict["<pad>"]] |
|
|
anchor_alignment = np.zeros((1, seq_len), dtype=np.int64) |
|
|
|
|
|
|
|
|
anchor_alignment[0, :] = 1 |
|
|
|
|
|
for sign, start_sec, end_sec in spans: |
|
|
|
|
|
start_idx = int(start_sec * sample_rate / hop_length) |
|
|
end_idx = int(end_sec * sample_rate / hop_length) |
|
|
|
|
|
|
|
|
start_idx = max(0, min(start_idx, seq_len)) |
|
|
end_idx = max(0, min(end_idx, seq_len)) |
|
|
|
|
|
if start_idx < end_idx: |
|
|
|
|
|
anchor_idx = len(anchor_ids) |
|
|
anchor_alignment[0, start_idx:end_idx] = anchor_idx |
|
|
anchor_ids.append(anchor_dict.get(sign, anchor_dict["+"])) |
|
|
|
|
|
return np.array([anchor_ids], dtype=np.int64), anchor_alignment |
|
|
|
|
|
def score_with_clap( |
|
|
self, |
|
|
audio_candidates: list[np.ndarray], |
|
|
text: str, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Score audio candidates against text using CLAP. |
|
|
|
|
|
The CLAP audio encoder expects waveforms at 48kHz, padded/truncated to |
|
|
10 seconds (480000 samples). |
|
|
|
|
|
Args: |
|
|
audio_candidates: List of audio waveforms, each shape (samples,) |
|
|
text: Text description to match against |
|
|
|
|
|
Returns: |
|
|
scores: Array of similarity scores, shape (num_candidates,) |
|
|
""" |
|
|
if self.clap_audio_encoder is None: |
|
|
raise RuntimeError("CLAP audio encoder not loaded") |
|
|
if self.clap_text_encoder is None: |
|
|
raise RuntimeError("CLAP text encoder not loaded") |
|
|
if self.clap_tokenizer is None: |
|
|
raise RuntimeError("CLAP tokenizer not loaded") |
|
|
if self.clap_config is None: |
|
|
raise RuntimeError("CLAP config not loaded") |
|
|
|
|
|
config = self.clap_config |
|
|
max_audio_len = config.get("max_audio_len", 480000) |
|
|
|
|
|
|
|
|
tokens = self.clap_tokenizer( |
|
|
text, |
|
|
return_tensors="np", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77, |
|
|
) |
|
|
|
|
|
text_embed = self.clap_text_encoder.run( |
|
|
["text_embed"], |
|
|
{ |
|
|
"input_ids": tokens["input_ids"].astype(np.int64), |
|
|
"attention_mask": tokens["attention_mask"].astype(np.int64), |
|
|
}, |
|
|
)[0] |
|
|
|
|
|
|
|
|
audio_embeds = [] |
|
|
for audio in audio_candidates: |
|
|
|
|
|
|
|
|
audio = (audio * 32768.0).astype(np.int16).astype(np.float32) / 32768.0 |
|
|
|
|
|
|
|
|
if len(audio) > max_audio_len: |
|
|
audio = audio[:max_audio_len] |
|
|
elif len(audio) < max_audio_len: |
|
|
|
|
|
n_repeat = int(np.ceil(max_audio_len / len(audio))) |
|
|
audio = np.tile(audio, n_repeat)[:max_audio_len] |
|
|
|
|
|
|
|
|
audio_input = audio.reshape(1, -1).astype(np.float32) |
|
|
|
|
|
|
|
|
audio_embed = self.clap_audio_encoder.run( |
|
|
["audio_embed"], |
|
|
{"waveform": audio_input}, |
|
|
)[0] |
|
|
|
|
|
audio_embeds.append(audio_embed) |
|
|
|
|
|
|
|
|
audio_embeds = np.concatenate(audio_embeds, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scores = np.matmul(audio_embeds, text_embed.T).squeeze(-1) |
|
|
|
|
|
return scores |
|
|
|
|
|
def generate_candidates( |
|
|
self, |
|
|
audio_features: np.ndarray, |
|
|
text_features: np.ndarray, |
|
|
text_mask: np.ndarray, |
|
|
num_candidates: int = 4, |
|
|
masked_video_features: Optional[np.ndarray] = None, |
|
|
anchor_ids: Optional[np.ndarray] = None, |
|
|
anchor_alignment: Optional[np.ndarray] = None, |
|
|
seed: Optional[int] = None, |
|
|
) -> list[tuple[np.ndarray, np.ndarray]]: |
|
|
""" |
|
|
Generate multiple separation candidates with different random seeds. |
|
|
|
|
|
Args: |
|
|
audio_features: Encoded audio features [B, T, C] |
|
|
text_features: Encoded text features |
|
|
text_mask: Text attention mask |
|
|
num_candidates: Number of candidates to generate |
|
|
masked_video_features: Optional video features |
|
|
anchor_ids: Optional anchor IDs |
|
|
anchor_alignment: Optional anchor alignment |
|
|
seed: Base random seed (candidates use seed, seed+1, seed+2, ...) |
|
|
|
|
|
Returns: |
|
|
List of (target_latent, residual_latent) tuples |
|
|
""" |
|
|
B, T, C = audio_features.shape |
|
|
|
|
|
candidates = [] |
|
|
|
|
|
for i in range(num_candidates): |
|
|
|
|
|
if seed is not None: |
|
|
np.random.seed(seed + i) |
|
|
|
|
|
|
|
|
x = np.random.randn(B, T, C).astype(np.float32) |
|
|
|
|
|
|
|
|
steps = self.num_ode_steps |
|
|
dt = 1.0 / steps |
|
|
|
|
|
for step_idx in range(steps): |
|
|
t = step_idx * dt |
|
|
|
|
|
k1 = self.dit_step( |
|
|
x, t, audio_features, text_features, text_mask, |
|
|
masked_video_features, anchor_ids, anchor_alignment |
|
|
) |
|
|
x_mid = x + k1 * (dt / 2.0) |
|
|
k2 = self.dit_step( |
|
|
x_mid, t + dt/2.0, audio_features, text_features, text_mask, |
|
|
masked_video_features, anchor_ids, anchor_alignment |
|
|
) |
|
|
x = x + k2 * dt |
|
|
|
|
|
|
|
|
target_latent = x[:, :, :128].transpose(0, 2, 1) |
|
|
residual_latent = x[:, :, 128:].transpose(0, 2, 1) |
|
|
|
|
|
candidates.append((target_latent, residual_latent)) |
|
|
|
|
|
return candidates |
|
|
|
|
|
def dit_step( |
|
|
self, |
|
|
noisy_audio: np.ndarray, |
|
|
time: float, |
|
|
audio_features: np.ndarray, |
|
|
text_features: np.ndarray, |
|
|
text_mask: np.ndarray, |
|
|
masked_video_features: Optional[np.ndarray] = None, |
|
|
anchor_ids: Optional[np.ndarray] = None, |
|
|
anchor_alignment: Optional[np.ndarray] = None, |
|
|
) -> np.ndarray: |
|
|
"""Run a single DiT denoiser step.""" |
|
|
batch_size = noisy_audio.shape[0] |
|
|
seq_len = noisy_audio.shape[1] |
|
|
|
|
|
|
|
|
first_input = self.dit.get_inputs()[0] |
|
|
use_fp16 = first_input.type == 'tensor(float16)' |
|
|
float_dtype = np.float16 if use_fp16 else np.float32 |
|
|
|
|
|
|
|
|
if anchor_ids is None: |
|
|
|
|
|
anchor_ids = np.zeros((batch_size, 2), dtype=np.int64) |
|
|
anchor_ids[:, 1] = 3 |
|
|
|
|
|
if anchor_alignment is None: |
|
|
|
|
|
anchor_alignment = np.zeros((batch_size, seq_len), dtype=np.int64) |
|
|
|
|
|
|
|
|
audio_pad_mask = np.ones((batch_size, seq_len), dtype=np.bool_) |
|
|
|
|
|
|
|
|
if masked_video_features is None: |
|
|
vision_dim = 1024 |
|
|
masked_video_features = np.zeros((batch_size, vision_dim, seq_len), dtype=float_dtype) |
|
|
|
|
|
inputs = { |
|
|
"noisy_audio": noisy_audio.astype(float_dtype), |
|
|
"time": np.array([time], dtype=float_dtype), |
|
|
"audio_features": audio_features.astype(float_dtype), |
|
|
"text_features": text_features.astype(float_dtype), |
|
|
"text_mask": text_mask.astype(np.bool_), |
|
|
"masked_video_features": masked_video_features.astype(float_dtype), |
|
|
"anchor_ids": anchor_ids.astype(np.int64), |
|
|
"anchor_alignment": anchor_alignment.astype(np.int64), |
|
|
"audio_pad_mask": audio_pad_mask.astype(np.bool_), |
|
|
} |
|
|
|
|
|
outputs = self.dit.run(None, inputs) |
|
|
return outputs[0] |
|
|
|
|
|
|
|
|
def separate( |
|
|
self, |
|
|
audio: np.ndarray, |
|
|
text: str, |
|
|
video_path: Optional[str] = None, |
|
|
mask_path: Optional[str] = None, |
|
|
predict_spans: bool = False, |
|
|
manual_anchors: Optional[list[tuple[str, float, float]]] = None, |
|
|
span_threshold: float = 0.3, |
|
|
rerank: bool = False, |
|
|
num_candidates: int = 4, |
|
|
rerank_seed: Optional[int] = None, |
|
|
) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], float]: |
|
|
""" |
|
|
Perform the full separation pipeline. |
|
|
|
|
|
Args: |
|
|
audio: Input mixture waveform |
|
|
text: Text description of the target source |
|
|
video_path: Optional path to a video for visual conditioning |
|
|
mask_path: Optional path to a video/image mask for visual prompting |
|
|
predict_spans: Whether to use PEAFrame for span prediction |
|
|
manual_anchors: Optional list of manual anchor spans |
|
|
span_threshold: Threshold for span prediction |
|
|
rerank: Whether to generate multiple candidates and rerank with CLAP |
|
|
num_candidates: Number of candidates for reranking |
|
|
rerank_seed: Random seed for reproducible candidate generation |
|
|
|
|
|
Returns: |
|
|
Tuple of (target audio, residual audio, masked video frames if any, fps) |
|
|
- target: The separated sound matching the text/visual prompt |
|
|
- residual: Everything else in the audio (the remainder) |
|
|
""" |
|
|
|
|
|
print("1. Encoding audio...") |
|
|
latent_features = self.encode_audio(audio) |
|
|
|
|
|
latent_features = latent_features.transpose(0, 2, 1) |
|
|
|
|
|
|
|
|
audio_features = np.concatenate([latent_features, latent_features], axis=2) |
|
|
print(f" Audio latent shape: {latent_features.shape}") |
|
|
|
|
|
|
|
|
print("2. Encoding text...") |
|
|
text_features, text_mask = self.encode_text(text) |
|
|
print(f" Text features shape: {text_features.shape}") |
|
|
|
|
|
|
|
|
anchor_ids = None |
|
|
anchor_alignment = None |
|
|
seq_len = latent_features.shape[1] |
|
|
|
|
|
if manual_anchors: |
|
|
print("2.5. Processing manual anchors...") |
|
|
anchor_ids, anchor_alignment = self.process_anchors( |
|
|
manual_anchors, seq_len |
|
|
) |
|
|
print(f" Anchors: {len(manual_anchors)} spans specified") |
|
|
elif predict_spans and self.peaframe is not None: |
|
|
print("2.5. Predicting spans with PEAFrame...") |
|
|
detected_spans = self.predict_spans(audio, text, threshold=span_threshold) |
|
|
if detected_spans: |
|
|
|
|
|
anchors = [("+", s, e) for s, e in detected_spans] |
|
|
anchor_ids, anchor_alignment = self.process_anchors(anchors, seq_len) |
|
|
print(f" Detected {len(detected_spans)} spans: {detected_spans}") |
|
|
else: |
|
|
print(" No spans detected, using null anchors") |
|
|
|
|
|
|
|
|
masked_video_features = None |
|
|
visual_frames = None |
|
|
fps = 24.0 |
|
|
if video_path and self.vision_encoder: |
|
|
print("3a. Loading and encoding video...") |
|
|
norm_frames, visual_frames, fps = self.load_video_frames(video_path, latent_features.shape[1], mask_path) |
|
|
masked_video_features = self.encode_video(norm_frames) |
|
|
print(f" Video features shape: {masked_video_features.shape}") |
|
|
|
|
|
|
|
|
if rerank and self.clap_audio_encoder is not None: |
|
|
print(f"3. Generating {num_candidates} candidates for reranking...") |
|
|
|
|
|
|
|
|
candidates = self.generate_candidates( |
|
|
audio_features, text_features, text_mask, |
|
|
num_candidates=num_candidates, |
|
|
masked_video_features=masked_video_features, |
|
|
anchor_ids=anchor_ids, |
|
|
anchor_alignment=anchor_alignment, |
|
|
seed=rerank_seed, |
|
|
) |
|
|
|
|
|
|
|
|
print("3b. Decoding candidate audios...") |
|
|
candidate_audios = [] |
|
|
for i, (target_latent, _) in enumerate(candidates): |
|
|
decoded = self.decode_audio(target_latent) |
|
|
candidate_audios.append(decoded) |
|
|
print(f" Candidate {i+1}/{num_candidates} decoded", end="\r") |
|
|
print() |
|
|
|
|
|
|
|
|
print("3c. Scoring candidates with CLAP...") |
|
|
scores = self.score_with_clap(candidate_audios, text) |
|
|
best_idx = int(np.argmax(scores)) |
|
|
print(f" Scores: {scores}") |
|
|
print(f" Selected candidate {best_idx + 1}/{num_candidates} (score: {scores[best_idx]:.4f})") |
|
|
|
|
|
|
|
|
target_latent, residual_latent = candidates[best_idx] |
|
|
print(f" Target latent shape: {target_latent.shape}") |
|
|
print(f" Residual latent shape: {residual_latent.shape}") |
|
|
|
|
|
else: |
|
|
|
|
|
print("3. Running ODE solver...") |
|
|
|
|
|
|
|
|
B, T, C = audio_features.shape |
|
|
x = np.random.randn(B, T, C).astype(np.float32) |
|
|
|
|
|
steps = self.num_ode_steps |
|
|
dt = 1.0 / steps |
|
|
|
|
|
for i in range(steps): |
|
|
t = i * dt |
|
|
print(f" ODE step {i+1}/{steps}", end="\r") |
|
|
|
|
|
k1 = self.dit_step( |
|
|
x, t, audio_features, text_features, text_mask, |
|
|
masked_video_features, anchor_ids, anchor_alignment |
|
|
) |
|
|
x_mid = x + k1 * (dt / 2.0) |
|
|
k2 = self.dit_step( |
|
|
x_mid, t + dt/2.0, audio_features, text_features, text_mask, |
|
|
masked_video_features, anchor_ids, anchor_alignment |
|
|
) |
|
|
|
|
|
x = x + k2 * dt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_latent = x[:, :, :128].transpose(0, 2, 1) |
|
|
residual_latent = x[:, :, 128:].transpose(0, 2, 1) |
|
|
print(f"\n Target latent shape: {target_latent.shape}") |
|
|
print(f" Residual latent shape: {residual_latent.shape}") |
|
|
|
|
|
|
|
|
print("4. Decoding target audio...") |
|
|
target_audio = self.decode_audio(target_latent) |
|
|
print(f" Target audio shape: {target_audio.shape}") |
|
|
|
|
|
print("5. Decoding residual audio...") |
|
|
residual_audio = self.decode_audio(residual_latent) |
|
|
print(f" Residual audio shape: {residual_audio.shape}") |
|
|
|
|
|
return target_audio, residual_audio, visual_frames, fps |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="SAM Audio ONNX Runtime Inference" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--audio", |
|
|
type=str, |
|
|
help="Path to input audio file (optional if --video is provided)", |
|
|
) |
|
|
parser.add_argument("--text", type=str, default="", help="Text description of the target source (optional if --video is provided)") |
|
|
parser.add_argument("--video", type=str, help="Optional path to video file for conditional separation") |
|
|
parser.add_argument("--mask", type=str, help="Optional path to mask file (visual prompting)") |
|
|
parser.add_argument( |
|
|
"--predict-spans", |
|
|
action="store_true", |
|
|
help="Use PEAFrame to automatically detect time spans matching the text", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--anchor", |
|
|
nargs=3, |
|
|
action="append", |
|
|
metavar=("SIGN", "START", "END"), |
|
|
help="Manual anchor: --anchor + 6.3 7.0 (sign is +, -, or null)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--span-threshold", |
|
|
type=float, |
|
|
default=0.3, |
|
|
help="Threshold for span prediction (default: 0.3)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--rerank", |
|
|
action="store_true", |
|
|
help="Generate multiple candidates and rerank with CLAP", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-candidates", |
|
|
type=int, |
|
|
default=4, |
|
|
help="Number of candidates for reranking (default: 4)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--rerank-seed", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Random seed for reproducible candidate generation", |
|
|
) |
|
|
parser.add_argument("--output", type=str, default="target.wav", help="Output WAV file path for target (separated) audio") |
|
|
parser.add_argument("--output-residual", type=str, default="residual.wav", help="Output WAV file path for residual audio") |
|
|
parser.add_argument("--output-video", type=str, help="Optional path to save masked video with separated audio") |
|
|
parser.add_argument("--model-dir", type=str, default="onnx_models", help="Directory containing ONNX models") |
|
|
parser.add_argument("--steps", type=int, default=16, help="Number of ODE solver steps") |
|
|
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Inference device") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
manual_anchors = None |
|
|
if args.anchor: |
|
|
manual_anchors = [] |
|
|
for sign, start, end in args.anchor: |
|
|
if sign not in ("+", "-", "null"): |
|
|
parser.error(f"Invalid anchor sign: {sign}. Use +, -, or null") |
|
|
manual_anchors.append((sign, float(start), float(end))) |
|
|
print(f"Manual anchors: {manual_anchors}") |
|
|
|
|
|
|
|
|
pipeline = SAMAudioONNXPipeline( |
|
|
model_dir=args.model_dir, |
|
|
device=args.device, |
|
|
num_ode_steps=args.steps, |
|
|
) |
|
|
|
|
|
|
|
|
if not args.audio and not args.video: |
|
|
parser.error("At least one of --audio or --video must be provided.") |
|
|
|
|
|
|
|
|
if not args.text and not args.video: |
|
|
parser.error("--text is required for audio-only separation.") |
|
|
|
|
|
audio_path = args.audio if args.audio else args.video |
|
|
|
|
|
|
|
|
print(f"\nLoading audio from: {audio_path}") |
|
|
audio = load_audio(audio_path, target_sr=48000) |
|
|
print(f"Audio duration: {len(audio)/48000:.2f} seconds") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
target_audio, residual_audio, masked_frames, fps = pipeline.separate( |
|
|
audio, |
|
|
args.text, |
|
|
video_path=args.video if args.video else None, |
|
|
mask_path=args.mask, |
|
|
predict_spans=args.predict_spans, |
|
|
manual_anchors=manual_anchors, |
|
|
span_threshold=args.span_threshold, |
|
|
rerank=args.rerank, |
|
|
num_candidates=args.num_candidates, |
|
|
rerank_seed=args.rerank_seed, |
|
|
) |
|
|
|
|
|
|
|
|
save_audio(target_audio, args.output, sample_rate=48000) |
|
|
save_audio(residual_audio, args.output_residual, sample_rate=48000) |
|
|
|
|
|
|
|
|
if args.output_video and masked_frames is not None: |
|
|
save_video_with_audio(masked_frames, target_audio, args.output_video, sample_rate=48000, fps=fps) |
|
|
|
|
|
print(f"\n✓ Done!") |
|
|
print(f" Target audio saved to: {args.output}") |
|
|
print(f" Residual audio saved to: {args.output_residual}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nError during separation: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|