Spaces:
Paused
Paused
| """ | |
| Unit tests for SpeakerExtractionService | |
| Tests speaker extraction functionality including: | |
| - Embedding extraction from reference clips | |
| - Cosine similarity comparison | |
| - Segment matching and filtering | |
| - Reference clip validation | |
| """ | |
| from pathlib import Path | |
| from unittest.mock import MagicMock, Mock, patch | |
| import numpy as np | |
| import pytest | |
| from src.services.speaker_extraction import SpeakerExtractionService | |
| from src.models.audio_segment import AudioSegment, SegmentType | |
| from src.models.speaker_profile import SpeakerProfile | |
| def speaker_extraction_service(): | |
| """Create SpeakerExtractionService instance for testing""" | |
| with patch("src.services.speaker_extraction.Pipeline") as mock_pipeline: | |
| service = SpeakerExtractionService() | |
| service.embedding_model = Mock() | |
| return service | |
| def mock_reference_audio(tmp_path): | |
| """Create mock reference clip audio file""" | |
| ref_file = tmp_path / "reference.m4a" | |
| ref_file.write_bytes(b"mock audio data") | |
| return ref_file | |
| def mock_target_audio(tmp_path): | |
| """Create mock target audio file""" | |
| target_file = tmp_path / "target.m4a" | |
| target_file.write_bytes(b"mock audio data") | |
| return target_file | |
| class TestSpeakerExtractionService: | |
| """Test suite for SpeakerExtractionService""" | |
| def test_service_initialization(self): | |
| """Test that service initializes embedding model correctly""" | |
| with patch("src.services.speaker_extraction.Pipeline") as mock_pipeline: | |
| service = SpeakerExtractionService() | |
| mock_pipeline.from_pretrained.assert_called_once() | |
| assert service.embedding_model is not None | |
| def test_extract_reference_embedding_success( | |
| self, speaker_extraction_service, mock_reference_audio | |
| ): | |
| """Test extracting embedding from valid reference clip""" | |
| # Mock embedding extraction | |
| mock_embedding = np.random.rand(512) | |
| speaker_extraction_service.embedding_model.return_value = mock_embedding | |
| with patch("src.services.speaker_extraction.read_audio") as mock_read: | |
| mock_read.return_value = (np.random.rand(16000 * 5), 16000) # 5 seconds | |
| embedding = speaker_extraction_service.extract_reference_embedding( | |
| str(mock_reference_audio) | |
| ) | |
| assert embedding is not None | |
| assert len(embedding) == 512 | |
| mock_read.assert_called_once() | |
| def test_extract_reference_embedding_too_short( | |
| self, speaker_extraction_service, mock_reference_audio | |
| ): | |
| """Test that reference clip shorter than 3 seconds raises error""" | |
| with patch("src.services.speaker_extraction.read_audio") as mock_read: | |
| mock_read.return_value = (np.random.rand(16000 * 2), 16000) # 2 seconds | |
| with pytest.raises(ValueError, match="too short"): | |
| speaker_extraction_service.extract_reference_embedding(str(mock_reference_audio)) | |
| def test_extract_reference_embedding_low_quality( | |
| self, speaker_extraction_service, mock_reference_audio | |
| ): | |
| """Test warning for low quality reference clip""" | |
| with patch("src.services.speaker_extraction.read_audio") as mock_read: | |
| # Very low amplitude audio (poor quality) | |
| audio_data = np.random.rand(16000 * 5) * 0.01 | |
| mock_read.return_value = (audio_data, 16000) | |
| mock_embedding = np.random.rand(512) | |
| speaker_extraction_service.embedding_model.return_value = mock_embedding | |
| with patch("src.services.speaker_extraction.logger") as mock_logger: | |
| embedding = speaker_extraction_service.extract_reference_embedding( | |
| str(mock_reference_audio) | |
| ) | |
| # Should still return embedding but log warning | |
| assert embedding is not None | |
| assert any( | |
| "quality" in str(call).lower() for call in mock_logger.warning.call_args_list | |
| ) | |
| def test_extract_target_embeddings(self, speaker_extraction_service, mock_target_audio): | |
| """Test extracting embeddings from all segments in target audio""" | |
| # Mock diarization result with multiple segments | |
| mock_segments = [ | |
| AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH), | |
| AudioSegment(4.0, 7.0, "SPEAKER_01", 0.85, SegmentType.SPEECH), | |
| AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH), | |
| ] | |
| with patch("src.services.speaker_extraction.read_audio") as mock_read: | |
| mock_read.return_value = (np.random.rand(16000 * 15), 16000) | |
| with patch.object( | |
| speaker_extraction_service, "detect_voice_segments", return_value=mock_segments | |
| ): | |
| mock_embedding = np.random.rand(512) | |
| speaker_extraction_service.embedding_model.return_value = mock_embedding | |
| segments_with_embeddings = speaker_extraction_service.extract_target_embeddings( | |
| str(mock_target_audio) | |
| ) | |
| assert len(segments_with_embeddings) == 3 | |
| for seg, emb in segments_with_embeddings: | |
| assert isinstance(seg, AudioSegment) | |
| assert len(emb) == 512 | |
| def test_compute_cosine_similarity(self, speaker_extraction_service): | |
| """Test cosine similarity calculation between embeddings""" | |
| # Identical embeddings should have similarity close to 1.0 | |
| embedding1 = np.array([1.0, 0.0, 0.0]) | |
| embedding2 = np.array([1.0, 0.0, 0.0]) | |
| similarity = speaker_extraction_service.compute_similarity(embedding1, embedding2) | |
| assert similarity == pytest.approx(1.0, abs=0.01) | |
| # Orthogonal embeddings should have similarity close to 0.0 | |
| embedding3 = np.array([0.0, 1.0, 0.0]) | |
| similarity2 = speaker_extraction_service.compute_similarity(embedding1, embedding3) | |
| assert similarity2 == pytest.approx(0.0, abs=0.01) | |
| # Opposite embeddings should have negative similarity | |
| embedding4 = np.array([-1.0, 0.0, 0.0]) | |
| similarity3 = speaker_extraction_service.compute_similarity(embedding1, embedding4) | |
| assert similarity3 < 0 | |
| def test_match_segments_with_threshold(self, speaker_extraction_service): | |
| """Test segment matching with confidence threshold""" | |
| reference_embedding = np.array([1.0, 0.0, 0.0]) | |
| segments_with_embeddings = [ | |
| ( | |
| AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH), | |
| np.array([0.95, 0.1, 0.05]), | |
| ), # High match | |
| ( | |
| AudioSegment(4.0, 7.0, "SPEAKER_01", 0.85, SegmentType.SPEECH), | |
| np.array([0.2, 0.9, 0.1]), | |
| ), # Low match | |
| ( | |
| AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH), | |
| np.array([0.98, 0.05, 0.02]), | |
| ), # High match | |
| ] | |
| with patch.object( | |
| speaker_extraction_service, | |
| "compute_similarity", | |
| side_effect=[0.85, 0.25, 0.90], # Similarity scores | |
| ): | |
| matched = speaker_extraction_service.match_segments( | |
| reference_embedding, segments_with_embeddings, threshold=0.40, min_confidence=0.30 | |
| ) | |
| assert len(matched) == 2 # Only high match segments | |
| assert matched[0][0].start_time == 0.0 | |
| assert matched[1][0].start_time == 8.0 | |
| def test_match_segments_min_confidence_filter(self, speaker_extraction_service): | |
| """Test that segments below min_confidence are filtered out""" | |
| reference_embedding = np.array([1.0, 0.0, 0.0]) | |
| segments_with_embeddings = [ | |
| ( | |
| AudioSegment(0.0, 3.0, "SPEAKER_00", 0.95, SegmentType.SPEECH), | |
| np.array([1.0, 0.0, 0.0]), | |
| ), | |
| ( | |
| AudioSegment(4.0, 7.0, "SPEAKER_01", 0.25, SegmentType.SPEECH), | |
| np.array([1.0, 0.0, 0.0]), | |
| ), # Low confidence | |
| ] | |
| with patch.object( | |
| speaker_extraction_service, | |
| "compute_similarity", | |
| return_value=0.95, # High similarity | |
| ): | |
| matched = speaker_extraction_service.match_segments( | |
| reference_embedding, segments_with_embeddings, threshold=0.40, min_confidence=0.30 | |
| ) | |
| assert len(matched) == 1 # Low confidence segment filtered | |
| assert matched[0][0].confidence == 0.95 | |
| def test_validate_reference_clip_success( | |
| self, speaker_extraction_service, mock_reference_audio | |
| ): | |
| """Test reference clip validation with good quality audio""" | |
| with patch("src.services.speaker_extraction.read_audio") as mock_read: | |
| # Good quality: 5 seconds, decent amplitude | |
| audio_data = np.random.rand(16000 * 5) * 0.5 | |
| mock_read.return_value = (audio_data, 16000) | |
| with patch("src.services.speaker_extraction.get_audio_duration") as mock_duration: | |
| mock_duration.return_value = 5.0 | |
| is_valid, message = speaker_extraction_service.validate_reference_clip( | |
| str(mock_reference_audio) | |
| ) | |
| assert is_valid is True | |
| assert "valid" in message.lower() or message == "" | |
| def test_validate_reference_clip_too_short( | |
| self, speaker_extraction_service, mock_reference_audio | |
| ): | |
| """Test reference clip validation fails for short clips""" | |
| with patch("src.services.speaker_extraction.get_audio_duration") as mock_duration: | |
| mock_duration.return_value = 2.0 # Too short | |
| is_valid, message = speaker_extraction_service.validate_reference_clip( | |
| str(mock_reference_audio) | |
| ) | |
| assert is_valid is False | |
| assert "short" in message.lower() | |
| def test_extract_and_export_concatenated( | |
| self, speaker_extraction_service, mock_reference_audio, mock_target_audio, tmp_path | |
| ): | |
| """Test end-to-end extraction with concatenation""" | |
| output_file = tmp_path / "extracted.m4a" | |
| mock_segments = [ | |
| AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH), | |
| AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH), | |
| ] | |
| with patch.object( | |
| speaker_extraction_service, | |
| "extract_reference_embedding", | |
| return_value=np.random.rand(512), | |
| ): | |
| with patch.object( | |
| speaker_extraction_service, | |
| "match_segments", | |
| return_value=[(seg, 0.85) for seg in mock_segments], | |
| ): | |
| with patch("src.services.speaker_extraction.read_audio") as mock_read: | |
| mock_read.return_value = (np.random.rand(16000 * 15), 16000) | |
| with patch("src.services.speaker_extraction.write_audio") as mock_write: | |
| report = speaker_extraction_service.extract_and_export( | |
| reference_clip=str(mock_reference_audio), | |
| target_file=str(mock_target_audio), | |
| output_path=str(output_file), | |
| threshold=0.40, | |
| concatenate=True, | |
| ) | |
| assert report["segments_found"] >= 0 | |
| assert report["segments_included"] >= 0 | |
| assert "output_file" in report | |
| mock_write.assert_called_once() | |
| def test_generate_extraction_report(self, speaker_extraction_service): | |
| """Test extraction report generation""" | |
| matched_segments = [ | |
| (AudioSegment(0.0, 3.0, "SPEAKER_00", 0.9, SegmentType.SPEECH), 0.85), | |
| (AudioSegment(8.0, 11.0, "SPEAKER_00", 0.92, SegmentType.SPEECH), 0.78), | |
| (AudioSegment(15.0, 18.0, "SPEAKER_00", 0.65, SegmentType.SPEECH), 0.42), | |
| ] | |
| report = speaker_extraction_service.generate_extraction_report( | |
| reference_clip="ref.m4a", | |
| target_file="target.m4a", | |
| threshold=0.40, | |
| matched_segments=matched_segments, | |
| processing_time=45.2, | |
| output_file="output.m4a", | |
| ) | |
| assert report["reference_clip"] == "ref.m4a" | |
| assert report["target_file"] == "target.m4a" | |
| assert report["threshold"] == 0.40 | |
| assert report["segments_found"] == 3 | |
| assert report["processing_time_seconds"] == 45.2 | |
| assert "average_confidence" in report | |
| assert "low_confidence_segments" in report | |
| class TestEmbeddingComparison: | |
| """Test suite for embedding comparison utilities""" | |
| def test_cosine_similarity_normalized_vectors(self): | |
| """Test cosine similarity with normalized vectors""" | |
| service = Mock() | |
| service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service) | |
| # Test with unit vectors | |
| v1 = np.array([1.0, 0.0]) | |
| v2 = np.array([0.0, 1.0]) | |
| similarity = service.compute_similarity(v1, v2) | |
| assert similarity == pytest.approx(0.0, abs=0.01) | |
| def test_cosine_similarity_identical_vectors(self): | |
| """Test cosine similarity with identical vectors""" | |
| service = Mock() | |
| service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service) | |
| v1 = np.random.rand(512) | |
| similarity = service.compute_similarity(v1, v1) | |
| assert similarity == pytest.approx(1.0, abs=0.01) | |
| def test_cosine_similarity_opposite_vectors(self): | |
| """Test cosine similarity with opposite vectors""" | |
| service = Mock() | |
| service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service) | |
| v1 = np.array([1.0, 2.0, 3.0]) | |
| v2 = np.array([-1.0, -2.0, -3.0]) | |
| similarity = service.compute_similarity(v1, v2) | |
| assert similarity == pytest.approx(-1.0, abs=0.01) | |
| def test_similarity_threshold_filtering(self): | |
| """Test that similarity threshold correctly filters segments""" | |
| service = Mock() | |
| service.compute_similarity = SpeakerExtractionService.compute_similarity.__get__(service) | |
| reference = np.array([1.0, 0.0]) | |
| # Different similarity levels | |
| high_match = np.array([0.95, 0.31]) # ~0.95 similarity | |
| low_match = np.array([0.5, 0.87]) # ~0.5 similarity | |
| assert service.compute_similarity(reference, high_match) > 0.7 | |
| assert service.compute_similarity(reference, low_match) < 0.7 | |