priyansh-saxena1 commited on
Commit
284dfa9
·
1 Parent(s): 808ef75

feat : add dual agent architecture

Browse files
Files changed (5) hide show
  1. app/graph.py +176 -374
  2. app/llm.py +122 -17
  3. app/main.py +2 -2
  4. app/schemas.py +16 -11
  5. tests/test_e2e.py +54 -165
app/graph.py CHANGED
@@ -1,403 +1,200 @@
 
 
1
  from typing import Optional, TypedDict, Annotated
2
  from langgraph.graph import StateGraph, START, END
3
  from langgraph.checkpoint.memory import MemorySaver
4
- import os
5
- import re
6
-
7
- _MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
8
-
9
- SYSTEM_PROMPT = """
10
- You are a clinical intake assistant.
11
-
12
- Rules:
13
- - Ask exactly ONE question at a time
14
- - Keep responses under 20 words
15
- - Be clear and direct
16
- - No explanations unless asked
17
- """
18
 
 
 
19
 
20
- def _ask(prompt: str) -> str:
21
- from app.llm import get_llm
22
- llm = get_llm()
23
- try:
24
- return llm.ask(prompt, system=SYSTEM_PROMPT)
25
- except TypeError:
26
- return llm.ask(prompt)
27
-
28
 
29
  def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
30
  return left + right
31
 
32
-
33
  class IntakeState(TypedDict):
34
  messages: Annotated[list[dict], add_messages]
35
- chief_complaint: str
36
- hpi: dict
37
- ros: dict[str, list[str]]
38
  current_node: str
39
  clinical_brief: Optional[dict]
40
- ros_systems: list[str]
41
- ros_current_index: int
42
- ros_pending_system: Optional[str]
43
- last_processed_message_index: int
44
- vague_retry_field: Optional[str]
45
-
46
-
47
- HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
48
-
49
- # Questions are templated — {cc} will be replaced with chief complaint
50
- HPI_QUESTIONS = {
51
- "onset": "When did {cc} start?",
52
- "location": "Where exactly do you feel {cc}?",
53
- "duration": "Is {cc} constant or does it come and go? How long does each episode last?",
54
- "character": "How would you describe {cc} — sharp, dull, pressure, burning?",
55
- "severity": "On a 1–10 scale, how severe is your {cc} right now?",
56
- "aggravating": "Does anything make {cc} worse, like activity or certain foods?",
57
- "relieving": "What helps relieve your {cc}?"
58
- }
59
-
60
- HPI_FIELD_CONTEXT = {
61
- "onset": "when your symptoms first started",
62
- "location": "where exactly you feel it",
63
- "duration": "how long each episode lasts",
64
- "character": "what the pain feels like",
65
- "severity": "pain severity (1-10)",
66
- "aggravating": "what makes symptoms worse",
67
- "relieving": "what relieves symptoms",
68
- }
69
-
70
- CC_KEYWORDS_TO_ROS = {
71
- "chest": ["cardiac", "respiratory", "gi"],
72
- "pain": ["cardiac", "respiratory", "gi"],
73
- "headache": ["neuro", "ent", "vision"],
74
- "head": ["neuro", "ent", "vision"],
75
- "breath": ["respiratory", "cardiac"],
76
- "shortness": ["respiratory", "cardiac"],
77
- "cough": ["respiratory", "ent"],
78
- "dizzy": ["neuro", "cardiac"],
79
- "nausea": ["gi", "constitutional"],
80
- "vomiting": ["gi", "constitutional"],
81
- }
82
-
83
- DEFAULT_ROS = ["constitutional", "cardiac", "respiratory"]
84
-
85
- ROS_SYSTEM_QUESTIONS = {
86
- "cardiac": "Any palpitations, fluttering, or swelling in your legs or ankles?",
87
- "respiratory": "Any shortness of breath, wheezing, or cough?",
88
- "gi": "Any nausea, vomiting, heartburn, or abdominal pain?",
89
- "neuro": "Any headaches, dizziness, numbness, or vision changes?",
90
- "ent": "Any ear pain, sore throat, or sinus pressure?",
91
- "vision": "Any blurry vision, double vision, or eye pain?",
92
- "constitutional": "Any fever, chills, unexplained weight loss, or fatigue?",
93
- }
94
-
95
-
96
- def get_relevant_ros_systems(cc: str) -> list[str]:
97
- cc_lower = cc.lower()
98
- seen = []
99
- for keyword, systems in CC_KEYWORDS_TO_ROS.items():
100
- if keyword in cc_lower:
101
- for s in systems:
102
- if s not in seen:
103
- seen.append(s)
104
- return seen if seen else DEFAULT_ROS
105
-
106
-
107
- def _fmt_question(field: str, cc: str) -> str:
108
- """Format an HPI question, injecting the chief complaint naturally."""
109
- q = HPI_QUESTIONS[field]
110
- cc_short = cc.split()[0:4] # first few words of complaint
111
- cc_str = " ".join(cc_short).lower() if cc_short else "this"
112
- return q.format(cc=cc_str)
113
-
114
 
115
- def extract_hpi_value(answer: str, field: str) -> str:
116
- answer = answer.strip()
117
- if field == "severity":
118
- match = re.search(r'(\d{1,2})\s*(?:out of|/|over)?\s*10', answer, re.IGNORECASE)
119
- if match:
120
- return f"{match.group(1)}/10"
121
- # also handle bare numbers 1-10
122
- match2 = re.search(r'\b([1-9]|10)\b', answer)
123
- if match2:
124
- return f"{match2.group(1)}/10"
125
- return answer
126
 
 
 
127
 
128
- def _is_vague_answer(answer: str) -> bool:
129
- vague_phrases = ["i don't know", "not sure", "dont know", "idk", "maybe", "i guess", "not really", "not sure"]
130
- return any(phrase in answer.lower() for phrase in vague_phrases)
 
 
 
 
131
 
132
-
133
- def _parse_ros_answer(answer: str) -> list[str]:
134
  """
135
- Parse a free-text ROS answer into a list of individual findings.
136
- Handles comma-separated, 'and'-joined, and 'no X' style negative findings.
137
  """
138
- # Split on commas, semicolons, and 'and'
139
- parts = re.split(r'[,;]|\band\b', answer, flags=re.IGNORECASE)
140
- findings = []
141
- for part in parts:
142
- part = part.strip()
143
- if part:
144
- findings.append(part)
145
- return findings if findings else [answer.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
 
148
  # -------------------- NODES --------------------
149
 
150
- GREETINGS = {"hello", "hi", "hey", "start", "begin", "ok", "okay", "yes", "sure"}
151
-
152
-
153
- def intake_node(state: IntakeState) -> dict:
154
- messages = state.get("messages", [])
155
- last_idx = state.get("last_processed_message_index", 0)
156
- cc = state.get("chief_complaint", "")
157
-
158
- if cc:
159
- return {"current_node": "hpi"}
160
-
161
- has_new_user_msg = len(messages) > last_idx
162
- greeting_reply = "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"
163
-
164
- if has_new_user_msg:
165
- user_msg = next((m for m in messages[last_idx:] if m["role"] == "user"), None)
166
- if user_msg:
167
- content = user_msg["content"].strip()
168
-
169
- if content.lower() in GREETINGS or len(content) <= 4:
170
- return {
171
- "messages": [{"role": "assistant", "content": greeting_reply}],
172
- "chief_complaint": "",
173
- "current_node": "intake",
174
- "last_processed_message_index": len(messages),
175
- "vague_retry_field": None,
176
- }
177
-
178
- cc = content
179
- if _MOCK():
180
- reply = f"Got it — {cc}. I'll ask a few quick questions to document your visit."
181
- else:
182
- reply = _ask(
183
- f"Patient's chief complaint is: '{cc}'. "
184
- "Acknowledge it in one sentence and say you'll ask a few questions."
185
- )
186
  return {
187
- "messages": [{"role": "assistant", "content": reply}],
188
- "chief_complaint": cc,
189
- "current_node": "hpi",
190
- "last_processed_message_index": len(messages),
191
- "vague_retry_field": None,
192
  }
193
-
194
- return {
195
- "messages": [{"role": "assistant", "content": greeting_reply}],
196
- "chief_complaint": "",
197
- "current_node": "intake",
198
- "last_processed_message_index": last_idx,
199
- "vague_retry_field": None,
200
- }
201
 
202
 
203
- def hpi_node(state: IntakeState) -> dict:
204
- messages = state.get("messages", [])
205
- last_idx = state.get("last_processed_message_index", 0)
206
- hpi = dict(state.get("hpi", {}))
207
- vague_retry_field = state.get("vague_retry_field")
208
- cc = state.get("chief_complaint", "")
209
-
210
- next_field = vague_retry_field
211
- if not next_field:
212
- for field in HPI_FIELDS:
213
- if field not in hpi or not hpi.get(field):
214
- next_field = field
215
- break
216
-
217
- if next_field is None:
218
  return {
219
- "messages": [{"role": "assistant", "content": "Thank you. Now I'll ask about a few other symptoms."}],
220
- "current_node": "ros",
221
- "last_processed_message_index": len(messages),
222
- "vague_retry_field": None,
223
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- has_new_user_msg = len(messages) > last_idx
226
-
227
- if has_new_user_msg:
228
- user_msg = next((m for m in messages[last_idx:] if m["role"] == "user"), None)
229
-
230
- if user_msg:
231
- answer = user_msg["content"]
232
-
233
- if _is_vague_answer(answer):
234
- field_context = HPI_FIELD_CONTEXT[next_field]
235
-
236
- if _MOCK():
237
- reply = f"Could you be more specific? I need to know {field_context}."
238
- else:
239
- reply = _ask(
240
- f"Patient response about {field_context} was vague. "
241
- "Ask for clarification in one short sentence."
242
- )
243
-
244
- return {
245
- "messages": [{"role": "assistant", "content": reply}],
246
- "current_node": "hpi",
247
- "last_processed_message_index": last_idx,
248
- "vague_retry_field": next_field,
249
- }
250
-
251
- hpi[next_field] = extract_hpi_value(answer, next_field)
252
-
253
- next_idx = HPI_FIELDS.index(next_field)
254
- if next_idx < len(HPI_FIELDS) - 1:
255
- next_field = HPI_FIELDS[next_idx + 1]
256
-
257
- if _MOCK():
258
- reply = _fmt_question(next_field, cc)
259
- else:
260
- reply = _ask(
261
- f"Complaint: {cc}. Known info: {hpi}. "
262
- f"Ask ONE question about {HPI_FIELD_CONTEXT[next_field]}."
263
- )
264
-
265
- return {
266
- "messages": [{"role": "assistant", "content": reply}],
267
- "hpi": hpi,
268
- "current_node": "hpi",
269
- "last_processed_message_index": len(messages),
270
- "vague_retry_field": None,
271
- }
272
 
273
- return {
274
- "messages": [{"role": "assistant", "content": "Thank you. Now I'll ask about a few other symptoms."}],
275
- "hpi": hpi,
276
- "current_node": "ros",
277
- "last_processed_message_index": len(messages),
278
- "vague_retry_field": None,
279
- }
280
-
281
- if _MOCK():
282
- reply = _fmt_question(next_field, cc)
283
  else:
284
- reply = _ask(
285
- f"Complaint: {cc}. Known info: {hpi}. "
286
- f"Ask ONE question about {HPI_FIELD_CONTEXT[next_field]}."
287
- )
288
-
 
 
 
 
 
 
289
  return {
290
- "messages": [{"role": "assistant", "content": reply}],
291
- "current_node": "hpi",
292
- "last_processed_message_index": last_idx,
293
- "vague_retry_field": None,
294
  }
295
 
296
 
297
- def ros_node(state: IntakeState) -> dict:
298
- messages = state.get("messages", [])
299
- last_idx = state.get("last_processed_message_index", 0)
300
- ros = dict(state.get("ros", {}))
301
- cc = state.get("chief_complaint", "")
302
-
303
- ros_systems = state.get("ros_systems") or get_relevant_ros_systems(cc)
304
- current_idx = state.get("ros_current_index", 0)
305
- pending = state.get("ros_pending_system")
306
-
307
- if current_idx >= len(ros_systems):
308
  return {
309
- "messages": [{"role": "assistant", "content": "Thank you — I have everything I need."}],
310
- "current_node": "brief_generator",
311
- "last_processed_message_index": len(messages),
312
  }
313
-
314
- has_new_user_msg = len(messages) > last_idx
315
-
316
- if has_new_user_msg and pending:
317
- answer = messages[-1]["content"]
318
- ros[pending] = _parse_ros_answer(answer)
319
-
320
- next_system = ros_systems[current_idx]
321
-
322
- if _MOCK():
323
- reply = ROS_SYSTEM_QUESTIONS.get(next_system, f"Any {next_system} symptoms? Mention present and absent.")
324
- else:
325
- reply = _ask(
326
- f"Ask about {next_system} symptoms. One short question. "
327
- "Ask for both present and absent symptoms."
328
- )
329
-
 
 
 
330
  return {
331
  "messages": [{"role": "assistant", "content": reply}],
332
- "ros": ros,
333
- "current_node": "ros",
334
- "ros_systems": ros_systems,
335
- "ros_current_index": current_idx + 1,
336
- "ros_pending_system": next_system,
337
- "last_processed_message_index": len(messages),
338
  }
339
 
340
 
341
- # -------------------- FINAL --------------------
342
-
343
- from datetime import datetime, timezone
344
- from app.schemas import HPI as HPIModel, ClinicalBrief as ClinicalBriefModel
345
-
346
-
347
- def _clean_hpi_value(field: str, raw: str) -> str:
348
- """
349
- Convert a raw patient answer into a clean clinical phrase.
350
- Removes filler words and informal language.
351
- """
352
- raw = raw.strip()
353
-
354
- # Remove filler starters
355
- fillers = [
356
- r'^(yeah|yes|no|well|so|like|um|uh|i mean|i guess),?\s*',
357
- r'^(it\'?s?\s+)',
358
- r'^(the\s+)',
359
- ]
360
- for pattern in fillers:
361
- raw = re.sub(pattern, '', raw, flags=re.IGNORECASE).strip()
362
-
363
- if not raw:
364
- return "not specified"
365
-
366
- # Capitalize first letter
367
- return raw[0].upper() + raw[1:]
368
-
369
-
370
- def brief_generator_node(state: IntakeState) -> dict:
371
- raw_hpi = state.get("hpi", {})
372
-
373
- # Clean each HPI field
374
- cleaned_hpi = {f: _clean_hpi_value(f, raw_hpi.get(f) or "not specified") for f in HPI_FIELDS}
375
-
376
- hpi_obj = HPIModel(**cleaned_hpi)
377
-
378
- # Clean ROS — ensure each system has a proper list of findings
379
- raw_ros = state.get("ros", {})
380
- cleaned_ros: dict[str, list[str]] = {}
381
- for system, findings in raw_ros.items():
382
- clean_findings = []
383
- for f in findings:
384
- f = f.strip()
385
- if f:
386
- # Capitalize
387
- f = f[0].upper() + f[1:]
388
- clean_findings.append(f)
389
- if clean_findings:
390
- cleaned_ros[system] = clean_findings
391
-
392
- brief = ClinicalBriefModel(
393
- chief_complaint=state.get("chief_complaint", ""),
394
- hpi=hpi_obj,
395
- ros=cleaned_ros,
396
  generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
397
  )
398
 
399
  return {
400
- "messages": [{"role": "assistant", "content": "Intake complete. Your clinical summary is ready."}],
401
  "current_node": "done",
402
  "clinical_brief": brief.model_dump(),
403
  }
@@ -406,31 +203,36 @@ def brief_generator_node(state: IntakeState) -> dict:
406
  def build_graph():
407
  workflow = StateGraph(IntakeState)
408
 
409
- workflow.add_node("intake", intake_node)
410
- workflow.add_node("hpi", hpi_node)
411
- workflow.add_node("ros", ros_node)
412
- workflow.add_node("brief_generator", brief_generator_node)
413
-
414
- def route(state: IntakeState) -> str:
415
- return state.get("current_node", "intake")
416
-
417
- workflow.add_edge(START, "intake")
418
-
419
- workflow.add_conditional_edges(
420
- "intake", route, {"intake": "intake", "hpi": "hpi"}
421
- )
422
- workflow.add_conditional_edges(
423
- "hpi", route, {"hpi": "hpi", "ros": "ros"}
424
- )
425
- workflow.add_conditional_edges(
426
- "ros", route, {"ros": "ros", "brief_generator": "brief_generator"}
427
- )
428
- workflow.add_edge("brief_generator", END)
 
 
 
 
429
 
430
  checkpointer = MemorySaver()
 
431
  graph = workflow.compile(
432
  checkpointer=checkpointer,
433
- interrupt_after=["intake", "hpi", "ros"]
434
  )
435
 
436
  return graph, checkpointer
 
1
+ import os
2
+ import json
3
  from typing import Optional, TypedDict, Annotated
4
  from langgraph.graph import StateGraph, START, END
5
  from langgraph.checkpoint.memory import MemorySaver
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from app.llm import get_llm
8
+ from app.schemas import ClinicalStateExtraction, ClinicalBrief, HPI
9
 
10
+ _MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
 
 
 
 
 
 
 
11
 
12
  def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
13
  return left + right
14
 
 
15
  class IntakeState(TypedDict):
16
  messages: Annotated[list[dict], add_messages]
17
+ clinical_state: str # JSON representation of ClinicalStateExtraction
18
+ missing_fields: list[str]
 
19
  current_node: str
20
  clinical_brief: Optional[dict]
21
+ frontend_stage: str # 'intake', 'hpi', 'ros', or 'done'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # -------------------- HELPER FUNCTIONS --------------------
 
 
 
 
 
 
 
 
 
 
24
 
25
+ HPI_REQUIRED = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
26
+ ROS_REQUIRED_COUNT = 3
27
 
28
+ def format_transcript(messages: list[dict]) -> str:
29
+ out = []
30
+ # Only send the last couple of turns to not overwhelm if it's long, but ideally all
31
+ for m in messages:
32
+ role = "AI" if m["role"] == "assistant" else "Patient"
33
+ out.append(f"{role}: {m['content']}")
34
+ return "\n".join(out)
35
 
36
+ def evaluate_missing(state: ClinicalStateExtraction) -> (list[str], str):
 
37
  """
38
+ Returns list of missing fields and the 'frontend_stage' mapped mapping.
 
39
  """
40
+ missing = []
41
+ stage = "intake"
42
+
43
+ if not state.chief_complaint:
44
+ missing.append("chief complaint (reason for visit)")
45
+ return missing, stage
46
+
47
+ stage = "hpi"
48
+ for field in HPI_REQUIRED:
49
+ val = getattr(state.hpi, field)
50
+ if not val or val.lower() == "not specified":
51
+ missing.append(f"HPI: {field}")
52
+
53
+ if missing:
54
+ return missing, stage
55
+
56
+ stage = "ros"
57
+ # Need at least a few systems covered if possible
58
+ if len(state.ros.keys()) < ROS_REQUIRED_COUNT:
59
+ missing.append(f"Review of Systems (ask about {ROS_REQUIRED_COUNT - len(state.ros.keys())} more bodily systems)")
60
+ return missing, stage
61
+
62
+ return [], "done"
63
 
64
 
65
  # -------------------- NODES --------------------
66
 
67
+ def triage_node(state: IntakeState) -> dict:
68
+ msgs = state.get("messages", [])
69
+ if not msgs:
70
+ return {"current_node": "triage"}
71
+
72
+ last_msg = msgs[-1]
73
+ if last_msg["role"] == "user":
74
+ content = last_msg["content"].lower()
75
+ emergencies = ["suicide", "kill myself", "crushing chest pain", "can't breathe", "heart attack"]
76
+ if any(e in content for e in emergencies):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return {
78
+ "messages": [{"role": "assistant", "content": "🚨 EMERGENCY OVERRIDE: Your symptoms sound like a medical emergency. Please call 911 or visit the nearest emergency room immediately."}],
79
+ "current_node": "done",
80
+ "frontend_stage": "done"
 
 
81
  }
82
+
83
+ return {"current_node": "extractor"}
 
 
 
 
 
 
84
 
85
 
86
+ def extractor_node(state: IntakeState) -> dict:
87
+ msgs = state.get("messages", [])
88
+ if not msgs:
89
+ # Initial state setup
 
 
 
 
 
 
 
 
 
 
 
90
  return {
91
+ "clinical_state": ClinicalStateExtraction().model_dump_json(),
92
+ "current_node": "evaluator"
 
 
93
  }
94
+
95
+ # Only run extractor if the last message was from the user
96
+ if msgs[-1]["role"] != "user":
97
+ return {"current_node": "evaluator"}
98
+
99
+ llm = get_llm()
100
+ transcript = format_transcript(msgs)
101
+
102
+ current_state_json = state.get("clinical_state")
103
+ if not current_state_json:
104
+ current_state_json = ClinicalStateExtraction().model_dump_json()
105
+
106
+ # Extractor Agent updates the state passively
107
+ new_state = llm.ask_json(transcript, current_state_json, ClinicalStateExtraction)
108
+
109
+ # Check if the extractor detected a latent emergency
110
+ if new_state.emergency_detected:
111
+ return {
112
+ "messages": [{"role": "assistant", "content": "🚨 EMERGENCY OVERRIDE: Based on your details, you require immediate medical attention. Call 911."}],
113
+ "current_node": "done",
114
+ "frontend_stage": "done",
115
+ "clinical_state": new_state.model_dump_json()
116
+ }
117
+
118
+ return {
119
+ "clinical_state": new_state.model_dump_json(),
120
+ "current_node": "evaluator"
121
+ }
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ def evaluator_node(state: IntakeState) -> dict:
125
+ state_json = state.get("clinical_state")
126
+ if not state_json:
127
+ clinical_state = ClinicalStateExtraction()
 
 
 
 
 
 
128
  else:
129
+ clinical_state = ClinicalStateExtraction.model_validate_json(state_json)
130
+
131
+ missing, stage = evaluate_missing(clinical_state)
132
+
133
+ if not missing:
134
+ return {
135
+ "missing_fields": missing,
136
+ "frontend_stage": "done",
137
+ "current_node": "scribe"
138
+ }
139
+
140
  return {
141
+ "missing_fields": missing,
142
+ "frontend_stage": stage,
143
+ "current_node": "conversationalist"
 
144
  }
145
 
146
 
147
+ def conversationalist_node(state: IntakeState) -> dict:
148
+ msgs = state.get("messages", [])
149
+ clinical_json = state.get("clinical_state", "{}")
150
+ missing = state.get("missing_fields", [])
151
+
152
+ if not msgs:
 
 
 
 
 
153
  return {
154
+ "messages": [{"role": "assistant", "content": "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"}],
155
+ "current_node": "conversationalist"
 
156
  }
157
+
158
+ # Check if the agent just spoke (prevent double-speaking if no user input)
159
+ if msgs[-1]["role"] == "assistant":
160
+ return {"current_node": "conversationalist"}
161
+
162
+ # Dynamic target targeting the top missing field
163
+ target = missing[0] if missing else "general details"
164
+
165
+ system_prompt = (
166
+ "You are an empathetic clinical intake assistant. "
167
+ "Your sole job is to ask the next logical medical question in a conversational way. "
168
+ f"We currently know this info about the patient:\n{clinical_json}\n\n"
169
+ f"YOUR GOAL: You MUST naturally uncover the following missing information: {target}. "
170
+ "Keep your response to exactly ONE question. Be concise and friendly."
171
+ )
172
+
173
+ transcript = format_transcript(msgs[-6:]) # Context window
174
+ llm = get_llm()
175
+ reply = llm.ask(f"Transcript:\n{transcript}\n\nAsk the next question about: {target}.", system=system_prompt)
176
+
177
  return {
178
  "messages": [{"role": "assistant", "content": reply}],
179
+ "current_node": "conversationalist"
 
 
 
 
 
180
  }
181
 
182
 
183
+ def scribe_node(state: IntakeState) -> dict:
184
+ state_json = state.get("clinical_state")
185
+ data = ClinicalStateExtraction.model_validate_json(state_json)
186
+
187
+ from datetime import datetime, timezone
188
+
189
+ brief = ClinicalBrief(
190
+ chief_complaint=data.chief_complaint or "Not specified",
191
+ hpi=data.hpi,
192
+ ros=data.ros,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
194
  )
195
 
196
  return {
197
+ "messages": [{"role": "assistant", "content": "Thank you — I have everything I need. Your clinical summary is ready."}],
198
  "current_node": "done",
199
  "clinical_brief": brief.model_dump(),
200
  }
 
203
  def build_graph():
204
  workflow = StateGraph(IntakeState)
205
 
206
+ workflow.add_node("triage", triage_node)
207
+ workflow.add_node("extractor", extractor_node)
208
+ workflow.add_node("evaluator", evaluator_node)
209
+ workflow.add_node("conversationalist", conversationalist_node)
210
+ workflow.add_node("scribe", scribe_node)
211
+
212
+ def route_triage(state: IntakeState) -> str:
213
+ # If triage marked it 'done' (emergency), skip everything
214
+ return state.get("current_node", "extractor")
215
+
216
+ def route_extractor(state: IntakeState) -> str:
217
+ # Extractor marks it 'done' if latent emergency, else 'evaluator'
218
+ return state.get("current_node", "evaluator")
219
+
220
+ def route_evaluator(state: IntakeState) -> str:
221
+ return state.get("current_node", "conversationalist")
222
+
223
+ workflow.add_edge(START, "triage")
224
+ workflow.add_conditional_edges("triage", route_triage, {"done": END, "extractor": "extractor"})
225
+ workflow.add_conditional_edges("extractor", route_extractor, {"done": END, "evaluator": "evaluator"})
226
+ workflow.add_conditional_edges("evaluator", route_evaluator, {"conversationalist": "conversationalist", "scribe": "scribe"})
227
+
228
+ workflow.add_edge("conversationalist", END)
229
+ workflow.add_edge("scribe", END)
230
 
231
  checkpointer = MemorySaver()
232
+ # Interrupt after conversationalist so it waits for user input
233
  graph = workflow.compile(
234
  checkpointer=checkpointer,
235
+ interrupt_after=["conversationalist"]
236
  )
237
 
238
  return graph, checkpointer
app/llm.py CHANGED
@@ -1,26 +1,79 @@
1
  import os
 
 
2
 
3
  CLINICAL_SYSTEM_PROMPT = (
4
  "You are a clinical intake assistant conducting a pre-visit patient interview. "
5
- "Ask one clear, concise, professional medical question at a time. "
6
  "Do not diagnose or give medical advice. Keep responses under 2 sentences. "
7
- "Be empathetic but professional."
8
  )
9
 
10
-
11
  class MockLLM:
12
  def __init__(self):
13
- self.hpi_fields = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
14
- self.current_hpi_index = 0
15
- self.ros_systems_done = False
16
-
17
- def ask(self, instruction: str) -> str:
18
- return "" # unused in mock mode — graph uses hardcoded questions
19
 
20
- def reset(self):
21
- self.current_hpi_index = 0
22
- self.ros_systems_done = False
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class TransformersLLM:
26
  def __init__(self):
@@ -39,11 +92,11 @@ class TransformersLLM:
39
  device_map="cpu",
40
  )
41
 
42
- def ask(self, instruction: str) -> str:
43
  self._load()
44
  import torch
45
  messages = [
46
- {"role": "system", "content": CLINICAL_SYSTEM_PROMPT},
47
  {"role": "user", "content": instruction},
48
  ]
49
  text = self.tokenizer.apply_chat_template(
@@ -53,8 +106,8 @@ class TransformersLLM:
53
  with torch.no_grad():
54
  outputs = self.model.generate(
55
  **inputs,
56
- max_new_tokens=80,
57
- temperature=0.3,
58
  do_sample=True,
59
  pad_token_id=self.tokenizer.eos_token_id,
60
  )
@@ -64,9 +117,61 @@ class TransformersLLM:
64
  )
65
  return response.strip()
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- _llm_instance = None
69
 
 
70
 
71
  def get_llm():
72
  global _llm_instance
 
1
  import os
2
+ import json
3
+ from pydantic import BaseModel
4
 
5
  CLINICAL_SYSTEM_PROMPT = (
6
  "You are a clinical intake assistant conducting a pre-visit patient interview. "
7
+ "Be empathetic, warm, and highly professional. "
8
  "Do not diagnose or give medical advice. Keep responses under 2 sentences. "
 
9
  )
10
 
 
11
  class MockLLM:
12
  def __init__(self):
13
+ pass
 
 
 
 
 
14
 
15
+ def ask(self, instruction: str, system: str = CLINICAL_SYSTEM_PROMPT) -> str:
16
+ # We will heavily mock the responses in graph.py for tests
17
+ if "empathetic reply" in instruction.lower():
18
+ if "chest" in instruction.lower():
19
+ return "I'm sorry to hear about your chest pain. When did it start?"
20
+ return "I understand. Can you tell me more?"
21
+
22
+ # General fallback that allows tests to check for context
23
+ if "onset" in instruction.lower():
24
+ return "When did this start?"
25
+ elif "severity" in instruction.lower() or "scale" in instruction.lower():
26
+ return "On a scale of 1 to 10, how severe is this?"
27
+ elif "location" in instruction.lower():
28
+ return "Where exactly do you feel this?"
29
+
30
+ return "Can you elaborate on that?"
31
 
32
+ def ask_json(self, transcript: str, current_state: str, schema_cls: type[BaseModel]) -> BaseModel:
33
+ # Mocking extraction logic for deterministic testing
34
+ t_low = transcript.lower()
35
+ state_dict = json.loads(current_state)
36
+
37
+ # very basic test logic
38
+ if "chest pain" in t_low:
39
+ state_dict["chief_complaint"] = "chest pain"
40
+ if "yesterday" in t_low or "morning" in t_low:
41
+ if not state_dict.get("hpi"): state_dict["hpi"] = {}
42
+ state_dict["hpi"]["onset"] = "this morning" if "morning" in t_low else "yesterday"
43
+ if "center" in t_low:
44
+ if not state_dict.get("hpi"): state_dict["hpi"] = {}
45
+ state_dict["hpi"]["location"] = "center of chest"
46
+ if "constant" in t_low:
47
+ if not state_dict.get("hpi"): state_dict["hpi"] = {}
48
+ state_dict["hpi"]["duration"] = "constant"
49
+ if "pressure" in t_low or "tight" in t_low:
50
+ if not state_dict.get("hpi"): state_dict["hpi"] = {}
51
+ state_dict["hpi"]["character"] = "tight pressure"
52
+ if "7" in t_low or "seven" in t_low:
53
+ if not state_dict.get("hpi"): state_dict["hpi"] = {}
54
+ state_dict["hpi"]["severity"] = "7/10"
55
+ if "walk" in t_low or "running" in t_low:
56
+ if not state_dict.get("hpi"): state_dict["hpi"] = {}
57
+ state_dict["hpi"]["aggravating"] = "walking"
58
+ if "rest" in t_low:
59
+ if not state_dict.get("hpi"): state_dict["hpi"] = {}
60
+ state_dict["hpi"]["relieving"] = "resting"
61
+
62
+ if "palpitations" in t_low:
63
+ if not state_dict.get("ros"): state_dict["ros"] = {}
64
+ state_dict["ros"]["cardiac"] = ["palpitations", "no syncope"]
65
+ if "breath" in t_low:
66
+ if not state_dict.get("ros"): state_dict["ros"] = {}
67
+ state_dict["ros"]["respiratory"] = ["shortness of breath", "no cough"]
68
+ if "nausea" in t_low:
69
+ if not state_dict.get("ros"): state_dict["ros"] = {}
70
+ state_dict["ros"]["gi"] = ["no nausea"]
71
+
72
+ if "crushing chest pain" in t_low or "heart attack" in t_low or "emergency" in t_low:
73
+ state_dict["emergency_detected"] = True
74
+
75
+ # Guarantee schema matches via Pydantic model_validate
76
+ return schema_cls.model_validate(state_dict)
77
 
78
  class TransformersLLM:
79
  def __init__(self):
 
92
  device_map="cpu",
93
  )
94
 
95
+ def ask(self, instruction: str, system: str = CLINICAL_SYSTEM_PROMPT) -> str:
96
  self._load()
97
  import torch
98
  messages = [
99
+ {"role": "system", "content": system},
100
  {"role": "user", "content": instruction},
101
  ]
102
  text = self.tokenizer.apply_chat_template(
 
106
  with torch.no_grad():
107
  outputs = self.model.generate(
108
  **inputs,
109
+ max_new_tokens=100,
110
+ temperature=0.4,
111
  do_sample=True,
112
  pad_token_id=self.tokenizer.eos_token_id,
113
  )
 
117
  )
118
  return response.strip()
119
 
120
+ def ask_json(self, transcript: str, current_state: str, schema_cls: type[BaseModel]) -> BaseModel:
121
+ self._load()
122
+ import torch
123
+
124
+ system = (
125
+ "You are a clinical data extraction engine. "
126
+ "Your objective is to read the patient transcript and output exactly a valid JSON document "
127
+ "that matches the requested schema. Extract all relevant medical facts you can find. "
128
+ "Merge new facts into the existing state."
129
+ )
130
+ instruction = (
131
+ f"CURRENT STATE JSON (Update this based on the transcript):\n{current_state}\n\n"
132
+ f"TRANSCRIPT:\n{transcript}\n\n"
133
+ f"Output ONLY valid JSON matching this schema structure:\n"
134
+ f"{schema_cls.model_json_schema()}"
135
+ )
136
+
137
+ messages = [
138
+ {"role": "system", "content": system},
139
+ {"role": "user", "content": instruction},
140
+ ]
141
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
142
+ inputs = self.tokenizer(text, return_tensors="pt")
143
+ with torch.no_grad():
144
+ outputs = self.model.generate(
145
+ **inputs,
146
+ max_new_tokens=400,
147
+ temperature=0.1, # Keep low for JSON determinism
148
+ do_sample=False,
149
+ pad_token_id=self.tokenizer.eos_token_id,
150
+ )
151
+ response = self.tokenizer.decode(
152
+ outputs[0][inputs.input_ids.shape[1]:],
153
+ skip_special_tokens=True,
154
+ )
155
+
156
+ # Attempt to parse json from output
157
+ json_str = response.strip()
158
+ if "```json" in json_str:
159
+ json_str = json_str.split("```json")[-1].split("```")[0]
160
+ elif "```" in json_str:
161
+ json_str = json_str.split("```")[-1].split("```")[0]
162
+
163
+ try:
164
+ parsed = json.loads(json_str)
165
+ return schema_cls.model_validate(parsed)
166
+ except Exception:
167
+ # Fallback to current state if extraction fails (avoids crashing)
168
+ try:
169
+ return schema_cls.model_validate_json(current_state)
170
+ except Exception:
171
+ return schema_cls()
172
 
 
173
 
174
+ _llm_instance = None
175
 
176
  def get_llm():
177
  global _llm_instance
app/main.py CHANGED
@@ -36,12 +36,12 @@ graph, checkpointer = build_graph()
36
 
37
 
38
  def get_current_node(session_id: str) -> str:
39
- """Get current node from checkpoint."""
40
  config = {"configurable": {"thread_id": session_id}}
41
  try:
42
  snapshot = graph.get_state(config)
43
  if snapshot and snapshot.values:
44
- return snapshot.values.get("current_node", "intake")
45
  except Exception:
46
  pass
47
  return "intake"
 
36
 
37
 
38
  def get_current_node(session_id: str) -> str:
39
+ """Get frontend stage from checkpoint."""
40
  config = {"configurable": {"thread_id": session_id}}
41
  try:
42
  snapshot = graph.get_state(config)
43
  if snapshot and snapshot.values:
44
+ return snapshot.values.get("frontend_stage", "intake")
45
  except Exception:
46
  pass
47
  return "intake"
app/schemas.py CHANGED
@@ -1,18 +1,23 @@
1
- from pydantic import BaseModel
2
-
3
 
4
  class HPI(BaseModel):
5
- onset: str
6
- location: str
7
- duration: str
8
- character: str
9
- severity: str
10
- aggravating: str
11
- relieving: str
12
-
13
 
14
  class ClinicalBrief(BaseModel):
15
  chief_complaint: str
16
  hpi: HPI
17
- ros: dict[str, list[str]]
18
  generated_at: str
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, List
2
+ from pydantic import BaseModel, Field
3
 
4
  class HPI(BaseModel):
5
+ onset: Optional[str] = Field(None, description="When the symptom started")
6
+ location: Optional[str] = Field(None, description="Where exactly the symptom is located")
7
+ duration: Optional[str] = Field(None, description="How long episodes last or if it is constant")
8
+ character: Optional[str] = Field(None, description="What the pain feels like (sharp, dull, pressure, etc.)")
9
+ severity: Optional[str] = Field(None, description="Pain scale severity (e.g., 7/10 or 'severe')")
10
+ aggravating: Optional[str] = Field(None, description="What makes the symptoms worse")
11
+ relieving: Optional[str] = Field(None, description="What helps relieve the symptoms")
 
12
 
13
  class ClinicalBrief(BaseModel):
14
  chief_complaint: str
15
  hpi: HPI
16
+ ros: Dict[str, List[str]]
17
  generated_at: str
18
+
19
+ class ClinicalStateExtraction(BaseModel):
20
+ chief_complaint: Optional[str] = Field(None, description="The main reason for the visit")
21
+ hpi: HPI = Field(default_factory=HPI)
22
+ ros: Dict[str, List[str]] = Field(default_factory=dict, description="Review of systems, keys are system names, values are list of findings (positive or negative)")
23
+ emergency_detected: bool = Field(False, description="True ONLY if the patient mentions life-threatening symptoms requiring immediate 911/ER like severe crushing chest pain radiating to jaw, active severe bleeding, or suicidal ideation")
tests/test_e2e.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
-
3
  os.environ["MOCK_LLM"] = "true"
4
 
5
  import pytest
6
  from httpx import AsyncClient, ASGITransport
7
 
8
  from app.main import app
9
-
10
 
11
  @pytest.fixture
12
  async def client():
@@ -14,7 +13,6 @@ async def client():
14
  async with AsyncClient(transport=transport, base_url="http://test") as c:
15
  yield c
16
 
17
-
18
  @pytest.mark.asyncio(loop_scope="function")
19
  async def test_health_endpoint(client):
20
  response = await client.get("/health")
@@ -23,171 +21,62 @@ async def test_health_endpoint(client):
23
  assert data["status"] == "ok"
24
  assert data["mock_mode"] is True
25
 
26
-
27
  @pytest.mark.asyncio(loop_scope="function")
28
- async def test_full_intake_flow(client):
29
- session_id = "test1"
30
-
31
- response = await client.post("/chat", json={"session_id": session_id, "message": "hello"})
 
 
 
32
  assert response.status_code == 200
33
  data = response.json()
34
- assert data["reply"]
35
- assert data["state"] in ["intake", "hpi"]
36
-
37
- responses = [
38
- "I have chest pain since this morning", # CC (intake)
39
- "It started about 3 hours ago", # onset
40
- "In the center of my chest", # location
41
- "It has been constant for an hour", # duration
42
- "It feels like pressure", # character
43
- "About a 7 out of 10", # severity
44
- "It gets worse when I walk", # aggravating
45
- "Resting helps a little", # relieving
46
- "palpitations present, no syncope", # cardiac ROS
47
- "mild shortness of breath, no cough", # respiratory ROS
48
- "no nausea or vomiting", # gi ROS
49
- ]
50
-
51
- final_data = None
52
- for resp_text in responses:
53
- response = await client.post("/chat", json={"session_id": session_id, "message": resp_text})
54
- assert response.status_code == 200
55
- final_data = response.json()
56
-
57
- assert final_data is not None
58
- assert final_data["state"] == "done"
59
- assert "brief" in final_data
60
- assert final_data["brief"] is not None
61
-
62
- brief = final_data["brief"]
63
- assert "chief_complaint" in brief
64
- assert "hpi" in brief
65
- assert "ros" in brief
66
-
67
 
68
  @pytest.mark.asyncio(loop_scope="function")
69
- async def test_hpi_reprompt(client):
70
- """Vague answers (I don't know) should trigger a re-prompt."""
71
- session_id = "test_vague"
72
-
73
- await client.post("/chat", json={"session_id": session_id, "message": "hello"})
74
- await client.post("/chat", json={"session_id": session_id, "message": "I have chest pain"})
75
-
76
- # First HPI question is about onset
77
- vague_response = await client.post("/chat", json={"session_id": session_id, "message": "I don't know"})
78
- assert vague_response.status_code == 200
79
- data = vague_response.json()
80
- reply_lower = data["reply"].lower()
81
- # Should ask again — should mention specificity or the field context
82
- assert "specific" in reply_lower or "when" in reply_lower or "start" in reply_lower
83
-
84
-
85
- @pytest.mark.asyncio(loop_scope="function")
86
- async def test_ros_scoping(client):
87
- """For chest pain, ROS should include cardiac and respiratory systems."""
88
- session_id = "test_chest_pain"
89
-
90
  await client.post("/chat", json={"session_id": session_id, "message": "hello"})
91
- await client.post("/chat", json={"session_id": session_id, "message": "I have chest pain"})
92
-
93
- hpi_responses = [
94
- "It started 3 hours ago",
95
- "In the center of my chest",
96
- "It has been constant",
97
- "It feels like pressure",
98
- "7 out of 10",
99
- "Walking makes it worse",
100
- "Resting helps",
101
- ]
102
-
103
- for resp in hpi_responses:
104
- await client.post("/chat", json={"session_id": session_id, "message": resp})
105
-
106
- # Now in ROS answer cardiac system
107
- await client.post("/chat", json={"session_id": session_id, "message": "palpitations, no syncope"})
108
- # respiratory
109
- await client.post("/chat", json={"session_id": session_id, "message": "mild shortness of breath, no cough"})
110
- # gi
111
- final_response = await client.post("/chat", json={"session_id": session_id, "message": "no nausea"})
112
- final_data = final_response.json()
113
-
114
- if final_data.get("brief"):
115
- ros_keys = list(final_data["brief"]["ros"].keys())
116
- assert "cardiac" in ros_keys or "respiratory" in ros_keys
117
-
118
-
119
- @pytest.mark.asyncio(loop_scope="function")
120
- async def test_brief_structure(client):
121
- """Brief should have all 7 HPI fields, chief_complaint, ros, and generated_at."""
122
- session_id = "test_brief"
123
-
124
- messages = [
125
- "hello",
126
- "I have chest pain",
127
- "It started 3 hours ago",
128
- "In the center of my chest",
129
- "Constant",
130
- "Pressure-like",
131
- "7 out of 10",
132
- "Walking worsens it",
133
- "Resting helps",
134
- "palpitations present, no syncope",
135
- "shortness of breath, no cough",
136
- "no nausea or vomiting",
137
- ]
138
-
139
- response = None
140
- for msg in messages:
141
- response = await client.post("/chat", json={"session_id": session_id, "message": msg})
142
- assert response.status_code == 200
143
-
144
- final_data = response.json()
145
-
146
- if final_data.get("brief"):
147
- brief = final_data["brief"]
148
- from app.schemas import ClinicalBrief
149
- validated = ClinicalBrief.model_validate(brief)
150
-
151
- assert validated.chief_complaint
152
- assert validated.hpi.onset
153
- assert validated.hpi.location
154
- assert validated.hpi.duration
155
- assert validated.hpi.character
156
- assert validated.hpi.severity
157
- assert validated.hpi.aggravating
158
- assert validated.hpi.relieving
159
- assert validated.generated_at
160
-
161
-
162
- @pytest.mark.asyncio(loop_scope="function")
163
- async def test_brief_cleaning(client):
164
- """Brief generator should strip informal filler words from patient answers."""
165
- session_id = "test_cleaning"
166
-
167
- messages = [
168
- "hello",
169
- "I have chest pain",
170
- "yeah like since yesterday evening", # filler "yeah like"
171
- "like in my chest area", # filler "like"
172
- "Constant",
173
- "um tight and squeezing", # filler "um"
174
- "7 out of 10",
175
- "yeah walking makes it worse", # filler "yeah"
176
- "Resting helps",
177
- "palpitations, no syncope",
178
- "mild shortness of breath",
179
- "no nausea",
180
- ]
181
-
182
- response = None
183
- for msg in messages:
184
- response = await client.post("/chat", json={"session_id": session_id, "message": msg})
185
- assert response.status_code == 200
186
-
187
- final_data = response.json()
188
- if final_data.get("brief"):
189
- hpi = final_data["brief"]["hpi"]
190
- # "yeah like since yesterday evening" → should not start with "yeah"
191
- if hpi.get("onset"):
192
- assert not hpi["onset"].lower().startswith("yeah"), \
193
- f"Filler not cleaned from onset: {hpi['onset']}"
 
1
  import os
 
2
  os.environ["MOCK_LLM"] = "true"
3
 
4
  import pytest
5
  from httpx import AsyncClient, ASGITransport
6
 
7
  from app.main import app
8
+ from app.schemas import ClinicalBrief
9
 
10
  @pytest.fixture
11
  async def client():
 
13
  async with AsyncClient(transport=transport, base_url="http://test") as c:
14
  yield c
15
 
 
16
  @pytest.mark.asyncio(loop_scope="function")
17
  async def test_health_endpoint(client):
18
  response = await client.get("/health")
 
21
  assert data["status"] == "ok"
22
  assert data["mock_mode"] is True
23
 
 
24
  @pytest.mark.asyncio(loop_scope="function")
25
+ async def test_emergency_triage_guardrail(client):
26
+ """If user types 'crushing chest pain', the triage node should immediately abort to 'done'."""
27
+ session_id = "test_emergency"
28
+
29
+ await client.post("/chat", json={"session_id": session_id, "message": "hello"})
30
+
31
+ response = await client.post("/chat", json={"session_id": session_id, "message": "I am having crushing chest pain"})
32
  assert response.status_code == 200
33
  data = response.json()
34
+
35
+ assert data["state"] == "done"
36
+ assert "911" in data["reply"] or "emergency" in data["reply"].lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  @pytest.mark.asyncio(loop_scope="function")
39
+ async def test_shadow_extractor_logic(client):
40
+ """
41
+ Test that the shadow extractor gracefully fills in missing information behind the scenes,
42
+ transitioning the frontend stage from hpi to ros and finally done.
43
+ """
44
+ session_id = "test_extraction"
45
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  await client.post("/chat", json={"session_id": session_id, "message": "hello"})
47
+
48
+ # 1. Chief Complaint & some HPI
49
+ # The mock LLM maps "chest pain" -> CC, "yesterday" -> onset
50
+ res = await client.post("/chat", json={"session_id": session_id, "message": "I have chest pain since yesterday"})
51
+ assert res.status_code == 200
52
+ data = res.json()
53
+ assert data["state"] == "hpi" # Needs more HPI info
54
+
55
+ # 2. More HPI info
56
+ res = await client.post("/chat", json={"session_id": session_id, "message": "It is constant pressure in the center. Severity is 7. Walking makes it worse, rest helps."})
57
+ assert res.status_code == 200
58
+ data = res.json()
59
+ assert data["state"] == "ros" # Completes HPI, moves to ROS
60
+
61
+ # 3. ROS info
62
+ res = await client.post("/chat", json={"session_id": session_id, "message": "I have palpitations and shortness of breath. No nausea."})
63
+ assert res.status_code == 200
64
+ data = res.json()
65
+
66
+ # Should be done
67
+ assert data["state"] == "done"
68
+ assert data["brief"] is not None
69
+
70
+ brief = ClinicalBrief.model_validate(data["brief"])
71
+ assert brief.chief_complaint == "chest pain"
72
+ assert brief.hpi.onset == "yesterday"
73
+ assert brief.hpi.location == "center of chest"
74
+ assert brief.hpi.duration == "constant"
75
+ assert brief.hpi.character == "tight pressure"
76
+ assert brief.hpi.severity == "7/10"
77
+ assert brief.hpi.aggravating == "walking"
78
+ assert brief.hpi.relieving == "resting"
79
+
80
+ assert "cardiac" in brief.ros
81
+ assert "respiratory" in brief.ros
82
+ assert "gi" in brief.ros