goabonga's picture
feat: add speaker recognition API with SpeechBrain ECAPA-TDNN
7323d5e unverified
import io
import logging
from contextlib import asynccontextmanager
from typing import cast
import torch
import torchaudio
from fastapi import FastAPI, File, HTTPException, UploadFile
from speechbrain.inference.speaker import SpeakerRecognition
from storage import StorageBackend, get_storage_backend
logger = logging.getLogger("uvicorn.error")
speaker_model: SpeakerRecognition | None = None
speaker_embeddings: dict[str, list[float]] = {}
storage: StorageBackend | None = None
def preprocess_audio(audio_bytes: bytes) -> torch.Tensor:
"""Load and preprocess audio to 16kHz mono."""
audio_buffer = io.BytesIO(audio_bytes)
waveform, sample_rate = torchaudio.load(audio_buffer)
# Resample to 16kHz if necessary
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
return cast(torch.Tensor, waveform)
@asynccontextmanager
async def lifespan(app: FastAPI):
global speaker_model, speaker_embeddings, storage
logger.info("Loading SpeechBrain Speaker Recognition model...")
speaker_model = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spkrec-ecapa-voxceleb",
)
storage = get_storage_backend()
speaker_embeddings = storage.load()
logger.info(
f"Model loaded successfully! {len(speaker_embeddings)} speakers registered."
)
yield
logger.info("Shutting down...")
app = FastAPI(title="Reachy SpeechBrain API", lifespan=lifespan)
@app.get("/health", include_in_schema=False)
def health():
return {"status": "ok"}
@app.get("/speakers")
def list_speakers():
"""List all registered speakers."""
return {"speakers": list(speaker_embeddings.keys())}
@app.post("/speakers/{name}/enroll")
async def enroll_speaker(name: str, file: UploadFile = File(...)):
"""Enroll a new speaker by providing an audio sample."""
if speaker_model is None:
raise RuntimeError("Model not loaded")
audio_bytes = await file.read()
waveform = preprocess_audio(audio_bytes)
# Extract embedding
embedding = speaker_model.encode_batch(waveform)
embedding_list = embedding.squeeze().tolist()
# Store embedding
speaker_embeddings[name] = embedding_list
if storage:
storage.save(speaker_embeddings)
return {
"message": f"Speaker '{name}' enrolled successfully",
"embedding_size": len(embedding_list),
}
@app.delete("/speakers/{name}")
def delete_speaker(name: str):
"""Delete a registered speaker."""
if name not in speaker_embeddings:
raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found")
del speaker_embeddings[name]
if storage:
storage.save(speaker_embeddings)
return {"message": f"Speaker '{name}' deleted successfully"}
@app.post("/identify")
async def identify_speaker(file: UploadFile = File(...)):
"""Identify the speaker from an audio sample."""
if speaker_model is None:
raise RuntimeError("Model not loaded")
if not speaker_embeddings:
raise HTTPException(
status_code=400,
detail="No speakers enrolled. Please enroll speakers first.",
)
audio_bytes = await file.read()
waveform = preprocess_audio(audio_bytes)
# Extract embedding for input audio
input_embedding = speaker_model.encode_batch(waveform)
# Compare with all registered speakers
best_match = None
best_score = -1.0
for name, stored_embedding in speaker_embeddings.items():
stored_tensor = torch.tensor(stored_embedding).unsqueeze(0)
score = speaker_model.similarity(input_embedding, stored_tensor)
score_value = float(score.squeeze())
if score_value > best_score:
best_score = score_value
best_match = name
# Threshold for identification (ECAPA-TDNN typically uses ~0.25)
threshold = 0.25
identified = best_score >= threshold
return {
"identified": identified,
"speaker": best_match if identified else None,
"confidence": best_score,
"threshold": threshold,
}
@app.post("/verify")
async def verify_speaker(name: str, file: UploadFile = File(...)):
"""Verify if the audio matches a specific speaker."""
if speaker_model is None:
raise RuntimeError("Model not loaded")
if name not in speaker_embeddings:
raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found")
audio_bytes = await file.read()
waveform = preprocess_audio(audio_bytes)
# Extract embedding for input audio
input_embedding = speaker_model.encode_batch(waveform)
# Compare with stored embedding
stored_tensor = torch.tensor(speaker_embeddings[name]).unsqueeze(0)
score = speaker_model.similarity(input_embedding, stored_tensor)
score_value = float(score.squeeze())
threshold = 0.25
verified = score_value >= threshold
return {
"verified": verified,
"speaker": name,
"confidence": score_value,
"threshold": threshold,
}