| """Tests audio STT/TTS.""" |
|
|
| import base64 |
| import sys |
| import types |
| from unittest.mock import patch |
|
|
| import pytest |
| from fastapi import HTTPException |
|
|
| from maris_core.audio.stt import SttRequest, transcribe |
| from maris_core.audio.tts import TtsRequest, synthesize |
| from maris_core.memory_context import ConversationMemoryStore |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_stt_requires_explicit_maris_model() -> None: |
| """Pārbauda, ka STT bez modeļa neatgriež fallback tekstu.""" |
| dummy_audio = base64.b64encode(b"dummy_audio_data").decode() |
| with patch.dict("os.environ", {"STT_MODEL": ""}, clear=False): |
| req = SttRequest(audio_base64=dummy_audio) |
| with pytest.raises(HTTPException) as exc_info: |
| await transcribe(req) |
| assert exc_info.value.status_code == 503 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_stt_returns_emotional_metadata() -> None: |
| dummy_audio = base64.b64encode(b"dummy_audio_data").decode() |
|
|
| with ( |
| patch.dict("os.environ", {"STT_MODEL": "MarisUK/test-stt"}, clear=False), |
| patch( |
| "maris_core.audio.stt._build_asr_pipeline", |
| return_value=lambda path: {"text": "Palīdzi, man ir panika"}, |
| ), |
| ): |
| response = await transcribe(SttRequest(audio_base64=dummy_audio)) |
|
|
| assert response.transcript == "Palīdzi, man ir panika" |
| assert response.detected_emotion == "distressed" |
| assert response.response_style == "empathetic_supportive" |
| assert response.emotion_confidence >= 0.7 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_stt_remembers_transcript_in_session_memory() -> None: |
| dummy_audio = base64.b64encode(b"dummy_audio_data").decode() |
| memory = ConversationMemoryStore() |
|
|
| with ( |
| patch.dict("os.environ", {"STT_MODEL": "MarisUK/test-stt"}, clear=False), |
| patch( |
| "maris_core.audio.stt._build_asr_pipeline", |
| return_value=lambda path: {"text": "Atceries šo balss pieprasījumu"}, |
| ), |
| patch("maris_core.audio.stt.memory_store", memory), |
| ): |
| await transcribe(SttRequest(audio_base64=dummy_audio, session_id="voice-session")) |
|
|
| matches = memory.retrieve_relevant_context("voice-session", "balss pieprasījumu") |
| assert matches |
| assert matches[0].content == "Atceries šo balss pieprasījumu" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_tts_requires_explicit_maris_model() -> None: |
| """Pārbauda, ka TTS bez modeļa neatgriež fallback audio.""" |
| with patch.dict("os.environ", {"TTS_MODEL": ""}, clear=False): |
| req = TtsRequest(text="Sveiki, es esmu Maris AI") |
| with pytest.raises(HTTPException) as exc_info: |
| await synthesize(req) |
| assert exc_info.value.status_code == 503 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_tts_remembers_assistant_response_in_session_memory() -> None: |
| memory = ConversationMemoryStore() |
|
|
| class FakeAudio: |
| def squeeze(self): |
| return [0, 1, 0, -1] |
|
|
| fake_transformers = types.SimpleNamespace( |
| pipeline=lambda *args, **kwargs: ( |
| lambda text: {"sampling_rate": 2, "audio": FakeAudio()} |
| ) |
| ) |
| fake_wavfile = types.SimpleNamespace(write=lambda *args, **kwargs: None) |
| fake_scipy = types.SimpleNamespace(io=types.SimpleNamespace(wavfile=fake_wavfile)) |
|
|
| with ( |
| patch.dict("os.environ", {"TTS_MODEL": "MarisUK/test-tts"}, clear=False), |
| patch("maris_core.audio.tts.memory_store", memory), |
| patch.dict( |
| sys.modules, |
| { |
| "transformers": fake_transformers, |
| "scipy": fake_scipy, |
| }, |
| ), |
| ): |
| await synthesize(TtsRequest(text="Šo atbildi vajag atcerēties", session_id="voice-session")) |
|
|
| matches = memory.retrieve_relevant_context("voice-session", "atbildi atcerēties") |
| assert matches |
| assert matches[0].content == "Šo atbildi vajag atcerēties" |
|
|