Spaces:
Sleeping
Sleeping
| import io | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import cast | |
| import torch | |
| import torchaudio | |
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from speechbrain.inference.speaker import SpeakerRecognition | |
| from storage import StorageBackend, get_storage_backend | |
| logger = logging.getLogger("uvicorn.error") | |
| speaker_model: SpeakerRecognition | None = None | |
| speaker_embeddings: dict[str, list[float]] = {} | |
| storage: StorageBackend | None = None | |
| def preprocess_audio(audio_bytes: bytes) -> torch.Tensor: | |
| """Load and preprocess audio to 16kHz mono.""" | |
| audio_buffer = io.BytesIO(audio_bytes) | |
| waveform, sample_rate = torchaudio.load(audio_buffer) | |
| # Resample to 16kHz if necessary | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=16000 | |
| ) | |
| waveform = resampler(waveform) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| return cast(torch.Tensor, waveform) | |
| async def lifespan(app: FastAPI): | |
| global speaker_model, speaker_embeddings, storage | |
| logger.info("Loading SpeechBrain Speaker Recognition model...") | |
| speaker_model = SpeakerRecognition.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir="pretrained_models/spkrec-ecapa-voxceleb", | |
| ) | |
| storage = get_storage_backend() | |
| speaker_embeddings = storage.load() | |
| logger.info( | |
| f"Model loaded successfully! {len(speaker_embeddings)} speakers registered." | |
| ) | |
| yield | |
| logger.info("Shutting down...") | |
| app = FastAPI(title="Reachy SpeechBrain API", lifespan=lifespan) | |
| def health(): | |
| return {"status": "ok"} | |
| def list_speakers(): | |
| """List all registered speakers.""" | |
| return {"speakers": list(speaker_embeddings.keys())} | |
| async def enroll_speaker(name: str, file: UploadFile = File(...)): | |
| """Enroll a new speaker by providing an audio sample.""" | |
| if speaker_model is None: | |
| raise RuntimeError("Model not loaded") | |
| audio_bytes = await file.read() | |
| waveform = preprocess_audio(audio_bytes) | |
| # Extract embedding | |
| embedding = speaker_model.encode_batch(waveform) | |
| embedding_list = embedding.squeeze().tolist() | |
| # Store embedding | |
| speaker_embeddings[name] = embedding_list | |
| if storage: | |
| storage.save(speaker_embeddings) | |
| return { | |
| "message": f"Speaker '{name}' enrolled successfully", | |
| "embedding_size": len(embedding_list), | |
| } | |
| def delete_speaker(name: str): | |
| """Delete a registered speaker.""" | |
| if name not in speaker_embeddings: | |
| raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found") | |
| del speaker_embeddings[name] | |
| if storage: | |
| storage.save(speaker_embeddings) | |
| return {"message": f"Speaker '{name}' deleted successfully"} | |
| async def identify_speaker(file: UploadFile = File(...)): | |
| """Identify the speaker from an audio sample.""" | |
| if speaker_model is None: | |
| raise RuntimeError("Model not loaded") | |
| if not speaker_embeddings: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No speakers enrolled. Please enroll speakers first.", | |
| ) | |
| audio_bytes = await file.read() | |
| waveform = preprocess_audio(audio_bytes) | |
| # Extract embedding for input audio | |
| input_embedding = speaker_model.encode_batch(waveform) | |
| # Compare with all registered speakers | |
| best_match = None | |
| best_score = -1.0 | |
| for name, stored_embedding in speaker_embeddings.items(): | |
| stored_tensor = torch.tensor(stored_embedding).unsqueeze(0) | |
| score = speaker_model.similarity(input_embedding, stored_tensor) | |
| score_value = float(score.squeeze()) | |
| if score_value > best_score: | |
| best_score = score_value | |
| best_match = name | |
| # Threshold for identification (ECAPA-TDNN typically uses ~0.25) | |
| threshold = 0.25 | |
| identified = best_score >= threshold | |
| return { | |
| "identified": identified, | |
| "speaker": best_match if identified else None, | |
| "confidence": best_score, | |
| "threshold": threshold, | |
| } | |
| async def verify_speaker(name: str, file: UploadFile = File(...)): | |
| """Verify if the audio matches a specific speaker.""" | |
| if speaker_model is None: | |
| raise RuntimeError("Model not loaded") | |
| if name not in speaker_embeddings: | |
| raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found") | |
| audio_bytes = await file.read() | |
| waveform = preprocess_audio(audio_bytes) | |
| # Extract embedding for input audio | |
| input_embedding = speaker_model.encode_batch(waveform) | |
| # Compare with stored embedding | |
| stored_tensor = torch.tensor(speaker_embeddings[name]).unsqueeze(0) | |
| score = speaker_model.similarity(input_embedding, stored_tensor) | |
| score_value = float(score.squeeze()) | |
| threshold = 0.25 | |
| verified = score_value >= threshold | |
| return { | |
| "verified": verified, | |
| "speaker": name, | |
| "confidence": score_value, | |
| "threshold": threshold, | |
| } | |