Spaces:
Running
Running
File size: 2,449 Bytes
32de4f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | """Pretrained ECAPA-TDNN speaker encoder.
Wraps speechbrain's VoxCeleb-trained ECAPA-TDNN. Given an enrollment clip,
returns a 192-d speaker embedding (the "voice fingerprint").
Frozen by default: fine-tuning speaker encoders during TSE training tends to
destabilize the identity space. We want the fingerprint to stay recognizable.
The checkpoint (~25 MB) downloads on first use to `data/_models/ecapa_voxceleb/`.
"""
from __future__ import annotations
from pathlib import Path
import torch
import torch.nn as nn
from vanta.config import DATA_DIR
ECAPA_EMBED_DIM = 192
class SpeakerEncoder(nn.Module):
def __init__(
self,
savedir: Path | None = None,
freeze: bool = True,
run_opts: dict | None = None,
):
super().__init__()
# Lazy import: loading speechbrain pulls in a lot; we only want it when
# this class is actually instantiated.
from speechbrain.inference.speaker import EncoderClassifier
from speechbrain.utils.fetching import LocalStrategy
savedir = savedir or (DATA_DIR / "_models" / "ecapa_voxceleb")
savedir.mkdir(parents=True, exist_ok=True)
# COPY instead of SYMLINK — Windows refuses symlinks without admin or
# Developer Mode, so defaulting to COPY is portable.
self.encoder = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=str(savedir),
run_opts=run_opts or {},
local_strategy=LocalStrategy.COPY,
)
self.embed_dim = ECAPA_EMBED_DIM
self.freeze = freeze
if freeze:
for p in self.encoder.parameters():
p.requires_grad_(False)
self.encoder.eval()
def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""wav: (B, T) at 16 kHz. Returns (B, 192) speaker embeddings."""
if wav.dim() == 3:
wav = wav.squeeze(1)
# ECAPA expects (B, T). speechbrain returns (B, 1, 192) -> squeeze.
if self.freeze:
with torch.no_grad():
emb = self.encoder.encode_batch(wav)
else:
emb = self.encoder.encode_batch(wav)
return emb.squeeze(1)
def train(self, mode: bool = True):
# If frozen, keep batchnorm/running stats in eval regardless of parent.
super().train(mode)
if self.freeze:
self.encoder.eval()
return self
|