Spaces:
Running
Running
File size: 7,878 Bytes
828f7dd 32de4f6 828f7dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """Inference helpers: load a trained Vanta checkpoint OR a SepFormer-based
backbone and extract a target speaker.
Two backends, same interface:
- VantaInference : our from-scratch trained checkpoint
- VantaSepFormerInference : pretrained SepFormer + our ECAPA selector
Pick at server startup via the VANTA_BACKEND env var. The trained-from-scratch
model is the project's "training pedigree" piece; the SepFormer backbone
delivers the audio quality we need for the live demo.
"""
from __future__ import annotations
import io
import subprocess
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
from vanta.config import SAMPLE_RATE
from vanta.models.vanta import Vanta, VantaConfig
from vanta.utils.audio import peak_normalize
MAX_MIX_SECONDS = 30.0
ENROLL_SECONDS = 5.0
def _ffmpeg_decode(raw: bytes) -> np.ndarray:
"""Pipe arbitrary-container bytes through ffmpeg, get mono 16 kHz float32.
libsndfile can't read MP4/M4A/WebM/MOV and so on. ffmpeg can. We spawn it
on demand and stream in/out via pipes to avoid temp files.
"""
from imageio_ffmpeg import get_ffmpeg_exe
cmd = [
get_ffmpeg_exe(),
"-hide_banner",
"-loglevel", "error",
"-i", "pipe:0",
"-vn", # ignore any video stream
"-ac", "1", # mono
"-ar", str(SAMPLE_RATE),
"-f", "f32le", # raw float32 output — trivial to np.frombuffer
"pipe:1",
]
proc = subprocess.run(cmd, input=raw, capture_output=True)
if proc.returncode != 0:
err = proc.stderr.decode("utf-8", errors="replace").strip()
raise ValueError(f"ffmpeg decode failed: {err or 'no stderr'}")
return np.frombuffer(proc.stdout, dtype=np.float32).copy()
def _to_mono_16k(raw: bytes) -> np.ndarray:
# Fast path: libsndfile handles WAV/FLAC/OGG/MP3 without spawning a process.
try:
wav, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
except Exception:
# Anything libsndfile refuses — MP4, M4A, WebM, MOV, etc. — goes to ffmpeg.
return _ffmpeg_decode(raw)
if wav.ndim > 1:
wav = wav.mean(axis=1)
if sr != SAMPLE_RATE:
import soxr
wav = soxr.resample(wav, sr, SAMPLE_RATE, quality="HQ")
return wav.astype(np.float32, copy=False)
def _fit(wav: np.ndarray, target_samples: int) -> np.ndarray:
if len(wav) >= target_samples:
return wav[:target_samples]
out = np.zeros(target_samples, dtype=wav.dtype)
out[: len(wav)] = wav
return out
class VantaInference:
"""Wraps a trained Vanta model for single-file inference.
Load once at startup, call `.extract(mixture_bytes, enrollment_bytes)` per
request. Returns (extracted_wav_bytes, residue_wav_bytes).
"""
def __init__(self, checkpoint_path: Path, repeats: int = 2, device: str = "auto"):
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model = Vanta(VantaConfig(repeats=repeats))
ck = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(ck["model_state"])
self.model.to(self.device).eval()
@torch.no_grad()
def extract(
self, mixture_bytes: bytes, enrollment_bytes: bytes
) -> tuple[bytes, bytes, dict]:
mixture = _to_mono_16k(mixture_bytes)
enrollment = _to_mono_16k(enrollment_bytes)
# Guardrails on request size.
orig_mix_samples = len(mixture)
max_samples = int(MAX_MIX_SECONDS * SAMPLE_RATE)
if len(mixture) > max_samples:
mixture = mixture[:max_samples]
# Enrollment has to be exactly ENROLL_SECONDS for our trained model.
enrollment = _fit(enrollment, int(ENROLL_SECONDS * SAMPLE_RATE))
enrollment = peak_normalize(enrollment, peak=0.95)
mix_t = torch.from_numpy(mixture).unsqueeze(0).to(self.device)
enr_t = torch.from_numpy(enrollment).unsqueeze(0).to(self.device)
# AMP matches how we trained, and it halves memory on long clips.
use_amp = self.device.type == "cuda"
if use_amp:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
est = self.model(mix_t, enrollment=enr_t).float()
else:
est = self.model(mix_t, enrollment=enr_t)
estimate = est.squeeze(0).cpu().numpy()
# SI-SDR is scale-invariant, so nothing in training penalizes the decoder
# for drifting to huge amplitudes. Model outputs routinely peak at
# ±100+. Match the mixture's loudness so playback sounds natural and
# PCM_16 encoding doesn't clip.
mix_peak = float(np.max(np.abs(mixture[: len(estimate)]))) + 1e-8
est_peak = float(np.max(np.abs(estimate))) + 1e-8
estimate = estimate * (mix_peak * 0.95 / est_peak)
# Residue = what Vanta removed. Handy for demos — users can play it and
# hear "this is what the void consumed."
residue = mixture[: len(estimate)] - estimate
meta = {
"sample_rate": SAMPLE_RATE,
"input_seconds": orig_mix_samples / SAMPLE_RATE,
"output_seconds": len(estimate) / SAMPLE_RATE,
"truncated": orig_mix_samples > max_samples,
}
return _encode_wav(estimate), _encode_wav(residue), meta
def _encode_wav(wav: np.ndarray) -> bytes:
buf = io.BytesIO()
sf.write(buf, wav, SAMPLE_RATE, subtype="PCM_16", format="WAV")
return buf.getvalue()
class VantaSepFormerInference:
"""SepFormer-backbone inference. Same public interface as VantaInference
so server.py can swap between them without code changes."""
def __init__(
self,
sepformer_source: str = "speechbrain/sepformer-libri2mix",
device: str = "auto",
):
from vanta.models.sepformer_tse import SepFormerTSE
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model = SepFormerTSE(
sepformer_source=sepformer_source, device=self.device
)
@torch.no_grad()
def extract(
self, mixture_bytes: bytes, enrollment_bytes: bytes
) -> tuple[bytes, bytes, dict]:
mixture = _to_mono_16k(mixture_bytes)
enrollment = _to_mono_16k(enrollment_bytes)
orig_mix_samples = len(mixture)
max_samples = int(MAX_MIX_SECONDS * SAMPLE_RATE)
if len(mixture) > max_samples:
mixture = mixture[:max_samples]
enrollment = _fit(enrollment, int(ENROLL_SECONDS * SAMPLE_RATE))
enrollment = peak_normalize(enrollment, peak=0.95)
mix_t = torch.from_numpy(mixture).unsqueeze(0).to(self.device)
enr_t = torch.from_numpy(enrollment).unsqueeze(0).to(self.device)
extracted_t, residue_t, model_meta = self.model(mix_t, enr_t)
extracted = extracted_t.squeeze(0).cpu().numpy()
residue = residue_t.squeeze(0).cpu().numpy()
# Match the mixture's loudness for natural playback (SepFormer outputs
# are typically lower-amplitude than the input).
mix_peak = float(np.max(np.abs(mixture[: len(extracted)]))) + 1e-8
for arr in (extracted, residue):
peak = float(np.max(np.abs(arr))) + 1e-8
if peak > 0:
arr *= mix_peak * 0.95 / peak
meta = {
"sample_rate": SAMPLE_RATE,
"input_seconds": orig_mix_samples / SAMPLE_RATE,
"output_seconds": len(extracted) / SAMPLE_RATE,
"truncated": orig_mix_samples > max_samples,
"backend": "sepformer",
**model_meta,
}
return _encode_wav(extracted), _encode_wav(residue), meta
|