Spaces:
Sleeping
Sleeping
priyansh-saxena1 commited on
Commit Β·
0bcdd07
1
Parent(s): 284dfa9
fix: optimize loading
Browse files- app/graph.py +140 -158
- app/llm.py +190 -121
- tests/test_e2e.py +177 -42
app/graph.py
CHANGED
|
@@ -4,235 +4,217 @@ from typing import Optional, TypedDict, Annotated
|
|
| 4 |
from langgraph.graph import StateGraph, START, END
|
| 5 |
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
|
| 7 |
-
from app.llm import get_llm
|
| 8 |
-
from app.schemas import
|
| 9 |
|
| 10 |
_MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
|
| 11 |
|
|
|
|
| 12 |
def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
|
| 13 |
return left + right
|
| 14 |
|
|
|
|
| 15 |
class IntakeState(TypedDict):
|
| 16 |
messages: Annotated[list[dict], add_messages]
|
| 17 |
-
clinical_state: str
|
| 18 |
missing_fields: list[str]
|
| 19 |
current_node: str
|
| 20 |
clinical_brief: Optional[dict]
|
| 21 |
-
frontend_stage: str
|
|
|
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def format_transcript(messages: list[dict]) -> str:
|
| 29 |
-
|
| 30 |
-
# Only send the last couple of turns to not overwhelm if it's long, but ideally all
|
| 31 |
for m in messages:
|
| 32 |
role = "AI" if m["role"] == "assistant" else "Patient"
|
| 33 |
-
|
| 34 |
-
return "\n".join(
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
missing = []
|
| 41 |
-
stage = "intake"
|
| 42 |
-
|
| 43 |
if not state.chief_complaint:
|
| 44 |
-
missing.append("chief complaint
|
| 45 |
-
return missing
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
stage = "ros"
|
| 57 |
-
# Need at least a few systems covered if possible
|
| 58 |
-
if len(state.ros.keys()) < ROS_REQUIRED_COUNT:
|
| 59 |
-
missing.append(f"Review of Systems (ask about {ROS_REQUIRED_COUNT - len(state.ros.keys())} more bodily systems)")
|
| 60 |
-
return missing, stage
|
| 61 |
-
|
| 62 |
-
return [], "done"
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# -------------------- NODES --------------------
|
| 66 |
|
| 67 |
def triage_node(state: IntakeState) -> dict:
|
|
|
|
| 68 |
msgs = state.get("messages", [])
|
| 69 |
-
if
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
last_msg = msgs[-1]
|
| 73 |
-
if last_msg["role"] == "user":
|
| 74 |
-
content = last_msg["content"].lower()
|
| 75 |
-
emergencies = ["suicide", "kill myself", "crushing chest pain", "can't breathe", "heart attack"]
|
| 76 |
-
if any(e in content for e in emergencies):
|
| 77 |
return {
|
| 78 |
-
"messages": [{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
"current_node": "done",
|
| 80 |
-
"frontend_stage": "done"
|
| 81 |
}
|
| 82 |
-
|
| 83 |
-
return {"current_node": "extractor"}
|
| 84 |
|
| 85 |
|
| 86 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
msgs = state.get("messages", [])
|
| 88 |
-
|
| 89 |
-
|
|
|
|
| 90 |
return {
|
| 91 |
-
"
|
| 92 |
-
"
|
|
|
|
|
|
|
| 93 |
}
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
llm = get_llm()
|
| 100 |
transcript = format_transcript(msgs)
|
| 101 |
-
|
| 102 |
-
current_state_json = state.get("clinical_state")
|
| 103 |
-
if not current_state_json:
|
| 104 |
-
current_state_json = ClinicalStateExtraction().model_dump_json()
|
| 105 |
-
|
| 106 |
-
# Extractor Agent updates the state passively
|
| 107 |
-
new_state = llm.ask_json(transcript, current_state_json, ClinicalStateExtraction)
|
| 108 |
-
|
| 109 |
-
# Check if the extractor detected a latent emergency
|
| 110 |
-
if new_state.emergency_detected:
|
| 111 |
-
return {
|
| 112 |
-
"messages": [{"role": "assistant", "content": "π¨ EMERGENCY OVERRIDE: Based on your details, you require immediate medical attention. Call 911."}],
|
| 113 |
-
"current_node": "done",
|
| 114 |
-
"frontend_stage": "done",
|
| 115 |
-
"clinical_state": new_state.model_dump_json()
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
return {
|
| 119 |
-
"clinical_state": new_state.model_dump_json(),
|
| 120 |
-
"current_node": "evaluator"
|
| 121 |
-
}
|
| 122 |
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
| 125 |
-
state_json = state.get("clinical_state")
|
| 126 |
-
if not state_json:
|
| 127 |
-
clinical_state = ClinicalStateExtraction()
|
| 128 |
-
else:
|
| 129 |
-
clinical_state = ClinicalStateExtraction.model_validate_json(state_json)
|
| 130 |
-
|
| 131 |
-
missing, stage = evaluate_missing(clinical_state)
|
| 132 |
-
|
| 133 |
-
if not missing:
|
| 134 |
return {
|
| 135 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
"frontend_stage": "done",
|
| 137 |
-
"current_node": "scribe"
|
| 138 |
}
|
| 139 |
-
|
| 140 |
-
return {
|
| 141 |
-
"missing_fields": missing,
|
| 142 |
-
"frontend_stage": stage,
|
| 143 |
-
"current_node": "conversationalist"
|
| 144 |
-
}
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return {
|
| 154 |
-
"messages": [{"role": "assistant", "content": "
|
| 155 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
}
|
| 157 |
-
|
| 158 |
-
# Check if the agent just spoke (prevent double-speaking if no user input)
|
| 159 |
-
if msgs[-1]["role"] == "assistant":
|
| 160 |
-
return {"current_node": "conversationalist"}
|
| 161 |
-
|
| 162 |
-
# Dynamic target targeting the top missing field
|
| 163 |
-
target = missing[0] if missing else "general details"
|
| 164 |
-
|
| 165 |
-
system_prompt = (
|
| 166 |
-
"You are an empathetic clinical intake assistant. "
|
| 167 |
-
"Your sole job is to ask the next logical medical question in a conversational way. "
|
| 168 |
-
f"We currently know this info about the patient:\n{clinical_json}\n\n"
|
| 169 |
-
f"YOUR GOAL: You MUST naturally uncover the following missing information: {target}. "
|
| 170 |
-
"Keep your response to exactly ONE question. Be concise and friendly."
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
transcript = format_transcript(msgs[-6:]) # Context window
|
| 174 |
-
llm = get_llm()
|
| 175 |
-
reply = llm.ask(f"Transcript:\n{transcript}\n\nAsk the next question about: {target}.", system=system_prompt)
|
| 176 |
-
|
| 177 |
return {
|
| 178 |
"messages": [{"role": "assistant", "content": reply}],
|
| 179 |
-
"
|
|
|
|
|
|
|
|
|
|
| 180 |
}
|
| 181 |
|
| 182 |
|
| 183 |
def scribe_node(state: IntakeState) -> dict:
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
| 187 |
from datetime import datetime, timezone
|
| 188 |
-
|
| 189 |
brief = ClinicalBrief(
|
| 190 |
chief_complaint=data.chief_complaint or "Not specified",
|
| 191 |
-
hpi=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
ros=data.ros,
|
| 193 |
generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
| 194 |
)
|
| 195 |
|
| 196 |
return {
|
| 197 |
-
"messages": [{"role": "assistant", "content": "
|
| 198 |
"current_node": "done",
|
|
|
|
| 199 |
"clinical_brief": brief.model_dump(),
|
| 200 |
}
|
| 201 |
|
| 202 |
|
|
|
|
|
|
|
| 203 |
def build_graph():
|
| 204 |
workflow = StateGraph(IntakeState)
|
| 205 |
|
| 206 |
workflow.add_node("triage", triage_node)
|
| 207 |
-
workflow.add_node("
|
| 208 |
-
workflow.add_node("evaluator", evaluator_node)
|
| 209 |
-
workflow.add_node("conversationalist", conversationalist_node)
|
| 210 |
-
workflow.add_node("scribe", scribe_node)
|
| 211 |
|
| 212 |
def route_triage(state: IntakeState) -> str:
|
| 213 |
-
|
| 214 |
-
return state.get("current_node", "extractor")
|
| 215 |
-
|
| 216 |
-
def route_extractor(state: IntakeState) -> str:
|
| 217 |
-
# Extractor marks it 'done' if latent emergency, else 'evaluator'
|
| 218 |
-
return state.get("current_node", "evaluator")
|
| 219 |
-
|
| 220 |
-
def route_evaluator(state: IntakeState) -> str:
|
| 221 |
-
return state.get("current_node", "conversationalist")
|
| 222 |
|
| 223 |
workflow.add_edge(START, "triage")
|
| 224 |
-
workflow.add_conditional_edges("triage", route_triage, {"done": END, "
|
| 225 |
-
workflow.
|
| 226 |
-
workflow.add_conditional_edges("evaluator", route_evaluator, {"conversationalist": "conversationalist", "scribe": "scribe"})
|
| 227 |
-
|
| 228 |
-
workflow.add_edge("conversationalist", END)
|
| 229 |
-
workflow.add_edge("scribe", END)
|
| 230 |
|
| 231 |
checkpointer = MemorySaver()
|
| 232 |
-
# Interrupt after
|
| 233 |
graph = workflow.compile(
|
| 234 |
checkpointer=checkpointer,
|
| 235 |
-
interrupt_after=["
|
| 236 |
)
|
| 237 |
|
| 238 |
return graph, checkpointer
|
|
|
|
| 4 |
from langgraph.graph import StateGraph, START, END
|
| 5 |
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
|
| 7 |
+
from app.llm import get_llm, CombinedOutput
|
| 8 |
+
from app.schemas import ClinicalBrief, HPI, ClinicalStateExtraction
|
| 9 |
|
| 10 |
_MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
|
| 11 |
|
| 12 |
+
|
| 13 |
def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
|
| 14 |
return left + right
|
| 15 |
|
| 16 |
+
|
| 17 |
class IntakeState(TypedDict):
|
| 18 |
messages: Annotated[list[dict], add_messages]
|
| 19 |
+
clinical_state: str # JSON of CombinedOutput (accumulated clinical data)
|
| 20 |
missing_fields: list[str]
|
| 21 |
current_node: str
|
| 22 |
clinical_brief: Optional[dict]
|
| 23 |
+
frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
|
| 24 |
+
|
| 25 |
|
| 26 |
+
HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
|
| 27 |
+
ROS_REQUIRED = 3
|
| 28 |
|
| 29 |
+
EMERGENCY_PHRASES = [
|
| 30 |
+
"crushing chest pain", "can't breathe", "cannot breathe",
|
| 31 |
+
"heart attack", "suicide", "kill myself", "can't move", "dying"
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ------------------------------------------------------------------ helpers --
|
| 36 |
|
| 37 |
def format_transcript(messages: list[dict]) -> str:
|
| 38 |
+
lines = []
|
|
|
|
| 39 |
for m in messages:
|
| 40 |
role = "AI" if m["role"] == "assistant" else "Patient"
|
| 41 |
+
lines.append(f"{role}: {m['content']}")
|
| 42 |
+
return "\n".join(lines)
|
| 43 |
|
| 44 |
+
|
| 45 |
+
def compute_stage(state: CombinedOutput) -> str:
|
| 46 |
+
if not state.chief_complaint:
|
| 47 |
+
return "intake"
|
| 48 |
+
for f in HPI_FIELDS:
|
| 49 |
+
if not getattr(state, f):
|
| 50 |
+
return "hpi"
|
| 51 |
+
if len(state.ros) < ROS_REQUIRED:
|
| 52 |
+
return "ros"
|
| 53 |
+
return "done"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def missing_from(state: CombinedOutput) -> list[str]:
|
| 57 |
missing = []
|
|
|
|
|
|
|
| 58 |
if not state.chief_complaint:
|
| 59 |
+
missing.append("chief complaint")
|
| 60 |
+
return missing
|
| 61 |
+
for f in HPI_FIELDS:
|
| 62 |
+
if not getattr(state, f):
|
| 63 |
+
missing.append(f"HPI:{f}")
|
| 64 |
+
if len(state.ros) < ROS_REQUIRED:
|
| 65 |
+
missing.append(f"ROS ({ROS_REQUIRED - len(state.ros)} more systems needed)")
|
| 66 |
+
return missing
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ------------------------------------------------------------------- nodes ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
def triage_node(state: IntakeState) -> dict:
|
| 72 |
+
"""Fast keyword check β no LLM call. Abort immediately on emergency phrases."""
|
| 73 |
msgs = state.get("messages", [])
|
| 74 |
+
if msgs and msgs[-1]["role"] == "user":
|
| 75 |
+
content = msgs[-1]["content"].lower()
|
| 76 |
+
if any(p in content for p in EMERGENCY_PHRASES):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
return {
|
| 78 |
+
"messages": [{
|
| 79 |
+
"role": "assistant",
|
| 80 |
+
"content": (
|
| 81 |
+
"π¨ EMERGENCY: Your symptoms require immediate attention. "
|
| 82 |
+
"Please call 911 or go to your nearest emergency room right away."
|
| 83 |
+
)
|
| 84 |
+
}],
|
| 85 |
"current_node": "done",
|
| 86 |
+
"frontend_stage": "done",
|
| 87 |
}
|
| 88 |
+
return {"current_node": "agent"}
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
+
def agent_node(state: IntakeState) -> dict:
|
| 92 |
+
"""
|
| 93 |
+
Core agent node β ONE combined LLM call per turn:
|
| 94 |
+
1. Extracts any new clinical data from the transcript.
|
| 95 |
+
2. Generates the next conversational question.
|
| 96 |
+
3. If all data is collected, builds the ClinicalBrief inline (no separate scribe node).
|
| 97 |
+
"""
|
| 98 |
msgs = state.get("messages", [])
|
| 99 |
+
|
| 100 |
+
# On first call with no messages, return opening greeting
|
| 101 |
+
if not msgs or (len(msgs) == 1 and msgs[0]["role"] == "assistant"):
|
| 102 |
return {
|
| 103 |
+
"messages": [{"role": "assistant", "content": "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"}],
|
| 104 |
+
"clinical_state": CombinedOutput().model_dump_json(),
|
| 105 |
+
"frontend_stage": "intake",
|
| 106 |
+
"current_node": "agent",
|
| 107 |
}
|
| 108 |
+
|
| 109 |
+
if msgs[-1]["role"] == "assistant":
|
| 110 |
+
return {"current_node": "agent"}
|
| 111 |
+
|
| 112 |
+
current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
|
|
|
|
| 113 |
transcript = format_transcript(msgs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
llm = get_llm()
|
| 116 |
+
result: CombinedOutput = llm.combined_call(transcript, current_json)
|
| 117 |
|
| 118 |
+
if result.emergency:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return {
|
| 120 |
+
"messages": [{"role": "assistant", "content": (
|
| 121 |
+
"π¨ EMERGENCY: Your symptoms require immediate attention. "
|
| 122 |
+
"Please call 911 or go to your nearest emergency room right away."
|
| 123 |
+
)}],
|
| 124 |
+
"clinical_state": result.model_dump_json(),
|
| 125 |
+
"current_node": "done",
|
| 126 |
"frontend_stage": "done",
|
|
|
|
| 127 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
stage = compute_stage(result)
|
| 130 |
+
missing = missing_from(result)
|
| 131 |
+
reply = result.reply or "Could you tell me more?"
|
| 132 |
+
|
| 133 |
+
# All fields complete β build the brief inline so it's available this turn
|
| 134 |
+
if stage == "done":
|
| 135 |
+
from datetime import datetime, timezone
|
| 136 |
+
brief = ClinicalBrief(
|
| 137 |
+
chief_complaint=result.chief_complaint or "Not specified",
|
| 138 |
+
hpi=HPI(
|
| 139 |
+
onset=result.onset or "Not specified",
|
| 140 |
+
location=result.location or "Not specified",
|
| 141 |
+
duration=result.duration or "Not specified",
|
| 142 |
+
character=result.character or "Not specified",
|
| 143 |
+
severity=result.severity or "Not specified",
|
| 144 |
+
aggravating=result.aggravating or "Not specified",
|
| 145 |
+
relieving=result.relieving or "Not specified",
|
| 146 |
+
),
|
| 147 |
+
ros=result.ros,
|
| 148 |
+
generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
| 149 |
+
)
|
| 150 |
return {
|
| 151 |
+
"messages": [{"role": "assistant", "content": "Your clinical summary is ready. Please wait for the doctor."}],
|
| 152 |
+
"clinical_state": result.model_dump_json(),
|
| 153 |
+
"missing_fields": [],
|
| 154 |
+
"frontend_stage": "done",
|
| 155 |
+
"current_node": "done",
|
| 156 |
+
"clinical_brief": brief.model_dump(),
|
| 157 |
}
|
| 158 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
return {
|
| 160 |
"messages": [{"role": "assistant", "content": reply}],
|
| 161 |
+
"clinical_state": result.model_dump_json(),
|
| 162 |
+
"missing_fields": missing,
|
| 163 |
+
"frontend_stage": stage,
|
| 164 |
+
"current_node": "agent",
|
| 165 |
}
|
| 166 |
|
| 167 |
|
| 168 |
def scribe_node(state: IntakeState) -> dict:
|
| 169 |
+
"""Build the final ClinicalBrief from the accumulated CombinedOutput state."""
|
| 170 |
+
state_json = state.get("clinical_state", "{}")
|
| 171 |
+
data = CombinedOutput.model_validate_json(state_json)
|
| 172 |
+
|
| 173 |
from datetime import datetime, timezone
|
| 174 |
+
|
| 175 |
brief = ClinicalBrief(
|
| 176 |
chief_complaint=data.chief_complaint or "Not specified",
|
| 177 |
+
hpi=HPI(
|
| 178 |
+
onset=data.onset or "Not specified",
|
| 179 |
+
location=data.location or "Not specified",
|
| 180 |
+
duration=data.duration or "Not specified",
|
| 181 |
+
character=data.character or "Not specified",
|
| 182 |
+
severity=data.severity or "Not specified",
|
| 183 |
+
aggravating=data.aggravating or "Not specified",
|
| 184 |
+
relieving=data.relieving or "Not specified",
|
| 185 |
+
),
|
| 186 |
ros=data.ros,
|
| 187 |
generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
| 188 |
)
|
| 189 |
|
| 190 |
return {
|
| 191 |
+
"messages": [{"role": "assistant", "content": "Your clinical summary is ready. Please wait for the doctor."}],
|
| 192 |
"current_node": "done",
|
| 193 |
+
"frontend_stage": "done",
|
| 194 |
"clinical_brief": brief.model_dump(),
|
| 195 |
}
|
| 196 |
|
| 197 |
|
| 198 |
+
# -------------------------------------------------------------- graph build --
|
| 199 |
+
|
| 200 |
def build_graph():
|
| 201 |
workflow = StateGraph(IntakeState)
|
| 202 |
|
| 203 |
workflow.add_node("triage", triage_node)
|
| 204 |
+
workflow.add_node("agent", agent_node)
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
def route_triage(state: IntakeState) -> str:
|
| 207 |
+
return state.get("current_node", "agent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
workflow.add_edge(START, "triage")
|
| 210 |
+
workflow.add_conditional_edges("triage", route_triage, {"done": END, "agent": "agent"})
|
| 211 |
+
workflow.add_edge("agent", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
checkpointer = MemorySaver()
|
| 214 |
+
# Interrupt after agent so it pauses for user input each turn
|
| 215 |
graph = workflow.compile(
|
| 216 |
checkpointer=checkpointer,
|
| 217 |
+
interrupt_after=["agent"]
|
| 218 |
)
|
| 219 |
|
| 220 |
return graph, checkpointer
|
app/llm.py
CHANGED
|
@@ -1,104 +1,180 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
|
|
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class MockLLM:
|
| 12 |
-
def
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
return "I'm sorry to hear about your chest pain. When did it start?"
|
| 20 |
-
return "I understand. Can you tell me more?"
|
| 21 |
-
|
| 22 |
-
# General fallback that allows tests to check for context
|
| 23 |
-
if "onset" in instruction.lower():
|
| 24 |
-
return "When did this start?"
|
| 25 |
-
elif "severity" in instruction.lower() or "scale" in instruction.lower():
|
| 26 |
-
return "On a scale of 1 to 10, how severe is this?"
|
| 27 |
-
elif "location" in instruction.lower():
|
| 28 |
-
return "Where exactly do you feel this?"
|
| 29 |
-
|
| 30 |
-
return "Can you elaborate on that?"
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
if
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
if "crushing chest pain" in t_low or "heart attack" in t_low or "emergency" in t_low:
|
| 73 |
-
state_dict["emergency_detected"] = True
|
| 74 |
-
|
| 75 |
-
# Guarantee schema matches via Pydantic model_validate
|
| 76 |
-
return schema_cls.model_validate(state_dict)
|
| 77 |
|
| 78 |
class TransformersLLM:
|
| 79 |
def __init__(self):
|
| 80 |
self.model = None
|
| 81 |
self.tokenizer = None
|
| 82 |
self.model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
|
|
|
|
| 83 |
|
| 84 |
def _load(self):
|
| 85 |
-
if self.model is None:
|
|
|
|
| 86 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 87 |
import torch
|
|
|
|
| 88 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
|
|
|
|
|
| 89 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 90 |
self.model_name,
|
| 91 |
-
torch_dtype=
|
| 92 |
device_map="cpu",
|
|
|
|
| 93 |
)
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
def
|
| 96 |
-
|
| 97 |
import torch
|
| 98 |
-
messages = [
|
| 99 |
-
{"role": "system", "content": system},
|
| 100 |
-
{"role": "user", "content": instruction},
|
| 101 |
-
]
|
| 102 |
text = self.tokenizer.apply_chat_template(
|
| 103 |
messages, tokenize=False, add_generation_prompt=True
|
| 104 |
)
|
|
@@ -106,9 +182,8 @@ class TransformersLLM:
|
|
| 106 |
with torch.no_grad():
|
| 107 |
outputs = self.model.generate(
|
| 108 |
**inputs,
|
| 109 |
-
max_new_tokens=
|
| 110 |
-
|
| 111 |
-
do_sample=True,
|
| 112 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 113 |
)
|
| 114 |
response = self.tokenizer.decode(
|
|
@@ -117,58 +192,52 @@ class TransformersLLM:
|
|
| 117 |
)
|
| 118 |
return response.strip()
|
| 119 |
|
| 120 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
self._load()
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
"
|
| 126 |
-
"
|
| 127 |
-
"
|
| 128 |
-
"
|
| 129 |
-
)
|
| 130 |
-
instruction = (
|
| 131 |
-
f"CURRENT STATE JSON (Update this based on the transcript):\n{current_state}\n\n"
|
| 132 |
-
f"TRANSCRIPT:\n{transcript}\n\n"
|
| 133 |
-
f"Output ONLY valid JSON matching this schema structure:\n"
|
| 134 |
-
f"{schema_cls.model_json_schema()}"
|
| 135 |
)
|
| 136 |
-
|
| 137 |
messages = [
|
| 138 |
-
{"role": "system", "content":
|
| 139 |
-
{"role": "user", "content":
|
| 140 |
]
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
max_new_tokens=400,
|
| 147 |
-
temperature=0.1, # Keep low for JSON determinism
|
| 148 |
-
do_sample=False,
|
| 149 |
-
pad_token_id=self.tokenizer.eos_token_id,
|
| 150 |
-
)
|
| 151 |
-
response = self.tokenizer.decode(
|
| 152 |
-
outputs[0][inputs.input_ids.shape[1]:],
|
| 153 |
-
skip_special_tokens=True,
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
# Attempt to parse json from output
|
| 157 |
-
json_str = response.strip()
|
| 158 |
if "```json" in json_str:
|
| 159 |
-
json_str = json_str.split("```json")[
|
| 160 |
elif "```" in json_str:
|
| 161 |
-
json_str = json_str.split("```")[
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
try:
|
| 164 |
parsed = json.loads(json_str)
|
| 165 |
-
return
|
| 166 |
-
except Exception:
|
| 167 |
-
|
|
|
|
| 168 |
try:
|
| 169 |
-
|
|
|
|
|
|
|
| 170 |
except Exception:
|
| 171 |
-
return
|
| 172 |
|
| 173 |
|
| 174 |
_llm_instance = None
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
+
import re
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
+
COMBINED_SYSTEM_PROMPT = """You are a clinical intake assistant AI. You have two jobs per turn:
|
| 7 |
+
|
| 8 |
+
JOB 1 (EXTRACT): Read the FULL conversation and update the clinical JSON state with any new information the patient provided. Only extract facts explicitly stated.
|
| 9 |
+
|
| 10 |
+
JOB 2 (RESPOND): Based on what is STILL MISSING from the clinical state, ask the patient ONE natural, empathetic question. Do NOT ask about things already filled in.
|
| 11 |
+
|
| 12 |
+
CRITICAL RULES:
|
| 13 |
+
- Output ONLY valid JSON, nothing else.
|
| 14 |
+
- Do NOT diagnose or give medical advice.
|
| 15 |
+
- Do NOT ask more than one question.
|
| 16 |
+
- If all fields are complete, set reply to "Thank you β I have everything I need."
|
| 17 |
+
- Emergency override: if patient mentions "crushing chest pain", "can't breathe", "suicide", or similar life-threatening phrases, set emergency=true.
|
| 18 |
+
|
| 19 |
+
OUTPUT FORMAT (strictly follow this, no extra text):
|
| 20 |
+
{
|
| 21 |
+
"chief_complaint": "...",
|
| 22 |
+
"onset": "...",
|
| 23 |
+
"location": "...",
|
| 24 |
+
"duration": "...",
|
| 25 |
+
"character": "...",
|
| 26 |
+
"severity": "...",
|
| 27 |
+
"aggravating": "...",
|
| 28 |
+
"relieving": "...",
|
| 29 |
+
"ros": {"system_name": ["finding1", "finding2"]},
|
| 30 |
+
"emergency": false,
|
| 31 |
+
"reply": "The single question to ask the patient next"
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
Use null for any field not yet known. Keep existing values if the patient didn't add new info."""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CombinedOutput(BaseModel):
|
| 38 |
+
chief_complaint: str | None = None
|
| 39 |
+
onset: str | None = None
|
| 40 |
+
location: str | None = None
|
| 41 |
+
duration: str | None = None
|
| 42 |
+
character: str | None = None
|
| 43 |
+
severity: str | None = None
|
| 44 |
+
aggravating: str | None = None
|
| 45 |
+
relieving: str | None = None
|
| 46 |
+
ros: dict[str, list[str]] = {}
|
| 47 |
+
emergency: bool = False
|
| 48 |
+
reply: str = ""
|
| 49 |
+
|
| 50 |
|
| 51 |
class MockLLM:
|
| 52 |
+
def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
|
| 53 |
+
"""Single call: extract + generate reply. No real inference in mock mode."""
|
| 54 |
+
t = transcript.lower()
|
| 55 |
+
try:
|
| 56 |
+
state = json.loads(current_json)
|
| 57 |
+
except Exception:
|
| 58 |
+
state = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
# --- Extraction ---
|
| 61 |
+
if "chest pain" in t and not state.get("chief_complaint"):
|
| 62 |
+
state["chief_complaint"] = "chest pain"
|
| 63 |
+
if any(w in t for w in ["yesterday", "this morning", "last night", "hours ago", "days ago", "since"]):
|
| 64 |
+
if not state.get("onset"):
|
| 65 |
+
if "yesterday" in t:
|
| 66 |
+
state["onset"] = "yesterday"
|
| 67 |
+
elif "this morning" in t or "morning" in t:
|
| 68 |
+
state["onset"] = "this morning"
|
| 69 |
+
else:
|
| 70 |
+
state["onset"] = "recently"
|
| 71 |
+
if any(w in t for w in ["center", "left", "right", "chest", "stomach", "head", "arm"]):
|
| 72 |
+
if not state.get("location"):
|
| 73 |
+
if "center" in t:
|
| 74 |
+
state["location"] = "center of chest"
|
| 75 |
+
elif "left" in t:
|
| 76 |
+
state["location"] = "left side of chest"
|
| 77 |
+
if any(w in t for w in ["constant", "intermittent", "comes and goes", "all day", "hours"]):
|
| 78 |
+
if not state.get("duration"):
|
| 79 |
+
state["duration"] = "constant" if "constant" in t else "intermittent"
|
| 80 |
+
if any(w in t for w in ["pressure", "tight", "squeezing", "sharp", "dull", "burning", "stabbing"]):
|
| 81 |
+
if not state.get("character"):
|
| 82 |
+
if "tight" in t or "squeezing" in t:
|
| 83 |
+
state["character"] = "tight, squeezing pressure"
|
| 84 |
+
elif "sharp" in t:
|
| 85 |
+
state["character"] = "sharp"
|
| 86 |
+
# Severity β match "N out of 10", "N/10", or isolated score digit
|
| 87 |
+
sev_match = re.search(r'\b([1-9]|10)\s*(?:out of|/|over)\s*10\b', t, re.IGNORECASE)
|
| 88 |
+
if not sev_match:
|
| 89 |
+
sev_match = re.search(r'\bseverity\s+(?:is\s+)?([1-9]|10)\b', t, re.IGNORECASE)
|
| 90 |
+
if sev_match and not state.get("severity"):
|
| 91 |
+
state["severity"] = f"{sev_match.group(1)}/10"
|
| 92 |
+
if any(w in t for w in ["walk", "run", "climb", "exert", "stress", "eating", "lying"]):
|
| 93 |
+
if not state.get("aggravating"):
|
| 94 |
+
if "walk" in t: state["aggravating"] = "walking"
|
| 95 |
+
elif "run" in t: state["aggravating"] = "running"
|
| 96 |
+
elif "climb" in t: state["aggravating"] = "climbing stairs"
|
| 97 |
+
if any(w in t for w in ["rest", "sit", "antacid", "medication", "nitroglycerin"]):
|
| 98 |
+
if not state.get("relieving"):
|
| 99 |
+
state["relieving"] = "resting"
|
| 100 |
+
if "palpitation" in t:
|
| 101 |
+
ros = state.get("ros", {})
|
| 102 |
+
ros["cardiac"] = ["palpitations present"] + (["no leg swelling"] if "no" in t and "swell" in t else [])
|
| 103 |
+
state["ros"] = ros
|
| 104 |
+
if "breath" in t or "wheez" in t or "cough" in t:
|
| 105 |
+
ros = state.get("ros", {})
|
| 106 |
+
ros["respiratory"] = ["shortness of breath" if "breath" in t else "no shortness of breath",
|
| 107 |
+
"no cough" if ("no" in t and "cough" in t) else ("cough" if "cough" in t else "no cough")]
|
| 108 |
+
state["ros"] = ros
|
| 109 |
+
if "nausea" in t or "vomit" in t or "heartburn" in t:
|
| 110 |
+
ros = state.get("ros", {})
|
| 111 |
+
ros["gi"] = ["no nausea" if ("no" in t and "nausea" in t) else "nausea",
|
| 112 |
+
"no vomiting" if ("no" in t and "vomit" in t) else "vomiting present"]
|
| 113 |
+
state["ros"] = ros
|
| 114 |
|
| 115 |
+
state["emergency"] = any(e in t for e in ["crushing chest pain", "heart attack", "can't breathe", "suicide", "kill myself"])
|
| 116 |
+
|
| 117 |
+
# --- Determine next question ---
|
| 118 |
+
if not state.get("chief_complaint"):
|
| 119 |
+
state["reply"] = "What brings you in today?"
|
| 120 |
+
elif not state.get("onset"):
|
| 121 |
+
cc = state.get("chief_complaint", "this")
|
| 122 |
+
state["reply"] = f"When did the {cc} start?"
|
| 123 |
+
elif not state.get("location"):
|
| 124 |
+
state["reply"] = "Where exactly do you feel it?"
|
| 125 |
+
elif not state.get("duration"):
|
| 126 |
+
state["reply"] = "Is it constant or does it come and go?"
|
| 127 |
+
elif not state.get("character"):
|
| 128 |
+
state["reply"] = "How would you describe it β sharp, dull, pressure, or tightness?"
|
| 129 |
+
elif not state.get("severity"):
|
| 130 |
+
state["reply"] = "On a scale of 1 to 10, how severe is it right now?"
|
| 131 |
+
elif not state.get("aggravating"):
|
| 132 |
+
state["reply"] = "Does anything make it worse, like physical activity?"
|
| 133 |
+
elif not state.get("relieving"):
|
| 134 |
+
state["reply"] = "What helps relieve it?"
|
| 135 |
+
else:
|
| 136 |
+
ros = state.get("ros", {})
|
| 137 |
+
cc = state.get("chief_complaint", "chest pain")
|
| 138 |
+
if "cardiac" not in ros:
|
| 139 |
+
state["reply"] = "Any heart-related symptoms β palpitations or leg swelling?"
|
| 140 |
+
elif "respiratory" not in ros:
|
| 141 |
+
state["reply"] = "Any shortness of breath, wheezing, or coughing?"
|
| 142 |
+
elif "gi" not in ros:
|
| 143 |
+
state["reply"] = "Any nausea, vomiting, or heartburn?"
|
| 144 |
+
else:
|
| 145 |
+
state["reply"] = "Thank you β I have everything I need."
|
| 146 |
+
|
| 147 |
+
return CombinedOutput.model_validate(state)
|
| 148 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
class TransformersLLM:
|
| 151 |
def __init__(self):
|
| 152 |
self.model = None
|
| 153 |
self.tokenizer = None
|
| 154 |
self.model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 155 |
+
self._load_lock = False
|
| 156 |
|
| 157 |
def _load(self):
|
| 158 |
+
if self.model is None and not self._load_lock:
|
| 159 |
+
self._load_lock = True
|
| 160 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 161 |
import torch
|
| 162 |
+
print(f"[LLM] Loading model {self.model_name}...")
|
| 163 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 164 |
+
# Use float16 β halves memory footprint and is ~2x faster than float32 on CPU
|
| 165 |
+
dtype = torch.float16
|
| 166 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 167 |
self.model_name,
|
| 168 |
+
torch_dtype=dtype,
|
| 169 |
device_map="cpu",
|
| 170 |
+
low_cpu_mem_usage=True,
|
| 171 |
)
|
| 172 |
+
self.model.eval()
|
| 173 |
+
print("[LLM] Model ready.")
|
| 174 |
|
| 175 |
+
def _infer(self, messages: list[dict], max_tokens: int = 350) -> str:
|
| 176 |
+
"""Single shared inference method. Greedy decode for speed."""
|
| 177 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
text = self.tokenizer.apply_chat_template(
|
| 179 |
messages, tokenize=False, add_generation_prompt=True
|
| 180 |
)
|
|
|
|
| 182 |
with torch.no_grad():
|
| 183 |
outputs = self.model.generate(
|
| 184 |
**inputs,
|
| 185 |
+
max_new_tokens=max_tokens,
|
| 186 |
+
do_sample=False, # Greedy β deterministic and fastest
|
|
|
|
| 187 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 188 |
)
|
| 189 |
response = self.tokenizer.decode(
|
|
|
|
| 192 |
)
|
| 193 |
return response.strip()
|
| 194 |
|
| 195 |
+
def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
|
| 196 |
+
"""
|
| 197 |
+
Single LLM call that BOTH extracts clinical data AND generates the next reply.
|
| 198 |
+
This halves latency vs. running extractor + conversationalist separately.
|
| 199 |
+
"""
|
| 200 |
self._load()
|
| 201 |
+
|
| 202 |
+
prompt = (
|
| 203 |
+
f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
|
| 204 |
+
f"FULL CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
|
| 205 |
+
"Instructions: Extract all new clinical facts from the transcript, merge them into the state, "
|
| 206 |
+
"and generate exactly ONE empathetic follow-up question for whatever is still missing. "
|
| 207 |
+
"Return ONLY the JSON object, no other text."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
)
|
|
|
|
| 209 |
messages = [
|
| 210 |
+
{"role": "system", "content": COMBINED_SYSTEM_PROMPT},
|
| 211 |
+
{"role": "user", "content": prompt},
|
| 212 |
]
|
| 213 |
+
|
| 214 |
+
raw = self._infer(messages, max_tokens=350)
|
| 215 |
+
|
| 216 |
+
# Parse JSON robustly
|
| 217 |
+
json_str = raw
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
if "```json" in json_str:
|
| 219 |
+
json_str = json_str.split("```json", 1)[1].split("```")[0]
|
| 220 |
elif "```" in json_str:
|
| 221 |
+
json_str = json_str.split("```", 1)[1].split("```")[0]
|
| 222 |
+
|
| 223 |
+
# Find first { ... } block
|
| 224 |
+
start = json_str.find("{")
|
| 225 |
+
end = json_str.rfind("}") + 1
|
| 226 |
+
if start != -1 and end > start:
|
| 227 |
+
json_str = json_str[start:end]
|
| 228 |
+
|
| 229 |
try:
|
| 230 |
parsed = json.loads(json_str)
|
| 231 |
+
return CombinedOutput.model_validate(parsed)
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print(f"[LLM] JSON parse error: {e}\nRaw output: {raw[:300]}")
|
| 234 |
+
# Return current state + error reply β never crash
|
| 235 |
try:
|
| 236 |
+
base = CombinedOutput.model_validate_json(current_json)
|
| 237 |
+
base.reply = "Could you please repeat that? I want to make sure I understood correctly."
|
| 238 |
+
return base
|
| 239 |
except Exception:
|
| 240 |
+
return CombinedOutput(reply="Could you please repeat that?")
|
| 241 |
|
| 242 |
|
| 243 |
_llm_instance = None
|
tests/test_e2e.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
os.environ["MOCK_LLM"] = "true"
|
| 3 |
|
| 4 |
import pytest
|
|
@@ -6,6 +7,10 @@ from httpx import AsyncClient, ASGITransport
|
|
| 6 |
|
| 7 |
from app.main import app
|
| 8 |
from app.schemas import ClinicalBrief
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
@pytest.fixture
|
| 11 |
async def client():
|
|
@@ -13,6 +18,114 @@ async def client():
|
|
| 13 |
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
| 14 |
yield c
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
@pytest.mark.asyncio(loop_scope="function")
|
| 17 |
async def test_health_endpoint(client):
|
| 18 |
response = await client.get("/health")
|
|
@@ -21,62 +134,84 @@ async def test_health_endpoint(client):
|
|
| 21 |
assert data["status"] == "ok"
|
| 22 |
assert data["mock_mode"] is True
|
| 23 |
|
|
|
|
| 24 |
@pytest.mark.asyncio(loop_scope="function")
|
| 25 |
-
async def
|
| 26 |
-
"""
|
| 27 |
session_id = "test_emergency"
|
| 28 |
-
|
| 29 |
await client.post("/chat", json={"session_id": session_id, "message": "hello"})
|
| 30 |
-
|
| 31 |
-
|
|
|
|
| 32 |
assert response.status_code == 200
|
| 33 |
data = response.json()
|
| 34 |
-
|
| 35 |
assert data["state"] == "done"
|
| 36 |
assert "911" in data["reply"] or "emergency" in data["reply"].lower()
|
| 37 |
|
|
|
|
| 38 |
@pytest.mark.asyncio(loop_scope="function")
|
| 39 |
-
async def
|
| 40 |
"""
|
| 41 |
-
|
| 42 |
-
|
| 43 |
"""
|
| 44 |
-
session_id = "
|
| 45 |
-
|
| 46 |
-
await client.post("/chat", json={"session_id": session_id, "message": "hello"})
|
| 47 |
-
|
| 48 |
-
# 1. Chief Complaint & some HPI
|
| 49 |
-
# The mock LLM maps "chest pain" -> CC, "yesterday" -> onset
|
| 50 |
-
res = await client.post("/chat", json={"session_id": session_id, "message": "I have chest pain since yesterday"})
|
| 51 |
-
assert res.status_code == 200
|
| 52 |
-
data = res.json()
|
| 53 |
-
assert data["state"] == "hpi" # Needs more HPI info
|
| 54 |
-
|
| 55 |
-
# 2. More HPI info
|
| 56 |
-
res = await client.post("/chat", json={"session_id": session_id, "message": "It is constant pressure in the center. Severity is 7. Walking makes it worse, rest helps."})
|
| 57 |
-
assert res.status_code == 200
|
| 58 |
-
data = res.json()
|
| 59 |
-
assert data["state"] == "ros" # Completes HPI, moves to ROS
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
assert
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
assert data["state"] == "done"
|
| 68 |
assert data["brief"] is not None
|
| 69 |
-
|
| 70 |
brief = ClinicalBrief.model_validate(data["brief"])
|
| 71 |
assert brief.chief_complaint == "chest pain"
|
| 72 |
-
assert brief.hpi.onset
|
| 73 |
-
assert brief.hpi.
|
| 74 |
-
assert brief.
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import time
|
| 3 |
os.environ["MOCK_LLM"] = "true"
|
| 4 |
|
| 5 |
import pytest
|
|
|
|
| 7 |
|
| 8 |
from app.main import app
|
| 9 |
from app.schemas import ClinicalBrief
|
| 10 |
+
from app.llm import MockLLM, CombinedOutput
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# βββββββββββββββββββββββ fixtures βββββββββββββββββββββββ
|
| 14 |
|
| 15 |
@pytest.fixture
|
| 16 |
async def client():
|
|
|
|
| 18 |
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
| 19 |
yield c
|
| 20 |
|
| 21 |
+
|
| 22 |
+
# ββββββββββββββββββββββ unit tests ββββββββββββββββββββββ
|
| 23 |
+
|
| 24 |
+
def test_mock_llm_combined_call_basic_extraction():
|
| 25 |
+
"""MockLLM should extract chief complaint, onset and location in one call."""
|
| 26 |
+
llm = MockLLM()
|
| 27 |
+
transcript = "Patient: I have chest pain since yesterday\nAI: Where is it?\nPatient: Center of my chest"
|
| 28 |
+
result = llm.combined_call(transcript, CombinedOutput().model_dump_json())
|
| 29 |
+
assert result.chief_complaint == "chest pain"
|
| 30 |
+
assert result.onset == "yesterday"
|
| 31 |
+
assert result.location == "center of chest"
|
| 32 |
+
assert result.reply # Should ask the next missing question
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_mock_llm_emergency_detection():
|
| 36 |
+
"""MockLLM should detect emergency keywords and set emergency=True."""
|
| 37 |
+
llm = MockLLM()
|
| 38 |
+
transcript = "Patient: I am having crushing chest pain"
|
| 39 |
+
result = llm.combined_call(transcript, CombinedOutput().model_dump_json())
|
| 40 |
+
assert result.emergency is True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_mock_llm_does_not_repeat_filled_questions():
|
| 44 |
+
"""If onset is already known, the next question should NOT ask about onset again."""
|
| 45 |
+
llm = MockLLM()
|
| 46 |
+
current = CombinedOutput(chief_complaint="chest pain", onset="yesterday").model_dump_json()
|
| 47 |
+
transcript = "Patient: chest pain yesterday\nAI: ok\nPatient: anything new"
|
| 48 |
+
result = llm.combined_call(transcript, current)
|
| 49 |
+
assert result.onset == "yesterday" # Should be preserved
|
| 50 |
+
assert "when" not in result.reply.lower() # Should not re-ask onset
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_mock_llm_severity_extraction():
|
| 54 |
+
"""Severity from different phrasings should always normalize to X/10."""
|
| 55 |
+
llm = MockLLM()
|
| 56 |
+
for phrase, expected in [
|
| 57 |
+
("it is a 7 out of 10", "7/10"),
|
| 58 |
+
("about 8 on the scale", None), # may not extract without explicit context
|
| 59 |
+
("i'd say 9 on a scale", None),
|
| 60 |
+
]:
|
| 61 |
+
state = CombinedOutput(
|
| 62 |
+
chief_complaint="chest pain", onset="yesterday",
|
| 63 |
+
location="chest", duration="constant", character="tight"
|
| 64 |
+
).model_dump_json()
|
| 65 |
+
result = llm.combined_call(f"Patient: {phrase}", state)
|
| 66 |
+
if expected:
|
| 67 |
+
assert result.severity == expected, f"Failed for: '{phrase}'"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_mock_llm_ros_extraction():
|
| 71 |
+
"""ROS should populate correctly when patient mentions system symptoms."""
|
| 72 |
+
llm = MockLLM()
|
| 73 |
+
full_hpi = CombinedOutput(
|
| 74 |
+
chief_complaint="chest pain", onset="yesterday", location="center of chest",
|
| 75 |
+
duration="constant", character="tight", severity="7/10",
|
| 76 |
+
aggravating="walking", relieving="resting"
|
| 77 |
+
).model_dump_json()
|
| 78 |
+
result = llm.combined_call("Patient: palpitations present no leg swelling", full_hpi)
|
| 79 |
+
assert "cardiac" in result.ros
|
| 80 |
+
|
| 81 |
+
result2 = llm.combined_call("Patient: mild shortness of breath", full_hpi)
|
| 82 |
+
assert "respiratory" in result2.ros
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_mock_llm_speed():
|
| 86 |
+
"""
|
| 87 |
+
MockLLM combined_call must complete under 100ms per call.
|
| 88 |
+
(Real LLM test is separate β this validates no accidental model load in mock mode.)
|
| 89 |
+
"""
|
| 90 |
+
llm = MockLLM()
|
| 91 |
+
state = CombinedOutput().model_dump_json()
|
| 92 |
+
|
| 93 |
+
times = []
|
| 94 |
+
for _ in range(5):
|
| 95 |
+
t0 = time.perf_counter()
|
| 96 |
+
llm.combined_call("Patient: I have chest pain since this morning in the center of my chest", state)
|
| 97 |
+
times.append(time.perf_counter() - t0)
|
| 98 |
+
|
| 99 |
+
avg_ms = (sum(times) / len(times)) * 1000
|
| 100 |
+
print(f"\n[speed] MockLLM avg combined_call: {avg_ms:.1f}ms")
|
| 101 |
+
assert avg_ms < 100, f"MockLLM too slow: {avg_ms:.1f}ms avg (should be <100ms)"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_combined_output_schema_round_trip():
|
| 105 |
+
"""CombinedOutput must survive JSON round-trip without data loss."""
|
| 106 |
+
original = CombinedOutput(
|
| 107 |
+
chief_complaint="headache",
|
| 108 |
+
onset="3 days ago",
|
| 109 |
+
location="forehead",
|
| 110 |
+
duration="constant",
|
| 111 |
+
character="throbbing",
|
| 112 |
+
severity="6/10",
|
| 113 |
+
aggravating="bright light",
|
| 114 |
+
relieving="dark room",
|
| 115 |
+
ros={"neuro": ["dizziness"], "ent": ["no ear pain"]},
|
| 116 |
+
emergency=False,
|
| 117 |
+
reply="Any vision changes?",
|
| 118 |
+
)
|
| 119 |
+
json_str = original.model_dump_json()
|
| 120 |
+
restored = CombinedOutput.model_validate_json(json_str)
|
| 121 |
+
assert restored.chief_complaint == "headache"
|
| 122 |
+
assert restored.severity == "6/10"
|
| 123 |
+
assert restored.ros["neuro"] == ["dizziness"]
|
| 124 |
+
assert restored.reply == "Any vision changes?"
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# βββοΏ½οΏ½οΏ½βββββββββββββββββ API integration tests βββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
@pytest.mark.asyncio(loop_scope="function")
|
| 130 |
async def test_health_endpoint(client):
|
| 131 |
response = await client.get("/health")
|
|
|
|
| 134 |
assert data["status"] == "ok"
|
| 135 |
assert data["mock_mode"] is True
|
| 136 |
|
| 137 |
+
|
| 138 |
@pytest.mark.asyncio(loop_scope="function")
|
| 139 |
+
async def test_emergency_triage_node(client):
|
| 140 |
+
"""Emergency phrase should bypass agent and return 911 message immediately."""
|
| 141 |
session_id = "test_emergency"
|
|
|
|
| 142 |
await client.post("/chat", json={"session_id": session_id, "message": "hello"})
|
| 143 |
+
response = await client.post(
|
| 144 |
+
"/chat", json={"session_id": session_id, "message": "I am having crushing chest pain"}
|
| 145 |
+
)
|
| 146 |
assert response.status_code == 200
|
| 147 |
data = response.json()
|
|
|
|
| 148 |
assert data["state"] == "done"
|
| 149 |
assert "911" in data["reply"] or "emergency" in data["reply"].lower()
|
| 150 |
|
| 151 |
+
|
| 152 |
@pytest.mark.asyncio(loop_scope="function")
|
| 153 |
+
async def test_full_intake_multi_turn_extraction(client):
|
| 154 |
"""
|
| 155 |
+
The agent should extract multiple fields per message and skip already-answered questions.
|
| 156 |
+
After 3 messages that collectively answer all HPI fields + 3 ROS systems, state should be 'done'.
|
| 157 |
"""
|
| 158 |
+
session_id = "test_multi_extract"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
# Kick-off
|
| 161 |
+
r = await client.post("/chat", json={"session_id": session_id, "message": "hello"})
|
| 162 |
+
assert r.status_code == 200
|
| 163 |
+
|
| 164 |
+
# Message 1: CC + onset + location
|
| 165 |
+
r = await client.post("/chat", json={
|
| 166 |
+
"session_id": session_id,
|
| 167 |
+
"message": "I have chest pain since yesterday in the center of my chest"
|
| 168 |
+
})
|
| 169 |
+
data = r.json()
|
| 170 |
+
assert data["state"] in ("intake", "hpi")
|
| 171 |
+
|
| 172 |
+
# Message 2: duration + character + severity + aggravating + relieving
|
| 173 |
+
r = await client.post("/chat", json={
|
| 174 |
+
"session_id": session_id,
|
| 175 |
+
"message": "It is constant, tight and squeezing, about a 7 out of 10. Walking worsens it and resting helps."
|
| 176 |
+
})
|
| 177 |
+
data = r.json()
|
| 178 |
+
assert data["state"] in ("hpi", "ros")
|
| 179 |
+
|
| 180 |
+
# Message 3: cover 3 ROS systems in one shot
|
| 181 |
+
r = await client.post("/chat", json={
|
| 182 |
+
"session_id": session_id,
|
| 183 |
+
"message": "I have palpitations, mild shortness of breath, and no nausea"
|
| 184 |
+
})
|
| 185 |
+
data = r.json()
|
| 186 |
+
# Should be done now
|
| 187 |
assert data["state"] == "done"
|
| 188 |
assert data["brief"] is not None
|
| 189 |
+
|
| 190 |
brief = ClinicalBrief.model_validate(data["brief"])
|
| 191 |
assert brief.chief_complaint == "chest pain"
|
| 192 |
+
assert brief.hpi.onset is not None
|
| 193 |
+
assert brief.hpi.severity is not None
|
| 194 |
+
assert len(brief.ros) >= 2
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@pytest.mark.asyncio(loop_scope="function")
|
| 198 |
+
async def test_api_response_time(client):
|
| 199 |
+
"""API with MockLLM must respond in under 2 seconds per message."""
|
| 200 |
+
session_id = "test_speed_api"
|
| 201 |
|
| 202 |
+
times = []
|
| 203 |
+
messages = [
|
| 204 |
+
"hello",
|
| 205 |
+
"I have a headache since this morning",
|
| 206 |
+
"It is on the left side of my head",
|
| 207 |
+
]
|
| 208 |
+
for msg in messages:
|
| 209 |
+
t0 = time.perf_counter()
|
| 210 |
+
r = await client.post("/chat", json={"session_id": session_id, "message": msg})
|
| 211 |
+
elapsed = time.perf_counter() - t0
|
| 212 |
+
times.append(elapsed)
|
| 213 |
+
assert r.status_code == 200
|
| 214 |
+
|
| 215 |
+
avg_s = sum(times) / len(times)
|
| 216 |
+
print(f"\n[speed] API avg response: {avg_s*1000:.0f}ms")
|
| 217 |
+
assert avg_s < 2.0, f"API too slow: {avg_s:.2f}s avg (should be <2s in mock mode)"
|