Spaces:
Sleeping
Sleeping
File size: 5,320 Bytes
7323d5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import io
import logging
from contextlib import asynccontextmanager
from typing import cast
import torch
import torchaudio
from fastapi import FastAPI, File, HTTPException, UploadFile
from speechbrain.inference.speaker import SpeakerRecognition
from storage import StorageBackend, get_storage_backend
logger = logging.getLogger("uvicorn.error")
speaker_model: SpeakerRecognition | None = None
speaker_embeddings: dict[str, list[float]] = {}
storage: StorageBackend | None = None
def preprocess_audio(audio_bytes: bytes) -> torch.Tensor:
"""Load and preprocess audio to 16kHz mono."""
audio_buffer = io.BytesIO(audio_bytes)
waveform, sample_rate = torchaudio.load(audio_buffer)
# Resample to 16kHz if necessary
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
return cast(torch.Tensor, waveform)
@asynccontextmanager
async def lifespan(app: FastAPI):
global speaker_model, speaker_embeddings, storage
logger.info("Loading SpeechBrain Speaker Recognition model...")
speaker_model = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spkrec-ecapa-voxceleb",
)
storage = get_storage_backend()
speaker_embeddings = storage.load()
logger.info(
f"Model loaded successfully! {len(speaker_embeddings)} speakers registered."
)
yield
logger.info("Shutting down...")
app = FastAPI(title="Reachy SpeechBrain API", lifespan=lifespan)
@app.get("/health", include_in_schema=False)
def health():
return {"status": "ok"}
@app.get("/speakers")
def list_speakers():
"""List all registered speakers."""
return {"speakers": list(speaker_embeddings.keys())}
@app.post("/speakers/{name}/enroll")
async def enroll_speaker(name: str, file: UploadFile = File(...)):
"""Enroll a new speaker by providing an audio sample."""
if speaker_model is None:
raise RuntimeError("Model not loaded")
audio_bytes = await file.read()
waveform = preprocess_audio(audio_bytes)
# Extract embedding
embedding = speaker_model.encode_batch(waveform)
embedding_list = embedding.squeeze().tolist()
# Store embedding
speaker_embeddings[name] = embedding_list
if storage:
storage.save(speaker_embeddings)
return {
"message": f"Speaker '{name}' enrolled successfully",
"embedding_size": len(embedding_list),
}
@app.delete("/speakers/{name}")
def delete_speaker(name: str):
"""Delete a registered speaker."""
if name not in speaker_embeddings:
raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found")
del speaker_embeddings[name]
if storage:
storage.save(speaker_embeddings)
return {"message": f"Speaker '{name}' deleted successfully"}
@app.post("/identify")
async def identify_speaker(file: UploadFile = File(...)):
"""Identify the speaker from an audio sample."""
if speaker_model is None:
raise RuntimeError("Model not loaded")
if not speaker_embeddings:
raise HTTPException(
status_code=400,
detail="No speakers enrolled. Please enroll speakers first.",
)
audio_bytes = await file.read()
waveform = preprocess_audio(audio_bytes)
# Extract embedding for input audio
input_embedding = speaker_model.encode_batch(waveform)
# Compare with all registered speakers
best_match = None
best_score = -1.0
for name, stored_embedding in speaker_embeddings.items():
stored_tensor = torch.tensor(stored_embedding).unsqueeze(0)
score = speaker_model.similarity(input_embedding, stored_tensor)
score_value = float(score.squeeze())
if score_value > best_score:
best_score = score_value
best_match = name
# Threshold for identification (ECAPA-TDNN typically uses ~0.25)
threshold = 0.25
identified = best_score >= threshold
return {
"identified": identified,
"speaker": best_match if identified else None,
"confidence": best_score,
"threshold": threshold,
}
@app.post("/verify")
async def verify_speaker(name: str, file: UploadFile = File(...)):
"""Verify if the audio matches a specific speaker."""
if speaker_model is None:
raise RuntimeError("Model not loaded")
if name not in speaker_embeddings:
raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found")
audio_bytes = await file.read()
waveform = preprocess_audio(audio_bytes)
# Extract embedding for input audio
input_embedding = speaker_model.encode_batch(waveform)
# Compare with stored embedding
stored_tensor = torch.tensor(speaker_embeddings[name]).unsqueeze(0)
score = speaker_model.similarity(input_embedding, stored_tensor)
score_value = float(score.squeeze())
threshold = 0.25
verified = score_value >= threshold
return {
"verified": verified,
"speaker": name,
"confidence": score_value,
"threshold": threshold,
}
|