Vanta / vanta /inference.py
Komalpreet Kaur
Switch to SepFormer backend for +9.7 dB SI-SDR
828f7dd unverified
"""Inference helpers: load a trained Vanta checkpoint OR a SepFormer-based
backbone and extract a target speaker.
Two backends, same interface:
- VantaInference : our from-scratch trained checkpoint
- VantaSepFormerInference : pretrained SepFormer + our ECAPA selector
Pick at server startup via the VANTA_BACKEND env var. The trained-from-scratch
model is the project's "training pedigree" piece; the SepFormer backbone
delivers the audio quality we need for the live demo.
"""
from __future__ import annotations
import io
import subprocess
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
from vanta.config import SAMPLE_RATE
from vanta.models.vanta import Vanta, VantaConfig
from vanta.utils.audio import peak_normalize
MAX_MIX_SECONDS = 30.0
ENROLL_SECONDS = 5.0
def _ffmpeg_decode(raw: bytes) -> np.ndarray:
"""Pipe arbitrary-container bytes through ffmpeg, get mono 16 kHz float32.
libsndfile can't read MP4/M4A/WebM/MOV and so on. ffmpeg can. We spawn it
on demand and stream in/out via pipes to avoid temp files.
"""
from imageio_ffmpeg import get_ffmpeg_exe
cmd = [
get_ffmpeg_exe(),
"-hide_banner",
"-loglevel", "error",
"-i", "pipe:0",
"-vn", # ignore any video stream
"-ac", "1", # mono
"-ar", str(SAMPLE_RATE),
"-f", "f32le", # raw float32 output — trivial to np.frombuffer
"pipe:1",
]
proc = subprocess.run(cmd, input=raw, capture_output=True)
if proc.returncode != 0:
err = proc.stderr.decode("utf-8", errors="replace").strip()
raise ValueError(f"ffmpeg decode failed: {err or 'no stderr'}")
return np.frombuffer(proc.stdout, dtype=np.float32).copy()
def _to_mono_16k(raw: bytes) -> np.ndarray:
# Fast path: libsndfile handles WAV/FLAC/OGG/MP3 without spawning a process.
try:
wav, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
except Exception:
# Anything libsndfile refuses — MP4, M4A, WebM, MOV, etc. — goes to ffmpeg.
return _ffmpeg_decode(raw)
if wav.ndim > 1:
wav = wav.mean(axis=1)
if sr != SAMPLE_RATE:
import soxr
wav = soxr.resample(wav, sr, SAMPLE_RATE, quality="HQ")
return wav.astype(np.float32, copy=False)
def _fit(wav: np.ndarray, target_samples: int) -> np.ndarray:
if len(wav) >= target_samples:
return wav[:target_samples]
out = np.zeros(target_samples, dtype=wav.dtype)
out[: len(wav)] = wav
return out
class VantaInference:
"""Wraps a trained Vanta model for single-file inference.
Load once at startup, call `.extract(mixture_bytes, enrollment_bytes)` per
request. Returns (extracted_wav_bytes, residue_wav_bytes).
"""
def __init__(self, checkpoint_path: Path, repeats: int = 2, device: str = "auto"):
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model = Vanta(VantaConfig(repeats=repeats))
ck = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(ck["model_state"])
self.model.to(self.device).eval()
@torch.no_grad()
def extract(
self, mixture_bytes: bytes, enrollment_bytes: bytes
) -> tuple[bytes, bytes, dict]:
mixture = _to_mono_16k(mixture_bytes)
enrollment = _to_mono_16k(enrollment_bytes)
# Guardrails on request size.
orig_mix_samples = len(mixture)
max_samples = int(MAX_MIX_SECONDS * SAMPLE_RATE)
if len(mixture) > max_samples:
mixture = mixture[:max_samples]
# Enrollment has to be exactly ENROLL_SECONDS for our trained model.
enrollment = _fit(enrollment, int(ENROLL_SECONDS * SAMPLE_RATE))
enrollment = peak_normalize(enrollment, peak=0.95)
mix_t = torch.from_numpy(mixture).unsqueeze(0).to(self.device)
enr_t = torch.from_numpy(enrollment).unsqueeze(0).to(self.device)
# AMP matches how we trained, and it halves memory on long clips.
use_amp = self.device.type == "cuda"
if use_amp:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
est = self.model(mix_t, enrollment=enr_t).float()
else:
est = self.model(mix_t, enrollment=enr_t)
estimate = est.squeeze(0).cpu().numpy()
# SI-SDR is scale-invariant, so nothing in training penalizes the decoder
# for drifting to huge amplitudes. Model outputs routinely peak at
# ±100+. Match the mixture's loudness so playback sounds natural and
# PCM_16 encoding doesn't clip.
mix_peak = float(np.max(np.abs(mixture[: len(estimate)]))) + 1e-8
est_peak = float(np.max(np.abs(estimate))) + 1e-8
estimate = estimate * (mix_peak * 0.95 / est_peak)
# Residue = what Vanta removed. Handy for demos — users can play it and
# hear "this is what the void consumed."
residue = mixture[: len(estimate)] - estimate
meta = {
"sample_rate": SAMPLE_RATE,
"input_seconds": orig_mix_samples / SAMPLE_RATE,
"output_seconds": len(estimate) / SAMPLE_RATE,
"truncated": orig_mix_samples > max_samples,
}
return _encode_wav(estimate), _encode_wav(residue), meta
def _encode_wav(wav: np.ndarray) -> bytes:
buf = io.BytesIO()
sf.write(buf, wav, SAMPLE_RATE, subtype="PCM_16", format="WAV")
return buf.getvalue()
class VantaSepFormerInference:
"""SepFormer-backbone inference. Same public interface as VantaInference
so server.py can swap between them without code changes."""
def __init__(
self,
sepformer_source: str = "speechbrain/sepformer-libri2mix",
device: str = "auto",
):
from vanta.models.sepformer_tse import SepFormerTSE
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model = SepFormerTSE(
sepformer_source=sepformer_source, device=self.device
)
@torch.no_grad()
def extract(
self, mixture_bytes: bytes, enrollment_bytes: bytes
) -> tuple[bytes, bytes, dict]:
mixture = _to_mono_16k(mixture_bytes)
enrollment = _to_mono_16k(enrollment_bytes)
orig_mix_samples = len(mixture)
max_samples = int(MAX_MIX_SECONDS * SAMPLE_RATE)
if len(mixture) > max_samples:
mixture = mixture[:max_samples]
enrollment = _fit(enrollment, int(ENROLL_SECONDS * SAMPLE_RATE))
enrollment = peak_normalize(enrollment, peak=0.95)
mix_t = torch.from_numpy(mixture).unsqueeze(0).to(self.device)
enr_t = torch.from_numpy(enrollment).unsqueeze(0).to(self.device)
extracted_t, residue_t, model_meta = self.model(mix_t, enr_t)
extracted = extracted_t.squeeze(0).cpu().numpy()
residue = residue_t.squeeze(0).cpu().numpy()
# Match the mixture's loudness for natural playback (SepFormer outputs
# are typically lower-amplitude than the input).
mix_peak = float(np.max(np.abs(mixture[: len(extracted)]))) + 1e-8
for arr in (extracted, residue):
peak = float(np.max(np.abs(arr))) + 1e-8
if peak > 0:
arr *= mix_peak * 0.95 / peak
meta = {
"sample_rate": SAMPLE_RATE,
"input_seconds": orig_mix_samples / SAMPLE_RATE,
"output_seconds": len(extracted) / SAMPLE_RATE,
"truncated": orig_mix_samples > max_samples,
"backend": "sepformer",
**model_meta,
}
return _encode_wav(extracted), _encode_wav(residue), meta