File size: 8,758 Bytes
d5fb3e9
0bcdd07
d5fb3e9
 
 
 
 
 
284dfa9
0bcdd07
 
 
 
d5fb3e9
 
 
 
 
 
 
0bcdd07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5fb3e9
 
 
 
 
 
 
 
0bcdd07
d5fb3e9
0bcdd07
 
284dfa9
 
0bcdd07
 
 
d5fb3e9
 
284dfa9
 
d5fb3e9
0bcdd07
d5fb3e9
0bcdd07
284dfa9
0bcdd07
 
284dfa9
0bcdd07
284dfa9
0bcdd07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84c39c6
0bcdd07
 
84c39c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bcdd07
 
 
284dfa9
 
0bcdd07
284dfa9
 
0bcdd07
 
84c39c6
0bcdd07
 
 
 
 
 
284dfa9
0bcdd07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import time
os.environ["MOCK_LLM"] = "true"

import pytest
from httpx import AsyncClient, ASGITransport

from app.main import app
from app.schemas import ClinicalBrief
from app.llm import MockLLM, CombinedOutput


# ─────────────────────── fixtures ───────────────────────

@pytest.fixture
async def client():
    transport = ASGITransport(app=app)
    async with AsyncClient(transport=transport, base_url="http://test") as c:
        yield c


# ────────────────────── unit tests ──────────────────────

def test_mock_llm_combined_call_basic_extraction():
    """MockLLM should extract chief complaint, onset and location in one call."""
    llm = MockLLM()
    transcript = "Patient: I have chest pain since yesterday\nAI: Where is it?\nPatient: Center of my chest"
    result = llm.combined_call(transcript, CombinedOutput().model_dump_json())
    assert result.chief_complaint == "chest pain"
    assert result.onset == "yesterday"
    assert result.location == "center of chest"
    assert result.reply  # Should ask the next missing question


def test_mock_llm_emergency_detection():
    """MockLLM should detect emergency keywords and set emergency=True."""
    llm = MockLLM()
    transcript = "Patient: I am having crushing chest pain"
    result = llm.combined_call(transcript, CombinedOutput().model_dump_json())
    assert result.emergency is True


def test_mock_llm_does_not_repeat_filled_questions():
    """If onset is already known, the next question should NOT ask about onset again."""
    llm = MockLLM()
    current = CombinedOutput(chief_complaint="chest pain", onset="yesterday").model_dump_json()
    transcript = "Patient: chest pain yesterday\nAI: ok\nPatient: anything new"
    result = llm.combined_call(transcript, current)
    assert result.onset == "yesterday"              # Should be preserved
    assert "when" not in result.reply.lower()       # Should not re-ask onset


def test_mock_llm_severity_extraction():
    """Severity from different phrasings should always normalize to X/10."""
    llm = MockLLM()
    for phrase, expected in [
        ("it is a 7 out of 10", "7/10"),
        ("about 8 on the scale", None),     # may not extract without explicit context
        ("i'd say 9 on a scale", None),
    ]:
        state = CombinedOutput(
            chief_complaint="chest pain", onset="yesterday",
            location="chest", duration="constant", character="tight"
        ).model_dump_json()
        result = llm.combined_call(f"Patient: {phrase}", state)
        if expected:
            assert result.severity == expected, f"Failed for: '{phrase}'"


def test_mock_llm_ros_extraction():
    """ROS should populate correctly when patient mentions system symptoms."""
    llm = MockLLM()
    full_hpi = CombinedOutput(
        chief_complaint="chest pain", onset="yesterday", location="center of chest",
        duration="constant", character="tight", severity="7/10",
        aggravating="walking", relieving="resting"
    ).model_dump_json()
    result = llm.combined_call("Patient: palpitations present no leg swelling", full_hpi)
    assert "cardiac" in result.ros
    
    result2 = llm.combined_call("Patient: mild shortness of breath", full_hpi)
    assert "respiratory" in result2.ros


def test_mock_llm_speed():
    """
    MockLLM combined_call must complete under 100ms per call.
    (Real LLM test is separate β€” this validates no accidental model load in mock mode.)
    """
    llm = MockLLM()
    state = CombinedOutput().model_dump_json()
    
    times = []
    for _ in range(5):
        t0 = time.perf_counter()
        llm.combined_call("Patient: I have chest pain since this morning in the center of my chest", state)
        times.append(time.perf_counter() - t0)
    
    avg_ms = (sum(times) / len(times)) * 1000
    print(f"\n[speed] MockLLM avg combined_call: {avg_ms:.1f}ms")
    assert avg_ms < 100, f"MockLLM too slow: {avg_ms:.1f}ms avg (should be <100ms)"


def test_combined_output_schema_round_trip():
    """CombinedOutput must survive JSON round-trip without data loss."""
    original = CombinedOutput(
        chief_complaint="headache",
        onset="3 days ago",
        location="forehead",
        duration="constant",
        character="throbbing",
        severity="6/10",
        aggravating="bright light",
        relieving="dark room",
        ros={"neuro": ["dizziness"], "ent": ["no ear pain"]},
        emergency=False,
        reply="Any vision changes?",
    )
    json_str = original.model_dump_json()
    restored = CombinedOutput.model_validate_json(json_str)
    assert restored.chief_complaint == "headache"
    assert restored.severity == "6/10"
    assert restored.ros["neuro"] == ["dizziness"]
    assert restored.reply == "Any vision changes?"


# ───────────────────── API integration tests ─────────────────────

@pytest.mark.asyncio(loop_scope="function")
async def test_health_endpoint(client):
    response = await client.get("/health")
    assert response.status_code == 200
    data = response.json()
    assert data["status"] == "ok"
    assert data["mock_mode"] is True


@pytest.mark.asyncio(loop_scope="function")
async def test_emergency_triage_node(client):
    """Emergency phrase should bypass agent and return 911 message immediately."""
    session_id = "test_emergency"
    await client.post("/chat", json={"session_id": session_id, "message": "hello"})
    response = await client.post(
        "/chat", json={"session_id": session_id, "message": "I am having crushing chest pain"}
    )
    assert response.status_code == 200
    data = response.json()
    assert data["state"] == "done"
    assert "911" in data["reply"] or "emergency" in data["reply"].lower()


@pytest.mark.asyncio(loop_scope="function")
async def test_full_intake_multi_turn_extraction(client):
    """
    The agent should extract multiple fields per message and skip already-answered questions.
    After 3 messages that collectively answer all HPI fields + 3 ROS systems, state should be 'done'.
    """
    session_id = "test_multi_extract"
    
    # Kick-off
    r = await client.post("/chat", json={"session_id": session_id, "message": "hello"})
    assert r.status_code == 200

    # Message 1: CC + onset + location
    r = await client.post("/chat", json={
        "session_id": session_id,
        "message": "I have chest pain since yesterday in the center of my chest"
    })
    data = r.json()
    assert data["state"] in ("intake", "hpi")

    # Message 2: duration + character + severity + aggravating + relieving
    r = await client.post("/chat", json={
        "session_id": session_id,
        "message": "It is constant, tight and squeezing, about a 7 out of 10. Walking worsens it and resting helps."
    })
    data = r.json()
    assert data["state"] in ("hpi", "ros")

    # Message 3: ROS cardiac
    r = await client.post("/chat", json={
        "session_id": session_id,
        "message": "I have palpitations"
    })
    data = r.json()
    assert data["state"] == "ros"

    # Message 4: ROS respiratory
    r = await client.post("/chat", json={
        "session_id": session_id,
        "message": "mild shortness of breath"
    })
    data = r.json()
    assert data["state"] == "ros"

    # Message 5: ROS GI
    r = await client.post("/chat", json={
        "session_id": session_id,
        "message": "no nausea"
    })
    data = r.json()
    # Should be done now
    assert data["state"] == "done"
    assert data["brief"] is not None

    brief = ClinicalBrief.model_validate(data["brief"])
    assert brief.chief_complaint == "chest pain"
    assert brief.hpi.onset is not None
    assert brief.hpi.severity is not None
    assert len(brief.ros) >= 3


@pytest.mark.asyncio(loop_scope="function")
async def test_api_response_time(client):
    """API with MockLLM must respond in under 2 seconds per message."""
    session_id = "test_speed_api"
    
    times = []
    messages = [
        "hello",
        "I have a headache since this morning",
        "It is on the left side of my head",
    ]
    for msg in messages:
        t0 = time.perf_counter()
        r = await client.post("/chat", json={"session_id": session_id, "message": msg})
        elapsed = time.perf_counter() - t0
        times.append(elapsed)
        assert r.status_code == 200

    avg_s = sum(times) / len(times)
    print(f"\n[speed] API avg response: {avg_s*1000:.0f}ms")
    assert avg_s < 2.0, f"API too slow: {avg_s:.2f}s avg (should be <2s in mock mode)"