voiceforge / backend /tests /integration /test_diarization.py
lordofgaming
Initial VoiceForge deployment (clean)
673435a
"""
Unit tests for DiarizationService
Tests speaker diarization functionality including:
- Service initialization
- Speaker merging logic
- API endpoint integration
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
import os
class TestDiarizationService:
"""Tests for DiarizationService class"""
@pytest.fixture
def mock_env(self):
"""Mock environment with HF_TOKEN"""
with patch.dict(os.environ, {"HF_TOKEN": "test_token"}):
yield
@pytest.fixture
def service(self, mock_env):
"""Create DiarizationService instance"""
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
return DiarizationService()
def test_init_cpu_device(self, service):
"""Test service initializes with CPU when CUDA unavailable"""
assert service.device == "cpu"
assert service.compute_type == "int8"
def test_check_requirements_missing_token(self):
"""Test check_requirements raises when HF_TOKEN missing"""
with patch.dict(os.environ, {"HF_TOKEN": ""}):
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
service = DiarizationService()
with pytest.raises(ValueError) as exc:
service.check_requirements()
assert "HF_TOKEN" in str(exc.value)
def test_check_requirements_with_token(self, service):
"""Test check_requirements passes with valid token"""
# Should not raise
service.check_requirements()
class TestSpeakerMerging:
"""Tests for speaker-to-segment merging logic"""
@pytest.fixture
def mock_diarization(self):
"""Create mock pyannote diarization object"""
mock = MagicMock()
# Mock itertracks to return speaker segments
mock.itertracks.return_value = [
(MockSegment(0.0, 2.0), None, "SPEAKER_00"),
(MockSegment(2.0, 4.0), None, "SPEAKER_01"),
(MockSegment(4.0, 6.0), None, "SPEAKER_00"),
]
return mock
def test_merge_speakers_midpoint_matching(self, mock_env, mock_diarization):
"""Test speaker merging uses midpoint matching"""
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
service = DiarizationService()
transcript = {
"segments": [
{"start": 0.0, "end": 1.5, "text": "Hello"},
{"start": 2.5, "end": 3.5, "text": "World"},
{"start": 4.5, "end": 5.5, "text": "Goodbye"},
],
"language": "en"
}
result = service._merge_speakers(transcript, mock_diarization)
assert len(result) == 3
assert result[0]["speaker"] == "SPEAKER_00" # 0.75 midpoint in 0-2 range
assert result[1]["speaker"] == "SPEAKER_01" # 3.0 midpoint in 2-4 range
assert result[2]["speaker"] == "SPEAKER_00" # 5.0 midpoint in 4-6 range
def test_merge_speakers_preserves_text(self, mock_env, mock_diarization):
"""Test that original transcript text is preserved"""
with patch("app.services.diarization_service.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
from app.services.diarization_service import DiarizationService
service = DiarizationService()
transcript = {
"segments": [
{"start": 0.0, "end": 1.0, "text": "Test text here"},
],
"language": "en"
}
result = service._merge_speakers(transcript, mock_diarization)
assert result[0]["text"] == "Test text here"
assert result[0]["start"] == 0.0
assert result[0]["end"] == 1.0
class MockSegment:
"""Mock pyannote Segment"""
def __init__(self, start: float, end: float):
self.start = start
self.end = end
class TestDiarizationAPI:
"""Integration tests for diarization API endpoint"""
@pytest.fixture
def client(self):
"""Create test client"""
from fastapi.testclient import TestClient
from app.main import app
return TestClient(app)
def test_diarize_endpoint_requires_file(self, client):
"""Test endpoint returns 422 when no file provided"""
response = client.post("/api/v1/stt/upload/diarize")
assert response.status_code == 422
def test_diarize_endpoint_accepts_parameters(self, client):
"""Test endpoint accepts speaker count parameters"""
# Create a minimal audio file
import io
import wave
# Create tiny WAV file
buffer = io.BytesIO()
with wave.open(buffer, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(16000)
wav.writeframes(b'\x00' * 32000) # 1 second of silence
buffer.seek(0)
# This will fail without HF_TOKEN but should accept the request format
response = client.post(
"/api/v1/stt/upload/diarize",
files={"file": ("test.wav", buffer, "audio/wav")},
data={
"num_speakers": 2,
"min_speakers": 1,
"max_speakers": 3,
"language": "en"
}
)
# Should get 400 (missing token) or 500 (processing error), not 422 (validation)
assert response.status_code in [400, 500]
# Fixtures for mock environment
@pytest.fixture
def mock_env():
"""Mock environment with HF_TOKEN"""
with patch.dict(os.environ, {"HF_TOKEN": "test_token"}):
yield