""" Unit tests for DiarizationService Tests speaker diarization functionality including: - Service initialization - Speaker merging logic - API endpoint integration """ import pytest from unittest.mock import Mock, patch, MagicMock import os class TestDiarizationService: """Tests for DiarizationService class""" @pytest.fixture def mock_env(self): """Mock environment with HF_TOKEN""" with patch.dict(os.environ, {"HF_TOKEN": "test_token"}): yield @pytest.fixture def service(self, mock_env): """Create DiarizationService instance""" with patch("app.services.diarization_service.torch") as mock_torch: mock_torch.cuda.is_available.return_value = False from app.services.diarization_service import DiarizationService return DiarizationService() def test_init_cpu_device(self, service): """Test service initializes with CPU when CUDA unavailable""" assert service.device == "cpu" assert service.compute_type == "int8" def test_check_requirements_missing_token(self): """Test check_requirements raises when HF_TOKEN missing""" with patch.dict(os.environ, {"HF_TOKEN": ""}): with patch("app.services.diarization_service.torch") as mock_torch: mock_torch.cuda.is_available.return_value = False from app.services.diarization_service import DiarizationService service = DiarizationService() with pytest.raises(ValueError) as exc: service.check_requirements() assert "HF_TOKEN" in str(exc.value) def test_check_requirements_with_token(self, service): """Test check_requirements passes with valid token""" # Should not raise service.check_requirements() class TestSpeakerMerging: """Tests for speaker-to-segment merging logic""" @pytest.fixture def mock_diarization(self): """Create mock pyannote diarization object""" mock = MagicMock() # Mock itertracks to return speaker segments mock.itertracks.return_value = [ (MockSegment(0.0, 2.0), None, "SPEAKER_00"), (MockSegment(2.0, 4.0), None, "SPEAKER_01"), (MockSegment(4.0, 6.0), None, "SPEAKER_00"), ] return mock def test_merge_speakers_midpoint_matching(self, mock_env, mock_diarization): """Test speaker merging uses midpoint matching""" with patch("app.services.diarization_service.torch") as mock_torch: mock_torch.cuda.is_available.return_value = False from app.services.diarization_service import DiarizationService service = DiarizationService() transcript = { "segments": [ {"start": 0.0, "end": 1.5, "text": "Hello"}, {"start": 2.5, "end": 3.5, "text": "World"}, {"start": 4.5, "end": 5.5, "text": "Goodbye"}, ], "language": "en" } result = service._merge_speakers(transcript, mock_diarization) assert len(result) == 3 assert result[0]["speaker"] == "SPEAKER_00" # 0.75 midpoint in 0-2 range assert result[1]["speaker"] == "SPEAKER_01" # 3.0 midpoint in 2-4 range assert result[2]["speaker"] == "SPEAKER_00" # 5.0 midpoint in 4-6 range def test_merge_speakers_preserves_text(self, mock_env, mock_diarization): """Test that original transcript text is preserved""" with patch("app.services.diarization_service.torch") as mock_torch: mock_torch.cuda.is_available.return_value = False from app.services.diarization_service import DiarizationService service = DiarizationService() transcript = { "segments": [ {"start": 0.0, "end": 1.0, "text": "Test text here"}, ], "language": "en" } result = service._merge_speakers(transcript, mock_diarization) assert result[0]["text"] == "Test text here" assert result[0]["start"] == 0.0 assert result[0]["end"] == 1.0 class MockSegment: """Mock pyannote Segment""" def __init__(self, start: float, end: float): self.start = start self.end = end class TestDiarizationAPI: """Integration tests for diarization API endpoint""" @pytest.fixture def client(self): """Create test client""" from fastapi.testclient import TestClient from app.main import app return TestClient(app) def test_diarize_endpoint_requires_file(self, client): """Test endpoint returns 422 when no file provided""" response = client.post("/api/v1/stt/upload/diarize") assert response.status_code == 422 def test_diarize_endpoint_accepts_parameters(self, client): """Test endpoint accepts speaker count parameters""" # Create a minimal audio file import io import wave # Create tiny WAV file buffer = io.BytesIO() with wave.open(buffer, 'wb') as wav: wav.setnchannels(1) wav.setsampwidth(2) wav.setframerate(16000) wav.writeframes(b'\x00' * 32000) # 1 second of silence buffer.seek(0) # This will fail without HF_TOKEN but should accept the request format response = client.post( "/api/v1/stt/upload/diarize", files={"file": ("test.wav", buffer, "audio/wav")}, data={ "num_speakers": 2, "min_speakers": 1, "max_speakers": 3, "language": "en" } ) # Should get 400 (missing token) or 500 (processing error), not 422 (validation) assert response.status_code in [400, 500] # Fixtures for mock environment @pytest.fixture def mock_env(): """Mock environment with HF_TOKEN""" with patch.dict(os.environ, {"HF_TOKEN": "test_token"}): yield