File size: 3,928 Bytes
f440f03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | """Tests audio STT/TTS."""
import base64
import sys
import types
from unittest.mock import patch
import pytest
from fastapi import HTTPException
from maris_core.audio.stt import SttRequest, transcribe
from maris_core.audio.tts import TtsRequest, synthesize
from maris_core.memory_context import ConversationMemoryStore
@pytest.mark.asyncio
async def test_stt_requires_explicit_maris_model() -> None:
"""Pārbauda, ka STT bez modeļa neatgriež fallback tekstu."""
dummy_audio = base64.b64encode(b"dummy_audio_data").decode()
with patch.dict("os.environ", {"STT_MODEL": ""}, clear=False):
req = SttRequest(audio_base64=dummy_audio)
with pytest.raises(HTTPException) as exc_info:
await transcribe(req)
assert exc_info.value.status_code == 503
@pytest.mark.asyncio
async def test_stt_returns_emotional_metadata() -> None:
dummy_audio = base64.b64encode(b"dummy_audio_data").decode()
with (
patch.dict("os.environ", {"STT_MODEL": "MarisUK/test-stt"}, clear=False),
patch(
"maris_core.audio.stt._build_asr_pipeline",
return_value=lambda path: {"text": "Palīdzi, man ir panika"},
),
):
response = await transcribe(SttRequest(audio_base64=dummy_audio))
assert response.transcript == "Palīdzi, man ir panika"
assert response.detected_emotion == "distressed"
assert response.response_style == "empathetic_supportive"
assert response.emotion_confidence >= 0.7
@pytest.mark.asyncio
async def test_stt_remembers_transcript_in_session_memory() -> None:
dummy_audio = base64.b64encode(b"dummy_audio_data").decode()
memory = ConversationMemoryStore()
with (
patch.dict("os.environ", {"STT_MODEL": "MarisUK/test-stt"}, clear=False),
patch(
"maris_core.audio.stt._build_asr_pipeline",
return_value=lambda path: {"text": "Atceries šo balss pieprasījumu"},
),
patch("maris_core.audio.stt.memory_store", memory),
):
await transcribe(SttRequest(audio_base64=dummy_audio, session_id="voice-session"))
matches = memory.retrieve_relevant_context("voice-session", "balss pieprasījumu")
assert matches
assert matches[0].content == "Atceries šo balss pieprasījumu"
@pytest.mark.asyncio
async def test_tts_requires_explicit_maris_model() -> None:
"""Pārbauda, ka TTS bez modeļa neatgriež fallback audio."""
with patch.dict("os.environ", {"TTS_MODEL": ""}, clear=False):
req = TtsRequest(text="Sveiki, es esmu Maris AI")
with pytest.raises(HTTPException) as exc_info:
await synthesize(req)
assert exc_info.value.status_code == 503
@pytest.mark.asyncio
async def test_tts_remembers_assistant_response_in_session_memory() -> None:
memory = ConversationMemoryStore()
class FakeAudio:
def squeeze(self): # type: ignore[no-untyped-def]
return [0, 1, 0, -1]
fake_transformers = types.SimpleNamespace(
pipeline=lambda *args, **kwargs: ( # type: ignore[no-untyped-def]
lambda text: {"sampling_rate": 2, "audio": FakeAudio()}
)
)
fake_wavfile = types.SimpleNamespace(write=lambda *args, **kwargs: None)
fake_scipy = types.SimpleNamespace(io=types.SimpleNamespace(wavfile=fake_wavfile))
with (
patch.dict("os.environ", {"TTS_MODEL": "MarisUK/test-tts"}, clear=False),
patch("maris_core.audio.tts.memory_store", memory),
patch.dict(
sys.modules,
{
"transformers": fake_transformers,
"scipy": fake_scipy,
},
),
):
await synthesize(TtsRequest(text="Šo atbildi vajag atcerēties", session_id="voice-session"))
matches = memory.retrieve_relevant_context("voice-session", "atbildi atcerēties")
assert matches
assert matches[0].content == "Šo atbildi vajag atcerēties"
|