Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for DiarizationService | |
| Tests speaker diarization functionality including: | |
| - Service initialization | |
| - Speaker merging logic | |
| - API endpoint integration | |
| """ | |
| import pytest | |
| from unittest.mock import Mock, patch, MagicMock | |
| import os | |
| class TestDiarizationService: | |
| """Tests for DiarizationService class""" | |
| def mock_env(self): | |
| """Mock environment with HF_TOKEN""" | |
| with patch.dict(os.environ, {"HF_TOKEN": "test_token"}): | |
| yield | |
| def service(self, mock_env): | |
| """Create DiarizationService instance""" | |
| with patch("app.services.diarization_service.torch") as mock_torch: | |
| mock_torch.cuda.is_available.return_value = False | |
| from app.services.diarization_service import DiarizationService | |
| return DiarizationService() | |
| def test_init_cpu_device(self, service): | |
| """Test service initializes with CPU when CUDA unavailable""" | |
| assert service.device == "cpu" | |
| assert service.compute_type == "int8" | |
| def test_check_requirements_missing_token(self): | |
| """Test check_requirements raises when HF_TOKEN missing""" | |
| with patch.dict(os.environ, {"HF_TOKEN": ""}): | |
| with patch("app.services.diarization_service.torch") as mock_torch: | |
| mock_torch.cuda.is_available.return_value = False | |
| from app.services.diarization_service import DiarizationService | |
| service = DiarizationService() | |
| with pytest.raises(ValueError) as exc: | |
| service.check_requirements() | |
| assert "HF_TOKEN" in str(exc.value) | |
| def test_check_requirements_with_token(self, service): | |
| """Test check_requirements passes with valid token""" | |
| # Should not raise | |
| service.check_requirements() | |
| class TestSpeakerMerging: | |
| """Tests for speaker-to-segment merging logic""" | |
| def mock_diarization(self): | |
| """Create mock pyannote diarization object""" | |
| mock = MagicMock() | |
| # Mock itertracks to return speaker segments | |
| mock.itertracks.return_value = [ | |
| (MockSegment(0.0, 2.0), None, "SPEAKER_00"), | |
| (MockSegment(2.0, 4.0), None, "SPEAKER_01"), | |
| (MockSegment(4.0, 6.0), None, "SPEAKER_00"), | |
| ] | |
| return mock | |
| def test_merge_speakers_midpoint_matching(self, mock_env, mock_diarization): | |
| """Test speaker merging uses midpoint matching""" | |
| with patch("app.services.diarization_service.torch") as mock_torch: | |
| mock_torch.cuda.is_available.return_value = False | |
| from app.services.diarization_service import DiarizationService | |
| service = DiarizationService() | |
| transcript = { | |
| "segments": [ | |
| {"start": 0.0, "end": 1.5, "text": "Hello"}, | |
| {"start": 2.5, "end": 3.5, "text": "World"}, | |
| {"start": 4.5, "end": 5.5, "text": "Goodbye"}, | |
| ], | |
| "language": "en" | |
| } | |
| result = service._merge_speakers(transcript, mock_diarization) | |
| assert len(result) == 3 | |
| assert result[0]["speaker"] == "SPEAKER_00" # 0.75 midpoint in 0-2 range | |
| assert result[1]["speaker"] == "SPEAKER_01" # 3.0 midpoint in 2-4 range | |
| assert result[2]["speaker"] == "SPEAKER_00" # 5.0 midpoint in 4-6 range | |
| def test_merge_speakers_preserves_text(self, mock_env, mock_diarization): | |
| """Test that original transcript text is preserved""" | |
| with patch("app.services.diarization_service.torch") as mock_torch: | |
| mock_torch.cuda.is_available.return_value = False | |
| from app.services.diarization_service import DiarizationService | |
| service = DiarizationService() | |
| transcript = { | |
| "segments": [ | |
| {"start": 0.0, "end": 1.0, "text": "Test text here"}, | |
| ], | |
| "language": "en" | |
| } | |
| result = service._merge_speakers(transcript, mock_diarization) | |
| assert result[0]["text"] == "Test text here" | |
| assert result[0]["start"] == 0.0 | |
| assert result[0]["end"] == 1.0 | |
| class MockSegment: | |
| """Mock pyannote Segment""" | |
| def __init__(self, start: float, end: float): | |
| self.start = start | |
| self.end = end | |
| class TestDiarizationAPI: | |
| """Integration tests for diarization API endpoint""" | |
| def client(self): | |
| """Create test client""" | |
| from fastapi.testclient import TestClient | |
| from app.main import app | |
| return TestClient(app) | |
| def test_diarize_endpoint_requires_file(self, client): | |
| """Test endpoint returns 422 when no file provided""" | |
| response = client.post("/api/v1/stt/upload/diarize") | |
| assert response.status_code == 422 | |
| def test_diarize_endpoint_accepts_parameters(self, client): | |
| """Test endpoint accepts speaker count parameters""" | |
| # Create a minimal audio file | |
| import io | |
| import wave | |
| # Create tiny WAV file | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, 'wb') as wav: | |
| wav.setnchannels(1) | |
| wav.setsampwidth(2) | |
| wav.setframerate(16000) | |
| wav.writeframes(b'\x00' * 32000) # 1 second of silence | |
| buffer.seek(0) | |
| # This will fail without HF_TOKEN but should accept the request format | |
| response = client.post( | |
| "/api/v1/stt/upload/diarize", | |
| files={"file": ("test.wav", buffer, "audio/wav")}, | |
| data={ | |
| "num_speakers": 2, | |
| "min_speakers": 1, | |
| "max_speakers": 3, | |
| "language": "en" | |
| } | |
| ) | |
| # Should get 400 (missing token) or 500 (processing error), not 422 (validation) | |
| assert response.status_code in [400, 500] | |
| # Fixtures for mock environment | |
| def mock_env(): | |
| """Mock environment with HF_TOKEN""" | |
| with patch.dict(os.environ, {"HF_TOKEN": "test_token"}): | |
| yield | |