"""Inference helpers: load a trained Vanta checkpoint OR a SepFormer-based backbone and extract a target speaker. Two backends, same interface: - VantaInference : our from-scratch trained checkpoint - VantaSepFormerInference : pretrained SepFormer + our ECAPA selector Pick at server startup via the VANTA_BACKEND env var. The trained-from-scratch model is the project's "training pedigree" piece; the SepFormer backbone delivers the audio quality we need for the live demo. """ from __future__ import annotations import io import subprocess from pathlib import Path import numpy as np import soundfile as sf import torch from vanta.config import SAMPLE_RATE from vanta.models.vanta import Vanta, VantaConfig from vanta.utils.audio import peak_normalize MAX_MIX_SECONDS = 30.0 ENROLL_SECONDS = 5.0 def _ffmpeg_decode(raw: bytes) -> np.ndarray: """Pipe arbitrary-container bytes through ffmpeg, get mono 16 kHz float32. libsndfile can't read MP4/M4A/WebM/MOV and so on. ffmpeg can. We spawn it on demand and stream in/out via pipes to avoid temp files. """ from imageio_ffmpeg import get_ffmpeg_exe cmd = [ get_ffmpeg_exe(), "-hide_banner", "-loglevel", "error", "-i", "pipe:0", "-vn", # ignore any video stream "-ac", "1", # mono "-ar", str(SAMPLE_RATE), "-f", "f32le", # raw float32 output — trivial to np.frombuffer "pipe:1", ] proc = subprocess.run(cmd, input=raw, capture_output=True) if proc.returncode != 0: err = proc.stderr.decode("utf-8", errors="replace").strip() raise ValueError(f"ffmpeg decode failed: {err or 'no stderr'}") return np.frombuffer(proc.stdout, dtype=np.float32).copy() def _to_mono_16k(raw: bytes) -> np.ndarray: # Fast path: libsndfile handles WAV/FLAC/OGG/MP3 without spawning a process. try: wav, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False) except Exception: # Anything libsndfile refuses — MP4, M4A, WebM, MOV, etc. — goes to ffmpeg. return _ffmpeg_decode(raw) if wav.ndim > 1: wav = wav.mean(axis=1) if sr != SAMPLE_RATE: import soxr wav = soxr.resample(wav, sr, SAMPLE_RATE, quality="HQ") return wav.astype(np.float32, copy=False) def _fit(wav: np.ndarray, target_samples: int) -> np.ndarray: if len(wav) >= target_samples: return wav[:target_samples] out = np.zeros(target_samples, dtype=wav.dtype) out[: len(wav)] = wav return out class VantaInference: """Wraps a trained Vanta model for single-file inference. Load once at startup, call `.extract(mixture_bytes, enrollment_bytes)` per request. Returns (extracted_wav_bytes, residue_wav_bytes). """ def __init__(self, checkpoint_path: Path, repeats: int = 2, device: str = "auto"): if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.model = Vanta(VantaConfig(repeats=repeats)) ck = torch.load(checkpoint_path, map_location=self.device, weights_only=False) self.model.load_state_dict(ck["model_state"]) self.model.to(self.device).eval() @torch.no_grad() def extract( self, mixture_bytes: bytes, enrollment_bytes: bytes ) -> tuple[bytes, bytes, dict]: mixture = _to_mono_16k(mixture_bytes) enrollment = _to_mono_16k(enrollment_bytes) # Guardrails on request size. orig_mix_samples = len(mixture) max_samples = int(MAX_MIX_SECONDS * SAMPLE_RATE) if len(mixture) > max_samples: mixture = mixture[:max_samples] # Enrollment has to be exactly ENROLL_SECONDS for our trained model. enrollment = _fit(enrollment, int(ENROLL_SECONDS * SAMPLE_RATE)) enrollment = peak_normalize(enrollment, peak=0.95) mix_t = torch.from_numpy(mixture).unsqueeze(0).to(self.device) enr_t = torch.from_numpy(enrollment).unsqueeze(0).to(self.device) # AMP matches how we trained, and it halves memory on long clips. use_amp = self.device.type == "cuda" if use_amp: with torch.autocast(device_type="cuda", dtype=torch.bfloat16): est = self.model(mix_t, enrollment=enr_t).float() else: est = self.model(mix_t, enrollment=enr_t) estimate = est.squeeze(0).cpu().numpy() # SI-SDR is scale-invariant, so nothing in training penalizes the decoder # for drifting to huge amplitudes. Model outputs routinely peak at # ±100+. Match the mixture's loudness so playback sounds natural and # PCM_16 encoding doesn't clip. mix_peak = float(np.max(np.abs(mixture[: len(estimate)]))) + 1e-8 est_peak = float(np.max(np.abs(estimate))) + 1e-8 estimate = estimate * (mix_peak * 0.95 / est_peak) # Residue = what Vanta removed. Handy for demos — users can play it and # hear "this is what the void consumed." residue = mixture[: len(estimate)] - estimate meta = { "sample_rate": SAMPLE_RATE, "input_seconds": orig_mix_samples / SAMPLE_RATE, "output_seconds": len(estimate) / SAMPLE_RATE, "truncated": orig_mix_samples > max_samples, } return _encode_wav(estimate), _encode_wav(residue), meta def _encode_wav(wav: np.ndarray) -> bytes: buf = io.BytesIO() sf.write(buf, wav, SAMPLE_RATE, subtype="PCM_16", format="WAV") return buf.getvalue() class VantaSepFormerInference: """SepFormer-backbone inference. Same public interface as VantaInference so server.py can swap between them without code changes.""" def __init__( self, sepformer_source: str = "speechbrain/sepformer-libri2mix", device: str = "auto", ): from vanta.models.sepformer_tse import SepFormerTSE if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.model = SepFormerTSE( sepformer_source=sepformer_source, device=self.device ) @torch.no_grad() def extract( self, mixture_bytes: bytes, enrollment_bytes: bytes ) -> tuple[bytes, bytes, dict]: mixture = _to_mono_16k(mixture_bytes) enrollment = _to_mono_16k(enrollment_bytes) orig_mix_samples = len(mixture) max_samples = int(MAX_MIX_SECONDS * SAMPLE_RATE) if len(mixture) > max_samples: mixture = mixture[:max_samples] enrollment = _fit(enrollment, int(ENROLL_SECONDS * SAMPLE_RATE)) enrollment = peak_normalize(enrollment, peak=0.95) mix_t = torch.from_numpy(mixture).unsqueeze(0).to(self.device) enr_t = torch.from_numpy(enrollment).unsqueeze(0).to(self.device) extracted_t, residue_t, model_meta = self.model(mix_t, enr_t) extracted = extracted_t.squeeze(0).cpu().numpy() residue = residue_t.squeeze(0).cpu().numpy() # Match the mixture's loudness for natural playback (SepFormer outputs # are typically lower-amplitude than the input). mix_peak = float(np.max(np.abs(mixture[: len(extracted)]))) + 1e-8 for arr in (extracted, residue): peak = float(np.max(np.abs(arr))) + 1e-8 if peak > 0: arr *= mix_peak * 0.95 / peak meta = { "sample_rate": SAMPLE_RATE, "input_seconds": orig_mix_samples / SAMPLE_RATE, "output_seconds": len(extracted) / SAMPLE_RATE, "truncated": orig_mix_samples > max_samples, "backend": "sepformer", **model_meta, } return _encode_wav(extracted), _encode_wav(residue), meta