Spaces:
Sleeping
Sleeping
File size: 8,758 Bytes
d5fb3e9 0bcdd07 d5fb3e9 284dfa9 0bcdd07 d5fb3e9 0bcdd07 d5fb3e9 0bcdd07 d5fb3e9 0bcdd07 284dfa9 0bcdd07 d5fb3e9 284dfa9 d5fb3e9 0bcdd07 d5fb3e9 0bcdd07 284dfa9 0bcdd07 284dfa9 0bcdd07 284dfa9 0bcdd07 84c39c6 0bcdd07 84c39c6 0bcdd07 284dfa9 0bcdd07 284dfa9 0bcdd07 84c39c6 0bcdd07 284dfa9 0bcdd07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | import os
import time
os.environ["MOCK_LLM"] = "true"
import pytest
from httpx import AsyncClient, ASGITransport
from app.main import app
from app.schemas import ClinicalBrief
from app.llm import MockLLM, CombinedOutput
# βββββββββββββββββββββββ fixtures βββββββββββββββββββββββ
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# ββββββββββββββββββββββ unit tests ββββββββββββββββββββββ
def test_mock_llm_combined_call_basic_extraction():
"""MockLLM should extract chief complaint, onset and location in one call."""
llm = MockLLM()
transcript = "Patient: I have chest pain since yesterday\nAI: Where is it?\nPatient: Center of my chest"
result = llm.combined_call(transcript, CombinedOutput().model_dump_json())
assert result.chief_complaint == "chest pain"
assert result.onset == "yesterday"
assert result.location == "center of chest"
assert result.reply # Should ask the next missing question
def test_mock_llm_emergency_detection():
"""MockLLM should detect emergency keywords and set emergency=True."""
llm = MockLLM()
transcript = "Patient: I am having crushing chest pain"
result = llm.combined_call(transcript, CombinedOutput().model_dump_json())
assert result.emergency is True
def test_mock_llm_does_not_repeat_filled_questions():
"""If onset is already known, the next question should NOT ask about onset again."""
llm = MockLLM()
current = CombinedOutput(chief_complaint="chest pain", onset="yesterday").model_dump_json()
transcript = "Patient: chest pain yesterday\nAI: ok\nPatient: anything new"
result = llm.combined_call(transcript, current)
assert result.onset == "yesterday" # Should be preserved
assert "when" not in result.reply.lower() # Should not re-ask onset
def test_mock_llm_severity_extraction():
"""Severity from different phrasings should always normalize to X/10."""
llm = MockLLM()
for phrase, expected in [
("it is a 7 out of 10", "7/10"),
("about 8 on the scale", None), # may not extract without explicit context
("i'd say 9 on a scale", None),
]:
state = CombinedOutput(
chief_complaint="chest pain", onset="yesterday",
location="chest", duration="constant", character="tight"
).model_dump_json()
result = llm.combined_call(f"Patient: {phrase}", state)
if expected:
assert result.severity == expected, f"Failed for: '{phrase}'"
def test_mock_llm_ros_extraction():
"""ROS should populate correctly when patient mentions system symptoms."""
llm = MockLLM()
full_hpi = CombinedOutput(
chief_complaint="chest pain", onset="yesterday", location="center of chest",
duration="constant", character="tight", severity="7/10",
aggravating="walking", relieving="resting"
).model_dump_json()
result = llm.combined_call("Patient: palpitations present no leg swelling", full_hpi)
assert "cardiac" in result.ros
result2 = llm.combined_call("Patient: mild shortness of breath", full_hpi)
assert "respiratory" in result2.ros
def test_mock_llm_speed():
"""
MockLLM combined_call must complete under 100ms per call.
(Real LLM test is separate β this validates no accidental model load in mock mode.)
"""
llm = MockLLM()
state = CombinedOutput().model_dump_json()
times = []
for _ in range(5):
t0 = time.perf_counter()
llm.combined_call("Patient: I have chest pain since this morning in the center of my chest", state)
times.append(time.perf_counter() - t0)
avg_ms = (sum(times) / len(times)) * 1000
print(f"\n[speed] MockLLM avg combined_call: {avg_ms:.1f}ms")
assert avg_ms < 100, f"MockLLM too slow: {avg_ms:.1f}ms avg (should be <100ms)"
def test_combined_output_schema_round_trip():
"""CombinedOutput must survive JSON round-trip without data loss."""
original = CombinedOutput(
chief_complaint="headache",
onset="3 days ago",
location="forehead",
duration="constant",
character="throbbing",
severity="6/10",
aggravating="bright light",
relieving="dark room",
ros={"neuro": ["dizziness"], "ent": ["no ear pain"]},
emergency=False,
reply="Any vision changes?",
)
json_str = original.model_dump_json()
restored = CombinedOutput.model_validate_json(json_str)
assert restored.chief_complaint == "headache"
assert restored.severity == "6/10"
assert restored.ros["neuro"] == ["dizziness"]
assert restored.reply == "Any vision changes?"
# βββββββββββββββββββββ API integration tests βββββββββββββββββββββ
@pytest.mark.asyncio(loop_scope="function")
async def test_health_endpoint(client):
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["mock_mode"] is True
@pytest.mark.asyncio(loop_scope="function")
async def test_emergency_triage_node(client):
"""Emergency phrase should bypass agent and return 911 message immediately."""
session_id = "test_emergency"
await client.post("/chat", json={"session_id": session_id, "message": "hello"})
response = await client.post(
"/chat", json={"session_id": session_id, "message": "I am having crushing chest pain"}
)
assert response.status_code == 200
data = response.json()
assert data["state"] == "done"
assert "911" in data["reply"] or "emergency" in data["reply"].lower()
@pytest.mark.asyncio(loop_scope="function")
async def test_full_intake_multi_turn_extraction(client):
"""
The agent should extract multiple fields per message and skip already-answered questions.
After 3 messages that collectively answer all HPI fields + 3 ROS systems, state should be 'done'.
"""
session_id = "test_multi_extract"
# Kick-off
r = await client.post("/chat", json={"session_id": session_id, "message": "hello"})
assert r.status_code == 200
# Message 1: CC + onset + location
r = await client.post("/chat", json={
"session_id": session_id,
"message": "I have chest pain since yesterday in the center of my chest"
})
data = r.json()
assert data["state"] in ("intake", "hpi")
# Message 2: duration + character + severity + aggravating + relieving
r = await client.post("/chat", json={
"session_id": session_id,
"message": "It is constant, tight and squeezing, about a 7 out of 10. Walking worsens it and resting helps."
})
data = r.json()
assert data["state"] in ("hpi", "ros")
# Message 3: ROS cardiac
r = await client.post("/chat", json={
"session_id": session_id,
"message": "I have palpitations"
})
data = r.json()
assert data["state"] == "ros"
# Message 4: ROS respiratory
r = await client.post("/chat", json={
"session_id": session_id,
"message": "mild shortness of breath"
})
data = r.json()
assert data["state"] == "ros"
# Message 5: ROS GI
r = await client.post("/chat", json={
"session_id": session_id,
"message": "no nausea"
})
data = r.json()
# Should be done now
assert data["state"] == "done"
assert data["brief"] is not None
brief = ClinicalBrief.model_validate(data["brief"])
assert brief.chief_complaint == "chest pain"
assert brief.hpi.onset is not None
assert brief.hpi.severity is not None
assert len(brief.ros) >= 3
@pytest.mark.asyncio(loop_scope="function")
async def test_api_response_time(client):
"""API with MockLLM must respond in under 2 seconds per message."""
session_id = "test_speed_api"
times = []
messages = [
"hello",
"I have a headache since this morning",
"It is on the left side of my head",
]
for msg in messages:
t0 = time.perf_counter()
r = await client.post("/chat", json={"session_id": session_id, "message": msg})
elapsed = time.perf_counter() - t0
times.append(elapsed)
assert r.status_code == 200
avg_s = sum(times) / len(times)
print(f"\n[speed] API avg response: {avg_s*1000:.0f}ms")
assert avg_s < 2.0, f"API too slow: {avg_s:.2f}s avg (should be <2s in mock mode)"
|