goabonga's picture
feat: add speaker recognition API with SpeechBrain ECAPA-TDNN
7323d5e unverified
import io
import wave
from unittest.mock import MagicMock, patch
import pytest
import torch
import app as app_module
from app import preprocess_audio
def create_wav_buffer(
duration_seconds: float = 1.0, sample_rate: int = 16000
) -> io.BytesIO:
"""Create a valid WAV file buffer with silence."""
buffer = io.BytesIO()
n_frames = int(duration_seconds * sample_rate)
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit audio
wav_file.setframerate(sample_rate)
wav_file.writeframes(b"\x00\x00" * n_frames)
buffer.seek(0)
return buffer
def test_health(client):
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
def test_list_speakers_empty(client):
"""Test listing speakers when none are enrolled."""
original_embeddings = app_module.speaker_embeddings.copy()
app_module.speaker_embeddings.clear()
try:
response = client.get("/speakers")
assert response.status_code == 200
assert response.json() == {"speakers": []}
finally:
app_module.speaker_embeddings.update(original_embeddings)
def test_list_speakers_with_data(client):
"""Test listing speakers with enrolled speakers."""
original_embeddings = app_module.speaker_embeddings.copy()
app_module.speaker_embeddings["alice"] = [0.1] * 192
app_module.speaker_embeddings["bob"] = [0.2] * 192
try:
response = client.get("/speakers")
assert response.status_code == 200
data = response.json()
assert set(data["speakers"]) == {"alice", "bob"}
finally:
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_enroll_speaker(client):
"""Test enrolling a new speaker."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
mock_model = MagicMock()
mock_model.encode_batch.return_value = torch.randn(1, 192)
app_module.speaker_model = mock_model
app_module.speaker_embeddings.clear()
try:
buffer = create_wav_buffer()
with patch("app.torchaudio.load") as mock_load:
mock_load.return_value = (torch.zeros(1, 16000), 16000)
mock_storage = MagicMock()
original_storage = app_module.storage
app_module.storage = mock_storage
try:
response = client.post(
"/speakers/alice/enroll",
files={"file": ("test.wav", buffer, "audio/wav")},
)
assert response.status_code == 200
data = response.json()
assert "alice" in data["message"]
assert data["embedding_size"] == 192
assert "alice" in app_module.speaker_embeddings
finally:
app_module.storage = original_storage
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_enroll_speaker_model_not_loaded(client):
"""Test enrolling when model is not loaded."""
original_model = app_module.speaker_model
app_module.speaker_model = None
try:
buffer = create_wav_buffer()
with pytest.raises(RuntimeError, match="Model not loaded"):
client.post(
"/speakers/alice/enroll",
files={"file": ("test.wav", buffer, "audio/wav")},
)
finally:
app_module.speaker_model = original_model
def test_delete_speaker(client):
"""Test deleting a speaker."""
original_embeddings = app_module.speaker_embeddings.copy()
original_storage = app_module.storage
app_module.speaker_embeddings["alice"] = [0.1] * 192
mock_storage = MagicMock()
app_module.storage = mock_storage
try:
response = client.delete("/speakers/alice")
assert response.status_code == 200
assert "alice" not in app_module.speaker_embeddings
finally:
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
app_module.storage = original_storage
def test_delete_speaker_not_found(client):
"""Test deleting a non-existent speaker."""
response = client.delete("/speakers/unknown")
assert response.status_code == 404
def test_identify_speaker(client):
"""Test identifying a speaker."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
mock_model = MagicMock()
mock_model.encode_batch.return_value = torch.randn(1, 192)
mock_model.similarity.return_value = torch.tensor([[0.85]])
app_module.speaker_model = mock_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings["alice"] = [0.1] * 192
try:
buffer = create_wav_buffer()
with patch("app.torchaudio.load") as mock_load:
mock_load.return_value = (torch.zeros(1, 16000), 16000)
response = client.post(
"/identify", files={"file": ("test.wav", buffer, "audio/wav")}
)
assert response.status_code == 200
data = response.json()
assert data["identified"] is True
assert data["speaker"] == "alice"
assert "confidence" in data
assert "threshold" in data
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_identify_speaker_no_match(client):
"""Test identification when no speaker matches."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
mock_model = MagicMock()
mock_model.encode_batch.return_value = torch.randn(1, 192)
mock_model.similarity.return_value = torch.tensor([[0.1]]) # Below threshold
app_module.speaker_model = mock_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings["alice"] = [0.1] * 192
try:
buffer = create_wav_buffer()
with patch("app.torchaudio.load") as mock_load:
mock_load.return_value = (torch.zeros(1, 16000), 16000)
response = client.post(
"/identify", files={"file": ("test.wav", buffer, "audio/wav")}
)
assert response.status_code == 200
data = response.json()
assert data["identified"] is False
assert data["speaker"] is None
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_identify_no_speakers_enrolled(client):
"""Test identification when no speakers are enrolled."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
mock_model = MagicMock()
app_module.speaker_model = mock_model
app_module.speaker_embeddings.clear()
try:
buffer = create_wav_buffer()
response = client.post(
"/identify", files={"file": ("test.wav", buffer, "audio/wav")}
)
assert response.status_code == 400
assert "No speakers enrolled" in response.json()["detail"]
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_verify_speaker(client):
"""Test verifying a specific speaker."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
mock_model = MagicMock()
mock_model.encode_batch.return_value = torch.randn(1, 192)
mock_model.similarity.return_value = torch.tensor([[0.85]])
app_module.speaker_model = mock_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings["alice"] = [0.1] * 192
try:
buffer = create_wav_buffer()
with patch("app.torchaudio.load") as mock_load:
mock_load.return_value = (torch.zeros(1, 16000), 16000)
response = client.post(
"/verify",
params={"name": "alice"},
files={"file": ("test.wav", buffer, "audio/wav")},
)
assert response.status_code == 200
data = response.json()
assert data["verified"] is True
assert data["speaker"] == "alice"
assert "confidence" in data
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_verify_speaker_not_found(client):
"""Test verifying a non-existent speaker."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
mock_model = MagicMock()
app_module.speaker_model = mock_model
app_module.speaker_embeddings.clear()
try:
buffer = create_wav_buffer()
response = client.post(
"/verify",
params={"name": "unknown"},
files={"file": ("test.wav", buffer, "audio/wav")},
)
assert response.status_code == 404
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_verify_speaker_failed(client):
"""Test verification when speaker doesn't match."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
mock_model = MagicMock()
mock_model.encode_batch.return_value = torch.randn(1, 192)
mock_model.similarity.return_value = torch.tensor([[0.1]]) # Below threshold
app_module.speaker_model = mock_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings["alice"] = [0.1] * 192
try:
buffer = create_wav_buffer()
with patch("app.torchaudio.load") as mock_load:
mock_load.return_value = (torch.zeros(1, 16000), 16000)
response = client.post(
"/verify",
params={"name": "alice"},
files={"file": ("test.wav", buffer, "audio/wav")},
)
assert response.status_code == 200
data = response.json()
assert data["verified"] is False
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_identify_model_not_loaded(client):
"""Test identify when model is not loaded."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
app_module.speaker_model = None
app_module.speaker_embeddings["alice"] = [0.1] * 192
try:
buffer = create_wav_buffer()
with pytest.raises(RuntimeError, match="Model not loaded"):
client.post("/identify", files={"file": ("test.wav", buffer, "audio/wav")})
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_verify_model_not_loaded(client):
"""Test verify when model is not loaded."""
original_model = app_module.speaker_model
original_embeddings = app_module.speaker_embeddings.copy()
app_module.speaker_model = None
app_module.speaker_embeddings["alice"] = [0.1] * 192
try:
buffer = create_wav_buffer()
with pytest.raises(RuntimeError, match="Model not loaded"):
client.post(
"/verify",
params={"name": "alice"},
files={"file": ("test.wav", buffer, "audio/wav")},
)
finally:
app_module.speaker_model = original_model
app_module.speaker_embeddings.clear()
app_module.speaker_embeddings.update(original_embeddings)
def test_preprocess_audio_with_resampling():
"""Test audio preprocessing with resampling from 44.1kHz to 16kHz."""
with patch("app.torchaudio.load") as mock_load:
mock_load.return_value = (torch.zeros(1, 44100), 44100)
with patch("app.torchaudio.transforms.Resample") as mock_resample:
mock_resampler = MagicMock()
mock_resampler.return_value = torch.zeros(1, 16000)
mock_resample.return_value = mock_resampler
result = preprocess_audio(b"fake_audio_data")
mock_resample.assert_called_once_with(orig_freq=44100, new_freq=16000)
assert result.shape == (1, 16000)
def test_preprocess_audio_stereo_to_mono():
"""Test audio preprocessing converting stereo to mono."""
with patch("app.torchaudio.load") as mock_load:
# Stereo audio (2 channels)
mock_load.return_value = (torch.zeros(2, 16000), 16000)
result = preprocess_audio(b"fake_audio_data")
# Should be converted to mono (1 channel)
assert result.shape[0] == 1
def test_preprocess_audio_no_changes_needed():
"""Test audio preprocessing when no resampling or mono conversion needed."""
with patch("app.torchaudio.load") as mock_load:
mock_load.return_value = (torch.zeros(1, 16000), 16000)
result = preprocess_audio(b"fake_audio_data")
assert result.shape == (1, 16000)