voice-tools / tests /unit /services /test_speaker_extraction.py
jcudit's picture
jcudit HF Staff
feat: complete audio speaker separation feature with 3 workflows
cb39c05
"""
Unit tests for SpeakerExtractionService
Tests speaker extraction functionality including:
- Embedding extraction from reference clips
- Cosine similarity comparison
- Segment matching and filtering
- Reference clip validation
"""
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import pytest
from src.services.speaker_extraction import SpeakerExtractionService
from src.models.audio_segment import AudioSegment, SegmentType
from src.models.speaker_profile import SpeakerProfile
@pytest.fixture
def speaker_extraction_service():
"""Create SpeakerExtractionService instance for testing"""
with patch("src.services.speaker_extraction.Pipeline") as mock_pipeline:
service = SpeakerExtractionService()
service.embedding_model = Mock()
return service
@pytest.fixture
def mock_reference_audio(tmp_path):
"""Create mock reference clip audio file"""
ref_file = tmp_path / "reference.m4a"
ref_file.write_bytes(b"mock audio data")
return ref_file
@pytest.fixture
def mock_target_audio(tmp_path):
"""Create mock target audio file"""
target_file = tmp_path / "target.m4a"
target_file.write_bytes(b"mock audio data")
return target_file
class TestSpeakerExtractionService:
"""Test suite for SpeakerExtractionService"""
def test_service_initialization(self):
"""Test that service initializes embedding model correctly"""
with patch("src.services.speaker_extraction.Pipeline") as mock_pipeline:
service = SpeakerExtractionService()
mock_pipeline.from_pretrained.assert_called_once()
assert service.embedding_model is not None
def test_extract_reference_embedding_success(
self, speaker_extraction_service, mock_reference_audio
):
"""Test extracting embedding from valid reference clip"""
# Mock embedding extraction
mock_embedding = np.random.rand(512)
speaker_extraction_service.embedding_model.return_value = mock_embedding
with patch("src.services.speaker_extraction.read_audio") as mock_read:
mock_read.return_value = (np.random.rand(16000 * 5), 16000) # 5 seconds
embedding = speaker_extraction_service.extract_reference_embedding(
str(mock_reference_audio)
)
assert embedding is not None
assert len(embedding) == 512
mock_read.assert_called_once()
def test_extract_reference_embedding_too_short(
self, speaker_extraction_service, mock_reference_audio
):
"""Test that reference clip shorter than 3 seconds raises error"""
with patch("src.services.speaker_extraction.read_audio") as mock_read:
mock_read.return_value = (np.random.rand(16000 * 2), 16000) # 2 seconds
with pytest.raises(ValueError, match="too short"):
speaker_extraction_service.extract_reference_embedding(str(mock_reference_audio))
def test_extract_reference_embedding_low_quality(
self, speaker_extraction_service, mock_reference_audio
):
"""Test warning for low quality reference clip"""
with patch("src.services.speaker_extraction.read_audio") as mock_read:
# Very low amplitude audio (poor quality)
audio_data = np.random.rand(16000 * 5) * 0.01
mock_read.return_value = (audio_data, 16000)
mock_embedding = np.random.rand(512)
speaker_extraction_service.embedding_model.return_value = mock_embedding
with patch("src.services.speaker_extraction.logger") as mock_logger:
embedding = speaker_extraction_service.extract_reference_embedding(
str(mock_reference_audio)
)
# Should still return embedding but log warning
assert embedding is not None
assert any(
"quality" in str(call).lower() for call in mock_logger.warning.call_args_list
)
def test_extract_target_embeddings(self, speaker_extraction_service, mock_target_audio):
"""Test extracting embeddings from all segments in target audio"""
# Mock diarization result with multiple segments
mock_segments = [
AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH),
AudioSegment(4.0, 7.0, "SPEAKER_01", 0.85, SegmentType.SPEECH),
AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH),
]
with patch("src.services.speaker_extraction.read_audio") as mock_read:
mock_read.return_value = (np.random.rand(16000 * 15), 16000)
with patch.object(
speaker_extraction_service, "detect_voice_segments", return_value=mock_segments
):
mock_embedding = np.random.rand(512)
speaker_extraction_service.embedding_model.return_value = mock_embedding
segments_with_embeddings = speaker_extraction_service.extract_target_embeddings(
str(mock_target_audio)
)
assert len(segments_with_embeddings) == 3
for seg, emb in segments_with_embeddings:
assert isinstance(seg, AudioSegment)
assert len(emb) == 512
def test_compute_cosine_similarity(self, speaker_extraction_service):
"""Test cosine similarity calculation between embeddings"""
# Identical embeddings should have similarity close to 1.0
embedding1 = np.array([1.0, 0.0, 0.0])
embedding2 = np.array([1.0, 0.0, 0.0])
similarity = speaker_extraction_service.compute_similarity(embedding1, embedding2)
assert similarity == pytest.approx(1.0, abs=0.01)
# Orthogonal embeddings should have similarity close to 0.0
embedding3 = np.array([0.0, 1.0, 0.0])
similarity2 = speaker_extraction_service.compute_similarity(embedding1, embedding3)
assert similarity2 == pytest.approx(0.0, abs=0.01)
# Opposite embeddings should have negative similarity
embedding4 = np.array([-1.0, 0.0, 0.0])
similarity3 = speaker_extraction_service.compute_similarity(embedding1, embedding4)
assert similarity3 < 0
def test_match_segments_with_threshold(self, speaker_extraction_service):
"""Test segment matching with confidence threshold"""
reference_embedding = np.array([1.0, 0.0, 0.0])
segments_with_embeddings = [
(
AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH),
np.array([0.95, 0.1, 0.05]),
), # High match
(
AudioSegment(4.0, 7.0, "SPEAKER_01", 0.85, SegmentType.SPEECH),
np.array([0.2, 0.9, 0.1]),
), # Low match
(
AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH),
np.array([0.98, 0.05, 0.02]),
), # High match
]
with patch.object(
speaker_extraction_service,
"compute_similarity",
side_effect=[0.85, 0.25, 0.90], # Similarity scores
):
matched = speaker_extraction_service.match_segments(
reference_embedding, segments_with_embeddings, threshold=0.40, min_confidence=0.30
)
assert len(matched) == 2 # Only high match segments
assert matched[0][0].start_time == 0.0
assert matched[1][0].start_time == 8.0
def test_match_segments_min_confidence_filter(self, speaker_extraction_service):
"""Test that segments below min_confidence are filtered out"""
reference_embedding = np.array([1.0, 0.0, 0.0])
segments_with_embeddings = [
(
AudioSegment(0.0, 3.0, "SPEAKER_00", 0.95, SegmentType.SPEECH),
np.array([1.0, 0.0, 0.0]),
),
(
AudioSegment(4.0, 7.0, "SPEAKER_01", 0.25, SegmentType.SPEECH),
np.array([1.0, 0.0, 0.0]),
), # Low confidence
]
with patch.object(
speaker_extraction_service,
"compute_similarity",
return_value=0.95, # High similarity
):
matched = speaker_extraction_service.match_segments(
reference_embedding, segments_with_embeddings, threshold=0.40, min_confidence=0.30
)
assert len(matched) == 1 # Low confidence segment filtered
assert matched[0][0].confidence == 0.95
def test_validate_reference_clip_success(
self, speaker_extraction_service, mock_reference_audio
):
"""Test reference clip validation with good quality audio"""
with patch("src.services.speaker_extraction.read_audio") as mock_read:
# Good quality: 5 seconds, decent amplitude
audio_data = np.random.rand(16000 * 5) * 0.5
mock_read.return_value = (audio_data, 16000)
with patch("src.services.speaker_extraction.get_audio_duration") as mock_duration:
mock_duration.return_value = 5.0
is_valid, message = speaker_extraction_service.validate_reference_clip(
str(mock_reference_audio)
)
assert is_valid is True
assert "valid" in message.lower() or message == ""
def test_validate_reference_clip_too_short(
self, speaker_extraction_service, mock_reference_audio
):
"""Test reference clip validation fails for short clips"""
with patch("src.services.speaker_extraction.get_audio_duration") as mock_duration:
mock_duration.return_value = 2.0 # Too short
is_valid, message = speaker_extraction_service.validate_reference_clip(
str(mock_reference_audio)
)
assert is_valid is False
assert "short" in message.lower()
def test_extract_and_export_concatenated(
self, speaker_extraction_service, mock_reference_audio, mock_target_audio, tmp_path
):
"""Test end-to-end extraction with concatenation"""
output_file = tmp_path / "extracted.m4a"
mock_segments = [
AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH),
AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH),
]
with patch.object(
speaker_extraction_service,
"extract_reference_embedding",
return_value=np.random.rand(512),
):
with patch.object(
speaker_extraction_service,
"match_segments",
return_value=[(seg, 0.85) for seg in mock_segments],
):
with patch("src.services.speaker_extraction.read_audio") as mock_read:
mock_read.return_value = (np.random.rand(16000 * 15), 16000)
with patch("src.services.speaker_extraction.write_audio") as mock_write:
report = speaker_extraction_service.extract_and_export(
reference_clip=str(mock_reference_audio),
target_file=str(mock_target_audio),
output_path=str(output_file),
threshold=0.40,
concatenate=True,
)
assert report["segments_found"] >= 0
assert report["segments_included"] >= 0
assert "output_file" in report
mock_write.assert_called_once()
def test_generate_extraction_report(self, speaker_extraction_service):
"""Test extraction report generation"""
matched_segments = [
(AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH), 0.85),
(AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH), 0.78),
(AudioSegment(15.0, 18.0, "SPEAKER_00", 0.65, SegmentType.SPEECH), 0.42),
]
report = speaker_extraction_service.generate_extraction_report(
reference_clip="ref.m4a",
target_file="target.m4a",
threshold=0.40,
matched_segments=matched_segments,
processing_time=45.2,
output_file="output.m4a",
)
assert report["reference_clip"] == "ref.m4a"
assert report["target_file"] == "target.m4a"
assert report["threshold"] == 0.40
assert report["segments_found"] == 3
assert report["processing_time_seconds"] == 45.2
assert "average_confidence" in report
assert "low_confidence_segments" in report
class TestEmbeddingComparison:
"""Test suite for embedding comparison utilities"""
def test_cosine_similarity_normalized_vectors(self):
"""Test cosine similarity with normalized vectors"""
service = Mock()
service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service)
# Test with unit vectors
v1 = np.array([1.0, 0.0])
v2 = np.array([0.0, 1.0])
similarity = service.compute_similarity(v1, v2)
assert similarity == pytest.approx(0.0, abs=0.01)
def test_cosine_similarity_identical_vectors(self):
"""Test cosine similarity with identical vectors"""
service = Mock()
service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service)
v1 = np.random.rand(512)
similarity = service.compute_similarity(v1, v1)
assert similarity == pytest.approx(1.0, abs=0.01)
def test_cosine_similarity_opposite_vectors(self):
"""Test cosine similarity with opposite vectors"""
service = Mock()
service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service)
v1 = np.array([1.0, 2.0, 3.0])
v2 = np.array([-1.0, -2.0, -3.0])
similarity = service.compute_similarity(v1, v2)
assert similarity == pytest.approx(-1.0, abs=0.01)
def test_similarity_threshold_filtering(self):
"""Test that similarity threshold correctly filters segments"""
service = Mock()
service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service)
reference = np.array([1.0, 0.0])
# Different similarity levels
high_match = np.array([0.95, 0.31]) # ~0.95 similarity
low_match = np.array([0.5, 0.87]) # ~0.5 similarity
assert service.compute_similarity(reference, high_match) > 0.7
assert service.compute_similarity(reference, low_match) < 0.7