File size: 3,928 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Tests audio STT/TTS."""

import base64
import sys
import types
from unittest.mock import patch

import pytest
from fastapi import HTTPException

from maris_core.audio.stt import SttRequest, transcribe
from maris_core.audio.tts import TtsRequest, synthesize
from maris_core.memory_context import ConversationMemoryStore


@pytest.mark.asyncio
async def test_stt_requires_explicit_maris_model() -> None:
    """Pārbauda, ka STT bez modeļa neatgriež fallback tekstu."""
    dummy_audio = base64.b64encode(b"dummy_audio_data").decode()
    with patch.dict("os.environ", {"STT_MODEL": ""}, clear=False):
        req = SttRequest(audio_base64=dummy_audio)
        with pytest.raises(HTTPException) as exc_info:
            await transcribe(req)
        assert exc_info.value.status_code == 503


@pytest.mark.asyncio
async def test_stt_returns_emotional_metadata() -> None:
    dummy_audio = base64.b64encode(b"dummy_audio_data").decode()

    with (
        patch.dict("os.environ", {"STT_MODEL": "MarisUK/test-stt"}, clear=False),
        patch(
            "maris_core.audio.stt._build_asr_pipeline",
            return_value=lambda path: {"text": "Palīdzi, man ir panika"},
        ),
    ):
        response = await transcribe(SttRequest(audio_base64=dummy_audio))

    assert response.transcript == "Palīdzi, man ir panika"
    assert response.detected_emotion == "distressed"
    assert response.response_style == "empathetic_supportive"
    assert response.emotion_confidence >= 0.7


@pytest.mark.asyncio
async def test_stt_remembers_transcript_in_session_memory() -> None:
    dummy_audio = base64.b64encode(b"dummy_audio_data").decode()
    memory = ConversationMemoryStore()

    with (
        patch.dict("os.environ", {"STT_MODEL": "MarisUK/test-stt"}, clear=False),
        patch(
            "maris_core.audio.stt._build_asr_pipeline",
            return_value=lambda path: {"text": "Atceries šo balss pieprasījumu"},
        ),
        patch("maris_core.audio.stt.memory_store", memory),
    ):
        await transcribe(SttRequest(audio_base64=dummy_audio, session_id="voice-session"))

    matches = memory.retrieve_relevant_context("voice-session", "balss pieprasījumu")
    assert matches
    assert matches[0].content == "Atceries šo balss pieprasījumu"


@pytest.mark.asyncio
async def test_tts_requires_explicit_maris_model() -> None:
    """Pārbauda, ka TTS bez modeļa neatgriež fallback audio."""
    with patch.dict("os.environ", {"TTS_MODEL": ""}, clear=False):
        req = TtsRequest(text="Sveiki, es esmu Maris AI")
        with pytest.raises(HTTPException) as exc_info:
            await synthesize(req)
        assert exc_info.value.status_code == 503


@pytest.mark.asyncio
async def test_tts_remembers_assistant_response_in_session_memory() -> None:
    memory = ConversationMemoryStore()

    class FakeAudio:
        def squeeze(self):  # type: ignore[no-untyped-def]
            return [0, 1, 0, -1]

    fake_transformers = types.SimpleNamespace(
        pipeline=lambda *args, **kwargs: (  # type: ignore[no-untyped-def]
            lambda text: {"sampling_rate": 2, "audio": FakeAudio()}
        )
    )
    fake_wavfile = types.SimpleNamespace(write=lambda *args, **kwargs: None)
    fake_scipy = types.SimpleNamespace(io=types.SimpleNamespace(wavfile=fake_wavfile))

    with (
        patch.dict("os.environ", {"TTS_MODEL": "MarisUK/test-tts"}, clear=False),
        patch("maris_core.audio.tts.memory_store", memory),
        patch.dict(
            sys.modules,
            {
                "transformers": fake_transformers,
                "scipy": fake_scipy,
            },
        ),
    ):
        await synthesize(TtsRequest(text="Šo atbildi vajag atcerēties", session_id="voice-session"))

    matches = memory.retrieve_relevant_context("voice-session", "atbildi atcerēties")
    assert matches
    assert matches[0].content == "Šo atbildi vajag atcerēties"