Who-Spoke-When / models /embedder.py
ConvxO2's picture
Fix cache dir for HuggingFace Spaces
c16fd7c
raw
history blame
4.97 kB
"""
Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
"""
import os
import torch
import torchaudio
import numpy as np
from pathlib import Path
from typing import Union, List, Tuple
from loguru import logger
class EcapaTDNNEmbedder:
"""
Speaker embedding extractor using ECAPA-TDNN architecture.
Produces 192-dim L2-normalized speaker embeddings per audio segment.
"""
MODEL_SOURCE = "speechbrain/spkrec-ecapa-voxceleb"
SAMPLE_RATE = 16000
EMBEDDING_DIM = 192
def __init__(self, device: str = "auto", cache_dir: str = "/tmp/model_cache"):
self.device = self._resolve_device(device)
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._model = None
logger.info(f"EcapaTDNNEmbedder initialized on device: {self.device}")
def _resolve_device(self, device: str) -> str:
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return device
def _load_model(self):
if self._model is not None:
return
try:
import speechbrain.utils.fetching as _fetching
import shutil as _shutil
from pathlib import Path as _Path
def _patched_link(src, dst, local_strategy):
dst = _Path(dst)
src = _Path(src)
dst.parent.mkdir(parents=True, exist_ok=True)
if dst.exists() or dst.is_symlink():
dst.unlink()
_shutil.copy2(str(src), str(dst))
_fetching.link_with_strategy = _patched_link
from speechbrain.inference.classifiers import EncoderClassifier
logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
savedir = str(self.cache_dir / "ecapa_tdnn")
import os
os.makedirs(savedir, exist_ok=True)
self._model = EncoderClassifier.from_hparams(
source=self.MODEL_SOURCE,
savedir=savedir,
run_opts={"device": self.device},
)
self._model.eval()
logger.success("ECAPA-TDNN model loaded successfully.")
except ImportError:
raise ImportError("SpeechBrain not installed. Run: pip install speechbrain")
def preprocess_audio(
self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
) -> torch.Tensor:
"""Resample and normalize audio to 16kHz mono float32 tensor."""
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
if audio.dim() == 1:
audio = audio.unsqueeze(0)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if sample_rate != self.SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.SAMPLE_RATE
)
audio = resampler(audio)
max_val = audio.abs().max()
if max_val > 0:
audio = audio / max_val
return audio.squeeze(0)
def extract_embedding(self, audio: torch.Tensor) -> np.ndarray:
"""
Extract L2-normalized ECAPA-TDNN embedding from a preprocessed audio tensor.
Returns L2-normalized embedding of shape (192,)
"""
self._load_model()
with torch.no_grad():
audio_batch = audio.unsqueeze(0).to(self.device)
lengths = torch.tensor([1.0]).to(self.device)
embedding = self._model.encode_batch(audio_batch, lengths)
embedding = embedding.squeeze().cpu().numpy()
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
def extract_embeddings_from_segments(
self,
audio: torch.Tensor,
sample_rate: int,
segments: List[Tuple[float, float]],
min_duration: float = 0.5,
) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
"""Extract embeddings for a list of (start, end) time segments."""
processed = self.preprocess_audio(audio, sample_rate)
embeddings = []
valid_segments = []
for start, end in segments:
duration = end - start
if duration < min_duration:
continue
start_sample = int(start * self.SAMPLE_RATE)
end_sample = int(end * self.SAMPLE_RATE)
segment_audio = processed[start_sample:end_sample]
if segment_audio.shape[0] == 0:
continue
emb = self.extract_embedding(segment_audio)
embeddings.append(emb)
valid_segments.append((start, end))
if not embeddings:
return np.empty((0, self.EMBEDDING_DIM)), []
return np.stack(embeddings), valid_segments