File size: 6,333 Bytes
673435a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Unit tests for DiarizationService

Tests speaker diarization functionality including:
- Service initialization
- Speaker merging logic
- API endpoint integration
"""

import pytest
from unittest.mock import Mock, patch, MagicMock
import os


class TestDiarizationService:
    """Tests for DiarizationService class"""
    
    @pytest.fixture
    def mock_env(self):
        """Mock environment with HF_TOKEN"""
        with patch.dict(os.environ, {"HF_TOKEN": "test_token"}):
            yield
    
    @pytest.fixture
    def service(self, mock_env):
        """Create DiarizationService instance"""
        with patch("app.services.diarization_service.torch") as mock_torch:
            mock_torch.cuda.is_available.return_value = False
            from app.services.diarization_service import DiarizationService
            return DiarizationService()
    
    def test_init_cpu_device(self, service):
        """Test service initializes with CPU when CUDA unavailable"""
        assert service.device == "cpu"
        assert service.compute_type == "int8"
    
    def test_check_requirements_missing_token(self):
        """Test check_requirements raises when HF_TOKEN missing"""
        with patch.dict(os.environ, {"HF_TOKEN": ""}):
            with patch("app.services.diarization_service.torch") as mock_torch:
                mock_torch.cuda.is_available.return_value = False
                from app.services.diarization_service import DiarizationService
                service = DiarizationService()
                
                with pytest.raises(ValueError) as exc:
                    service.check_requirements()
                assert "HF_TOKEN" in str(exc.value)
    
    def test_check_requirements_with_token(self, service):
        """Test check_requirements passes with valid token"""
        # Should not raise
        service.check_requirements()


class TestSpeakerMerging:
    """Tests for speaker-to-segment merging logic"""
    
    @pytest.fixture
    def mock_diarization(self):
        """Create mock pyannote diarization object"""
        mock = MagicMock()
        
        # Mock itertracks to return speaker segments
        mock.itertracks.return_value = [
            (MockSegment(0.0, 2.0), None, "SPEAKER_00"),
            (MockSegment(2.0, 4.0), None, "SPEAKER_01"),
            (MockSegment(4.0, 6.0), None, "SPEAKER_00"),
        ]
        return mock
    
    def test_merge_speakers_midpoint_matching(self, mock_env, mock_diarization):
        """Test speaker merging uses midpoint matching"""
        with patch("app.services.diarization_service.torch") as mock_torch:
            mock_torch.cuda.is_available.return_value = False
            from app.services.diarization_service import DiarizationService
            
            service = DiarizationService()
            
            transcript = {
                "segments": [
                    {"start": 0.0, "end": 1.5, "text": "Hello"},
                    {"start": 2.5, "end": 3.5, "text": "World"},
                    {"start": 4.5, "end": 5.5, "text": "Goodbye"},
                ],
                "language": "en"
            }
            
            result = service._merge_speakers(transcript, mock_diarization)
            
            assert len(result) == 3
            assert result[0]["speaker"] == "SPEAKER_00"  # 0.75 midpoint in 0-2 range
            assert result[1]["speaker"] == "SPEAKER_01"  # 3.0 midpoint in 2-4 range
            assert result[2]["speaker"] == "SPEAKER_00"  # 5.0 midpoint in 4-6 range
    
    def test_merge_speakers_preserves_text(self, mock_env, mock_diarization):
        """Test that original transcript text is preserved"""
        with patch("app.services.diarization_service.torch") as mock_torch:
            mock_torch.cuda.is_available.return_value = False
            from app.services.diarization_service import DiarizationService
            
            service = DiarizationService()
            
            transcript = {
                "segments": [
                    {"start": 0.0, "end": 1.0, "text": "Test text here"},
                ],
                "language": "en"
            }
            
            result = service._merge_speakers(transcript, mock_diarization)
            
            assert result[0]["text"] == "Test text here"
            assert result[0]["start"] == 0.0
            assert result[0]["end"] == 1.0


class MockSegment:
    """Mock pyannote Segment"""
    def __init__(self, start: float, end: float):
        self.start = start
        self.end = end


class TestDiarizationAPI:
    """Integration tests for diarization API endpoint"""
    
    @pytest.fixture
    def client(self):
        """Create test client"""
        from fastapi.testclient import TestClient
        from app.main import app
        return TestClient(app)
    
    def test_diarize_endpoint_requires_file(self, client):
        """Test endpoint returns 422 when no file provided"""
        response = client.post("/api/v1/stt/upload/diarize")
        assert response.status_code == 422
    
    def test_diarize_endpoint_accepts_parameters(self, client):
        """Test endpoint accepts speaker count parameters"""
        # Create a minimal audio file
        import io
        import wave
        
        # Create tiny WAV file
        buffer = io.BytesIO()
        with wave.open(buffer, 'wb') as wav:
            wav.setnchannels(1)
            wav.setsampwidth(2)
            wav.setframerate(16000)
            wav.writeframes(b'\x00' * 32000)  # 1 second of silence
        
        buffer.seek(0)
        
        # This will fail without HF_TOKEN but should accept the request format
        response = client.post(
            "/api/v1/stt/upload/diarize",
            files={"file": ("test.wav", buffer, "audio/wav")},
            data={
                "num_speakers": 2,
                "min_speakers": 1,
                "max_speakers": 3,
                "language": "en"
            }
        )
        
        # Should get 400 (missing token) or 500 (processing error), not 422 (validation)
        assert response.status_code in [400, 500]


# Fixtures for mock environment
@pytest.fixture
def mock_env():
    """Mock environment with HF_TOKEN"""
    with patch.dict(os.environ, {"HF_TOKEN": "test_token"}):
        yield