Spaces:
Running
Running
File size: 6,129 Bytes
4b8c370 d7a2919 f54b658 d7a2919 f54b658 d7a2919 f54b658 d7a2919 4b8c370 d7a2919 4b8c370 d7a2919 f54b658 76debc2 d7a2919 f54b658 d7a2919 4b8c370 f54b658 4b8c370 f54b658 4b8c370 d7a2919 f54b658 4b8c370 f54b658 d7a2919 4b8c370 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
Speaker Embedding Extraction using ECAPA-TDNN architecture via SpeechBrain.
Handles audio preprocessing, feature extraction, and L2-normalized embeddings.
"""
import inspect
from pathlib import Path
from typing import Union, List, Tuple
import numpy as np
import torch
import torchaudio
from loguru import logger
class EcapaTDNNEmbedder:
"""
Speaker embedding extractor using ECAPA-TDNN architecture.
Produces 192-dim L2-normalized speaker embeddings per audio segment.
"""
MODEL_SOURCE = "speechbrain/spkrec-ecapa-voxceleb"
SAMPLE_RATE = 16000
EMBEDDING_DIM = 192
def __init__(self, device: str = "auto", cache_dir: str = "./model_cache"):
self.device = self._resolve_device(device)
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._model = None
logger.info(f"EcapaTDNNEmbedder initialized on device: {self.device}")
def _resolve_device(self, device: str) -> str:
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return device
def _build_hparams_kwargs(self, encoder_cls, savedir: Path, hf_cache: Path) -> dict:
kwargs = {
"source": self.MODEL_SOURCE,
"savedir": str(savedir),
"run_opts": {"device": self.device},
}
sig = inspect.signature(encoder_cls.from_hparams)
if "huggingface_cache_dir" in sig.parameters:
kwargs["huggingface_cache_dir"] = str(hf_cache)
if "local_strategy" in sig.parameters:
try:
from speechbrain.utils.fetching import LocalStrategy
kwargs["local_strategy"] = LocalStrategy.COPY
except Exception:
pass
return kwargs
def _load_model(self):
if self._model is not None:
return
try:
try:
from speechbrain.inference.classifiers import EncoderClassifier
except ImportError:
# Backward compatibility with older SpeechBrain versions.
from speechbrain.pretrained import EncoderClassifier
savedir = self.cache_dir / "ecapa_tdnn"
hf_cache = self.cache_dir / "hf_cache"
savedir.mkdir(parents=True, exist_ok=True)
hf_cache.mkdir(parents=True, exist_ok=True)
logger.info(f"Loading ECAPA-TDNN from {self.MODEL_SOURCE}...")
logger.info(f"Savedir: {savedir}, exists: {savedir.exists()}")
kwargs = self._build_hparams_kwargs(EncoderClassifier, savedir, hf_cache)
model = EncoderClassifier.from_hparams(**kwargs)
if model is None:
# Some SpeechBrain/HF hub combinations ignore optional kwargs.
logger.warning("ECAPA load returned None; retrying with minimal from_hparams kwargs.")
model = EncoderClassifier.from_hparams(
source=self.MODEL_SOURCE,
savedir=str(savedir),
run_opts={"device": self.device},
)
if model is None:
raise RuntimeError("EncoderClassifier.from_hparams returned None")
self._model = model
self._model.eval()
logger.success("ECAPA-TDNN model loaded successfully.")
except ImportError as exc:
raise ImportError("SpeechBrain not installed.") from exc
except Exception as exc:
raise RuntimeError(f"Failed to load ECAPA-TDNN model: {exc}") from exc
def preprocess_audio(
self, audio: Union[np.ndarray, torch.Tensor], sample_rate: int
) -> torch.Tensor:
"""Resample and normalize audio to 16kHz mono float32 tensor."""
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
if audio.dim() == 1:
audio = audio.unsqueeze(0)
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
if sample_rate != self.SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.SAMPLE_RATE
)
audio = resampler(audio)
max_val = audio.abs().max()
if max_val > 0:
audio = audio / max_val
return audio.squeeze(0)
def extract_embedding(self, audio: torch.Tensor) -> np.ndarray:
"""
Extract L2-normalized ECAPA-TDNN embedding from a preprocessed audio tensor.
Returns L2-normalized embedding of shape (192,)
"""
self._load_model()
with torch.no_grad():
audio_batch = audio.unsqueeze(0).to(self.device)
lengths = torch.tensor([1.0]).to(self.device)
embedding = self._model.encode_batch(audio_batch, lengths)
embedding = embedding.squeeze().cpu().numpy()
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
def extract_embeddings_from_segments(
self,
audio: torch.Tensor,
sample_rate: int,
segments: List[Tuple[float, float]],
min_duration: float = 0.5,
) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
"""Extract embeddings for a list of (start, end) time segments."""
processed = self.preprocess_audio(audio, sample_rate)
embeddings = []
valid_segments = []
for start, end in segments:
duration = end - start
if duration < min_duration:
continue
start_sample = int(start * self.SAMPLE_RATE)
end_sample = int(end * self.SAMPLE_RATE)
segment_audio = processed[start_sample:end_sample]
if segment_audio.shape[0] == 0:
continue
emb = self.extract_embedding(segment_audio)
embeddings.append(emb)
valid_segments.append((start, end))
if not embeddings:
return np.empty((0, self.EMBEDDING_DIM)), []
return np.stack(embeddings), valid_segments
|