Spaces:
Sleeping
Sleeping
File size: 6,333 Bytes
673435a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
Unit tests for DiarizationService
Tests speaker diarization functionality including:
- Service initialization
- Speaker merging logic
- API endpoint integration
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
import os
class TestDiarizationService:
"""Tests for DiarizationService class"""
@pytest.fixture
def mock_env(self):
"""Mock environment with HF_TOKEN"""
with patch.dict(os.environ, {"HF_TOKEN": "test_token"}):
yield
@pytest.fixture
def service(self, mock_env):
"""Create DiarizationService instance"""
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
return DiarizationService()
def test_init_cpu_device(self, service):
"""Test service initializes with CPU when CUDA unavailable"""
assert service.device == "cpu"
assert service.compute_type == "int8"
def test_check_requirements_missing_token(self):
"""Test check_requirements raises when HF_TOKEN missing"""
with patch.dict(os.environ, {"HF_TOKEN": ""}):
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
service = DiarizationService()
with pytest.raises(ValueError) as exc:
service.check_requirements()
assert "HF_TOKEN" in str(exc.value)
def test_check_requirements_with_token(self, service):
"""Test check_requirements passes with valid token"""
# Should not raise
service.check_requirements()
class TestSpeakerMerging:
"""Tests for speaker-to-segment merging logic"""
@pytest.fixture
def mock_diarization(self):
"""Create mock pyannote diarization object"""
mock = MagicMock()
# Mock itertracks to return speaker segments
mock.itertracks.return_value = [
(MockSegment(0.0, 2.0), None, "SPEAKER_00"),
(MockSegment(2.0, 4.0), None, "SPEAKER_01"),
(MockSegment(4.0, 6.0), None, "SPEAKER_00"),
]
return mock
def test_merge_speakers_midpoint_matching(self, mock_env, mock_diarization):
"""Test speaker merging uses midpoint matching"""
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
service = DiarizationService()
transcript = {
"segments": [
{"start": 0.0, "end": 1.5, "text": "Hello"},
{"start": 2.5, "end": 3.5, "text": "World"},
{"start": 4.5, "end": 5.5, "text": "Goodbye"},
],
"language": "en"
}
result = service._merge_speakers(transcript, mock_diarization)
assert len(result) == 3
assert result[0]["speaker"] == "SPEAKER_00" # 0.75 midpoint in 0-2 range
assert result[1]["speaker"] == "SPEAKER_01" # 3.0 midpoint in 2-4 range
assert result[2]["speaker"] == "SPEAKER_00" # 5.0 midpoint in 4-6 range
def test_merge_speakers_preserves_text(self, mock_env, mock_diarization):
"""Test that original transcript text is preserved"""
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
service = DiarizationService()
transcript = {
"segments": [
{"start": 0.0, "end": 1.0, "text": "Test text here"},
],
"language": "en"
}
result = service._merge_speakers(transcript, mock_diarization)
assert result[0]["text"] == "Test text here"
assert result[0]["start"] == 0.0
assert result[0]["end"] == 1.0
class MockSegment:
"""Mock pyannote Segment"""
def __init__(self, start: float, end: float):
self.start = start
self.end = end
class TestDiarizationAPI:
"""Integration tests for diarization API endpoint"""
@pytest.fixture
def client(self):
"""Create test client"""
from fastapi.testclient import TestClient
from app.main import app
return TestClient(app)
def test_diarize_endpoint_requires_file(self, client):
"""Test endpoint returns 422 when no file provided"""
response = client.post("/api/v1/stt/upload/diarize")
assert response.status_code == 422
def test_diarize_endpoint_accepts_parameters(self, client):
"""Test endpoint accepts speaker count parameters"""
# Create a minimal audio file
import io
import wave
# Create tiny WAV file
buffer = io.BytesIO()
with wave.open(buffer, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(16000)
wav.writeframes(b'\x00' * 32000) # 1 second of silence
buffer.seek(0)
# This will fail without HF_TOKEN but should accept the request format
response = client.post(
"/api/v1/stt/upload/diarize",
files={"file": ("test.wav", buffer, "audio/wav")},
data={
"num_speakers": 2,
"min_speakers": 1,
"max_speakers": 3,
"language": "en"
}
)
# Should get 400 (missing token) or 500 (processing error), not 422 (validation)
assert response.status_code in [400, 500]
# Fixtures for mock environment
@pytest.fixture
def mock_env():
"""Mock environment with HF_TOKEN"""
with patch.dict(os.environ, {"HF_TOKEN": "test_token"}):
yield
|