File size: 5,320 Bytes
7323d5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import io
import logging
from contextlib import asynccontextmanager
from typing import cast

import torch
import torchaudio
from fastapi import FastAPI, File, HTTPException, UploadFile
from speechbrain.inference.speaker import SpeakerRecognition

from storage import StorageBackend, get_storage_backend

logger = logging.getLogger("uvicorn.error")

speaker_model: SpeakerRecognition | None = None
speaker_embeddings: dict[str, list[float]] = {}
storage: StorageBackend | None = None


def preprocess_audio(audio_bytes: bytes) -> torch.Tensor:
    """Load and preprocess audio to 16kHz mono."""
    audio_buffer = io.BytesIO(audio_bytes)
    waveform, sample_rate = torchaudio.load(audio_buffer)

    # Resample to 16kHz if necessary
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(
            orig_freq=sample_rate, new_freq=16000
        )
        waveform = resampler(waveform)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    return cast(torch.Tensor, waveform)


@asynccontextmanager
async def lifespan(app: FastAPI):
    global speaker_model, speaker_embeddings, storage
    logger.info("Loading SpeechBrain Speaker Recognition model...")
    speaker_model = SpeakerRecognition.from_hparams(
        source="speechbrain/spkrec-ecapa-voxceleb",
        savedir="pretrained_models/spkrec-ecapa-voxceleb",
    )

    storage = get_storage_backend()
    speaker_embeddings = storage.load()

    logger.info(
        f"Model loaded successfully! {len(speaker_embeddings)} speakers registered."
    )
    yield
    logger.info("Shutting down...")


app = FastAPI(title="Reachy SpeechBrain API", lifespan=lifespan)


@app.get("/health", include_in_schema=False)
def health():
    return {"status": "ok"}


@app.get("/speakers")
def list_speakers():
    """List all registered speakers."""
    return {"speakers": list(speaker_embeddings.keys())}


@app.post("/speakers/{name}/enroll")
async def enroll_speaker(name: str, file: UploadFile = File(...)):
    """Enroll a new speaker by providing an audio sample."""
    if speaker_model is None:
        raise RuntimeError("Model not loaded")

    audio_bytes = await file.read()
    waveform = preprocess_audio(audio_bytes)

    # Extract embedding
    embedding = speaker_model.encode_batch(waveform)
    embedding_list = embedding.squeeze().tolist()

    # Store embedding
    speaker_embeddings[name] = embedding_list
    if storage:
        storage.save(speaker_embeddings)

    return {
        "message": f"Speaker '{name}' enrolled successfully",
        "embedding_size": len(embedding_list),
    }


@app.delete("/speakers/{name}")
def delete_speaker(name: str):
    """Delete a registered speaker."""
    if name not in speaker_embeddings:
        raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found")

    del speaker_embeddings[name]
    if storage:
        storage.save(speaker_embeddings)

    return {"message": f"Speaker '{name}' deleted successfully"}


@app.post("/identify")
async def identify_speaker(file: UploadFile = File(...)):
    """Identify the speaker from an audio sample."""
    if speaker_model is None:
        raise RuntimeError("Model not loaded")

    if not speaker_embeddings:
        raise HTTPException(
            status_code=400,
            detail="No speakers enrolled. Please enroll speakers first.",
        )

    audio_bytes = await file.read()
    waveform = preprocess_audio(audio_bytes)

    # Extract embedding for input audio
    input_embedding = speaker_model.encode_batch(waveform)

    # Compare with all registered speakers
    best_match = None
    best_score = -1.0

    for name, stored_embedding in speaker_embeddings.items():
        stored_tensor = torch.tensor(stored_embedding).unsqueeze(0)
        score = speaker_model.similarity(input_embedding, stored_tensor)
        score_value = float(score.squeeze())

        if score_value > best_score:
            best_score = score_value
            best_match = name

    # Threshold for identification (ECAPA-TDNN typically uses ~0.25)
    threshold = 0.25
    identified = best_score >= threshold

    return {
        "identified": identified,
        "speaker": best_match if identified else None,
        "confidence": best_score,
        "threshold": threshold,
    }


@app.post("/verify")
async def verify_speaker(name: str, file: UploadFile = File(...)):
    """Verify if the audio matches a specific speaker."""
    if speaker_model is None:
        raise RuntimeError("Model not loaded")

    if name not in speaker_embeddings:
        raise HTTPException(status_code=404, detail=f"Speaker '{name}' not found")

    audio_bytes = await file.read()
    waveform = preprocess_audio(audio_bytes)

    # Extract embedding for input audio
    input_embedding = speaker_model.encode_batch(waveform)

    # Compare with stored embedding
    stored_tensor = torch.tensor(speaker_embeddings[name]).unsqueeze(0)
    score = speaker_model.similarity(input_embedding, stored_tensor)
    score_value = float(score.squeeze())

    threshold = 0.25
    verified = score_value >= threshold

    return {
        "verified": verified,
        "speaker": name,
        "confidence": score_value,
        "threshold": threshold,
    }