Spaces:
Sleeping
Sleeping
priyansh-saxena1 commited on
Commit ·
284dfa9
1
Parent(s): 808ef75
feat : add dual agent architecture
Browse files- app/graph.py +176 -374
- app/llm.py +122 -17
- app/main.py +2 -2
- app/schemas.py +16 -11
- tests/test_e2e.py +54 -165
app/graph.py
CHANGED
|
@@ -1,403 +1,200 @@
|
|
|
|
|
|
|
|
| 1 |
from typing import Optional, TypedDict, Annotated
|
| 2 |
from langgraph.graph import StateGraph, START, END
|
| 3 |
from langgraph.checkpoint.memory import MemorySaver
|
| 4 |
-
import os
|
| 5 |
-
import re
|
| 6 |
-
|
| 7 |
-
_MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
|
| 8 |
-
|
| 9 |
-
SYSTEM_PROMPT = """
|
| 10 |
-
You are a clinical intake assistant.
|
| 11 |
-
|
| 12 |
-
Rules:
|
| 13 |
-
- Ask exactly ONE question at a time
|
| 14 |
-
- Keep responses under 20 words
|
| 15 |
-
- Be clear and direct
|
| 16 |
-
- No explanations unless asked
|
| 17 |
-
"""
|
| 18 |
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
from app.llm import get_llm
|
| 22 |
-
llm = get_llm()
|
| 23 |
-
try:
|
| 24 |
-
return llm.ask(prompt, system=SYSTEM_PROMPT)
|
| 25 |
-
except TypeError:
|
| 26 |
-
return llm.ask(prompt)
|
| 27 |
-
|
| 28 |
|
| 29 |
def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
|
| 30 |
return left + right
|
| 31 |
|
| 32 |
-
|
| 33 |
class IntakeState(TypedDict):
|
| 34 |
messages: Annotated[list[dict], add_messages]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
ros: dict[str, list[str]]
|
| 38 |
current_node: str
|
| 39 |
clinical_brief: Optional[dict]
|
| 40 |
-
|
| 41 |
-
ros_current_index: int
|
| 42 |
-
ros_pending_system: Optional[str]
|
| 43 |
-
last_processed_message_index: int
|
| 44 |
-
vague_retry_field: Optional[str]
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
|
| 48 |
-
|
| 49 |
-
# Questions are templated — {cc} will be replaced with chief complaint
|
| 50 |
-
HPI_QUESTIONS = {
|
| 51 |
-
"onset": "When did {cc} start?",
|
| 52 |
-
"location": "Where exactly do you feel {cc}?",
|
| 53 |
-
"duration": "Is {cc} constant or does it come and go? How long does each episode last?",
|
| 54 |
-
"character": "How would you describe {cc} — sharp, dull, pressure, burning?",
|
| 55 |
-
"severity": "On a 1–10 scale, how severe is your {cc} right now?",
|
| 56 |
-
"aggravating": "Does anything make {cc} worse, like activity or certain foods?",
|
| 57 |
-
"relieving": "What helps relieve your {cc}?"
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
HPI_FIELD_CONTEXT = {
|
| 61 |
-
"onset": "when your symptoms first started",
|
| 62 |
-
"location": "where exactly you feel it",
|
| 63 |
-
"duration": "how long each episode lasts",
|
| 64 |
-
"character": "what the pain feels like",
|
| 65 |
-
"severity": "pain severity (1-10)",
|
| 66 |
-
"aggravating": "what makes symptoms worse",
|
| 67 |
-
"relieving": "what relieves symptoms",
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
CC_KEYWORDS_TO_ROS = {
|
| 71 |
-
"chest": ["cardiac", "respiratory", "gi"],
|
| 72 |
-
"pain": ["cardiac", "respiratory", "gi"],
|
| 73 |
-
"headache": ["neuro", "ent", "vision"],
|
| 74 |
-
"head": ["neuro", "ent", "vision"],
|
| 75 |
-
"breath": ["respiratory", "cardiac"],
|
| 76 |
-
"shortness": ["respiratory", "cardiac"],
|
| 77 |
-
"cough": ["respiratory", "ent"],
|
| 78 |
-
"dizzy": ["neuro", "cardiac"],
|
| 79 |
-
"nausea": ["gi", "constitutional"],
|
| 80 |
-
"vomiting": ["gi", "constitutional"],
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
DEFAULT_ROS = ["constitutional", "cardiac", "respiratory"]
|
| 84 |
-
|
| 85 |
-
ROS_SYSTEM_QUESTIONS = {
|
| 86 |
-
"cardiac": "Any palpitations, fluttering, or swelling in your legs or ankles?",
|
| 87 |
-
"respiratory": "Any shortness of breath, wheezing, or cough?",
|
| 88 |
-
"gi": "Any nausea, vomiting, heartburn, or abdominal pain?",
|
| 89 |
-
"neuro": "Any headaches, dizziness, numbness, or vision changes?",
|
| 90 |
-
"ent": "Any ear pain, sore throat, or sinus pressure?",
|
| 91 |
-
"vision": "Any blurry vision, double vision, or eye pain?",
|
| 92 |
-
"constitutional": "Any fever, chills, unexplained weight loss, or fatigue?",
|
| 93 |
-
}
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def get_relevant_ros_systems(cc: str) -> list[str]:
|
| 97 |
-
cc_lower = cc.lower()
|
| 98 |
-
seen = []
|
| 99 |
-
for keyword, systems in CC_KEYWORDS_TO_ROS.items():
|
| 100 |
-
if keyword in cc_lower:
|
| 101 |
-
for s in systems:
|
| 102 |
-
if s not in seen:
|
| 103 |
-
seen.append(s)
|
| 104 |
-
return seen if seen else DEFAULT_ROS
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def _fmt_question(field: str, cc: str) -> str:
|
| 108 |
-
"""Format an HPI question, injecting the chief complaint naturally."""
|
| 109 |
-
q = HPI_QUESTIONS[field]
|
| 110 |
-
cc_short = cc.split()[0:4] # first few words of complaint
|
| 111 |
-
cc_str = " ".join(cc_short).lower() if cc_short else "this"
|
| 112 |
-
return q.format(cc=cc_str)
|
| 113 |
-
|
| 114 |
|
| 115 |
-
|
| 116 |
-
answer = answer.strip()
|
| 117 |
-
if field == "severity":
|
| 118 |
-
match = re.search(r'(\d{1,2})\s*(?:out of|/|over)?\s*10', answer, re.IGNORECASE)
|
| 119 |
-
if match:
|
| 120 |
-
return f"{match.group(1)}/10"
|
| 121 |
-
# also handle bare numbers 1-10
|
| 122 |
-
match2 = re.search(r'\b([1-9]|10)\b', answer)
|
| 123 |
-
if match2:
|
| 124 |
-
return f"{match2.group(1)}/10"
|
| 125 |
-
return answer
|
| 126 |
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
def
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
def _parse_ros_answer(answer: str) -> list[str]:
|
| 134 |
"""
|
| 135 |
-
|
| 136 |
-
Handles comma-separated, 'and'-joined, and 'no X' style negative findings.
|
| 137 |
"""
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
# -------------------- NODES --------------------
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
has_new_user_msg = len(messages) > last_idx
|
| 162 |
-
greeting_reply = "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"
|
| 163 |
-
|
| 164 |
-
if has_new_user_msg:
|
| 165 |
-
user_msg = next((m for m in messages[last_idx:] if m["role"] == "user"), None)
|
| 166 |
-
if user_msg:
|
| 167 |
-
content = user_msg["content"].strip()
|
| 168 |
-
|
| 169 |
-
if content.lower() in GREETINGS or len(content) <= 4:
|
| 170 |
-
return {
|
| 171 |
-
"messages": [{"role": "assistant", "content": greeting_reply}],
|
| 172 |
-
"chief_complaint": "",
|
| 173 |
-
"current_node": "intake",
|
| 174 |
-
"last_processed_message_index": len(messages),
|
| 175 |
-
"vague_retry_field": None,
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
cc = content
|
| 179 |
-
if _MOCK():
|
| 180 |
-
reply = f"Got it — {cc}. I'll ask a few quick questions to document your visit."
|
| 181 |
-
else:
|
| 182 |
-
reply = _ask(
|
| 183 |
-
f"Patient's chief complaint is: '{cc}'. "
|
| 184 |
-
"Acknowledge it in one sentence and say you'll ask a few questions."
|
| 185 |
-
)
|
| 186 |
return {
|
| 187 |
-
"messages": [{"role": "assistant", "content":
|
| 188 |
-
"
|
| 189 |
-
"
|
| 190 |
-
"last_processed_message_index": len(messages),
|
| 191 |
-
"vague_retry_field": None,
|
| 192 |
}
|
| 193 |
-
|
| 194 |
-
return {
|
| 195 |
-
"messages": [{"role": "assistant", "content": greeting_reply}],
|
| 196 |
-
"chief_complaint": "",
|
| 197 |
-
"current_node": "intake",
|
| 198 |
-
"last_processed_message_index": last_idx,
|
| 199 |
-
"vague_retry_field": None,
|
| 200 |
-
}
|
| 201 |
|
| 202 |
|
| 203 |
-
def
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
vague_retry_field = state.get("vague_retry_field")
|
| 208 |
-
cc = state.get("chief_complaint", "")
|
| 209 |
-
|
| 210 |
-
next_field = vague_retry_field
|
| 211 |
-
if not next_field:
|
| 212 |
-
for field in HPI_FIELDS:
|
| 213 |
-
if field not in hpi or not hpi.get(field):
|
| 214 |
-
next_field = field
|
| 215 |
-
break
|
| 216 |
-
|
| 217 |
-
if next_field is None:
|
| 218 |
return {
|
| 219 |
-
"
|
| 220 |
-
"current_node": "
|
| 221 |
-
"last_processed_message_index": len(messages),
|
| 222 |
-
"vague_retry_field": None,
|
| 223 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
-
has_new_user_msg = len(messages) > last_idx
|
| 226 |
-
|
| 227 |
-
if has_new_user_msg:
|
| 228 |
-
user_msg = next((m for m in messages[last_idx:] if m["role"] == "user"), None)
|
| 229 |
-
|
| 230 |
-
if user_msg:
|
| 231 |
-
answer = user_msg["content"]
|
| 232 |
-
|
| 233 |
-
if _is_vague_answer(answer):
|
| 234 |
-
field_context = HPI_FIELD_CONTEXT[next_field]
|
| 235 |
-
|
| 236 |
-
if _MOCK():
|
| 237 |
-
reply = f"Could you be more specific? I need to know {field_context}."
|
| 238 |
-
else:
|
| 239 |
-
reply = _ask(
|
| 240 |
-
f"Patient response about {field_context} was vague. "
|
| 241 |
-
"Ask for clarification in one short sentence."
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
return {
|
| 245 |
-
"messages": [{"role": "assistant", "content": reply}],
|
| 246 |
-
"current_node": "hpi",
|
| 247 |
-
"last_processed_message_index": last_idx,
|
| 248 |
-
"vague_retry_field": next_field,
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
hpi[next_field] = extract_hpi_value(answer, next_field)
|
| 252 |
-
|
| 253 |
-
next_idx = HPI_FIELDS.index(next_field)
|
| 254 |
-
if next_idx < len(HPI_FIELDS) - 1:
|
| 255 |
-
next_field = HPI_FIELDS[next_idx + 1]
|
| 256 |
-
|
| 257 |
-
if _MOCK():
|
| 258 |
-
reply = _fmt_question(next_field, cc)
|
| 259 |
-
else:
|
| 260 |
-
reply = _ask(
|
| 261 |
-
f"Complaint: {cc}. Known info: {hpi}. "
|
| 262 |
-
f"Ask ONE question about {HPI_FIELD_CONTEXT[next_field]}."
|
| 263 |
-
)
|
| 264 |
-
|
| 265 |
-
return {
|
| 266 |
-
"messages": [{"role": "assistant", "content": reply}],
|
| 267 |
-
"hpi": hpi,
|
| 268 |
-
"current_node": "hpi",
|
| 269 |
-
"last_processed_message_index": len(messages),
|
| 270 |
-
"vague_retry_field": None,
|
| 271 |
-
}
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
"last_processed_message_index": len(messages),
|
| 278 |
-
"vague_retry_field": None,
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
if _MOCK():
|
| 282 |
-
reply = _fmt_question(next_field, cc)
|
| 283 |
else:
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
return {
|
| 290 |
-
"
|
| 291 |
-
"
|
| 292 |
-
"
|
| 293 |
-
"vague_retry_field": None,
|
| 294 |
}
|
| 295 |
|
| 296 |
|
| 297 |
-
def
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
ros_systems = state.get("ros_systems") or get_relevant_ros_systems(cc)
|
| 304 |
-
current_idx = state.get("ros_current_index", 0)
|
| 305 |
-
pending = state.get("ros_pending_system")
|
| 306 |
-
|
| 307 |
-
if current_idx >= len(ros_systems):
|
| 308 |
return {
|
| 309 |
-
"messages": [{"role": "assistant", "content": "
|
| 310 |
-
"current_node": "
|
| 311 |
-
"last_processed_message_index": len(messages),
|
| 312 |
}
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
| 330 |
return {
|
| 331 |
"messages": [{"role": "assistant", "content": reply}],
|
| 332 |
-
"
|
| 333 |
-
"current_node": "ros",
|
| 334 |
-
"ros_systems": ros_systems,
|
| 335 |
-
"ros_current_index": current_idx + 1,
|
| 336 |
-
"ros_pending_system": next_system,
|
| 337 |
-
"last_processed_message_index": len(messages),
|
| 338 |
}
|
| 339 |
|
| 340 |
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
"""
|
| 352 |
-
raw = raw.strip()
|
| 353 |
-
|
| 354 |
-
# Remove filler starters
|
| 355 |
-
fillers = [
|
| 356 |
-
r'^(yeah|yes|no|well|so|like|um|uh|i mean|i guess),?\s*',
|
| 357 |
-
r'^(it\'?s?\s+)',
|
| 358 |
-
r'^(the\s+)',
|
| 359 |
-
]
|
| 360 |
-
for pattern in fillers:
|
| 361 |
-
raw = re.sub(pattern, '', raw, flags=re.IGNORECASE).strip()
|
| 362 |
-
|
| 363 |
-
if not raw:
|
| 364 |
-
return "not specified"
|
| 365 |
-
|
| 366 |
-
# Capitalize first letter
|
| 367 |
-
return raw[0].upper() + raw[1:]
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
def brief_generator_node(state: IntakeState) -> dict:
|
| 371 |
-
raw_hpi = state.get("hpi", {})
|
| 372 |
-
|
| 373 |
-
# Clean each HPI field
|
| 374 |
-
cleaned_hpi = {f: _clean_hpi_value(f, raw_hpi.get(f) or "not specified") for f in HPI_FIELDS}
|
| 375 |
-
|
| 376 |
-
hpi_obj = HPIModel(**cleaned_hpi)
|
| 377 |
-
|
| 378 |
-
# Clean ROS — ensure each system has a proper list of findings
|
| 379 |
-
raw_ros = state.get("ros", {})
|
| 380 |
-
cleaned_ros: dict[str, list[str]] = {}
|
| 381 |
-
for system, findings in raw_ros.items():
|
| 382 |
-
clean_findings = []
|
| 383 |
-
for f in findings:
|
| 384 |
-
f = f.strip()
|
| 385 |
-
if f:
|
| 386 |
-
# Capitalize
|
| 387 |
-
f = f[0].upper() + f[1:]
|
| 388 |
-
clean_findings.append(f)
|
| 389 |
-
if clean_findings:
|
| 390 |
-
cleaned_ros[system] = clean_findings
|
| 391 |
-
|
| 392 |
-
brief = ClinicalBriefModel(
|
| 393 |
-
chief_complaint=state.get("chief_complaint", ""),
|
| 394 |
-
hpi=hpi_obj,
|
| 395 |
-
ros=cleaned_ros,
|
| 396 |
generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
| 397 |
)
|
| 398 |
|
| 399 |
return {
|
| 400 |
-
"messages": [{"role": "assistant", "content": "
|
| 401 |
"current_node": "done",
|
| 402 |
"clinical_brief": brief.model_dump(),
|
| 403 |
}
|
|
@@ -406,31 +203,36 @@ def brief_generator_node(state: IntakeState) -> dict:
|
|
| 406 |
def build_graph():
|
| 407 |
workflow = StateGraph(IntakeState)
|
| 408 |
|
| 409 |
-
workflow.add_node("
|
| 410 |
-
workflow.add_node("
|
| 411 |
-
workflow.add_node("
|
| 412 |
-
workflow.add_node("
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
)
|
| 428 |
-
workflow.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
checkpointer = MemorySaver()
|
|
|
|
| 431 |
graph = workflow.compile(
|
| 432 |
checkpointer=checkpointer,
|
| 433 |
-
interrupt_after=["
|
| 434 |
)
|
| 435 |
|
| 436 |
return graph, checkpointer
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
from typing import Optional, TypedDict, Annotated
|
| 4 |
from langgraph.graph import StateGraph, START, END
|
| 5 |
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
from app.llm import get_llm
|
| 8 |
+
from app.schemas import ClinicalStateExtraction, ClinicalBrief, HPI
|
| 9 |
|
| 10 |
+
_MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
|
| 13 |
return left + right
|
| 14 |
|
|
|
|
| 15 |
class IntakeState(TypedDict):
|
| 16 |
messages: Annotated[list[dict], add_messages]
|
| 17 |
+
clinical_state: str # JSON representation of ClinicalStateExtraction
|
| 18 |
+
missing_fields: list[str]
|
|
|
|
| 19 |
current_node: str
|
| 20 |
clinical_brief: Optional[dict]
|
| 21 |
+
frontend_stage: str # 'intake', 'hpi', 'ros', or 'done'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
# -------------------- HELPER FUNCTIONS --------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
HPI_REQUIRED = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
|
| 26 |
+
ROS_REQUIRED_COUNT = 3
|
| 27 |
|
| 28 |
+
def format_transcript(messages: list[dict]) -> str:
|
| 29 |
+
out = []
|
| 30 |
+
# Only send the last couple of turns to not overwhelm if it's long, but ideally all
|
| 31 |
+
for m in messages:
|
| 32 |
+
role = "AI" if m["role"] == "assistant" else "Patient"
|
| 33 |
+
out.append(f"{role}: {m['content']}")
|
| 34 |
+
return "\n".join(out)
|
| 35 |
|
| 36 |
+
def evaluate_missing(state: ClinicalStateExtraction) -> (list[str], str):
|
|
|
|
| 37 |
"""
|
| 38 |
+
Returns list of missing fields and the 'frontend_stage' mapped mapping.
|
|
|
|
| 39 |
"""
|
| 40 |
+
missing = []
|
| 41 |
+
stage = "intake"
|
| 42 |
+
|
| 43 |
+
if not state.chief_complaint:
|
| 44 |
+
missing.append("chief complaint (reason for visit)")
|
| 45 |
+
return missing, stage
|
| 46 |
+
|
| 47 |
+
stage = "hpi"
|
| 48 |
+
for field in HPI_REQUIRED:
|
| 49 |
+
val = getattr(state.hpi, field)
|
| 50 |
+
if not val or val.lower() == "not specified":
|
| 51 |
+
missing.append(f"HPI: {field}")
|
| 52 |
+
|
| 53 |
+
if missing:
|
| 54 |
+
return missing, stage
|
| 55 |
+
|
| 56 |
+
stage = "ros"
|
| 57 |
+
# Need at least a few systems covered if possible
|
| 58 |
+
if len(state.ros.keys()) < ROS_REQUIRED_COUNT:
|
| 59 |
+
missing.append(f"Review of Systems (ask about {ROS_REQUIRED_COUNT - len(state.ros.keys())} more bodily systems)")
|
| 60 |
+
return missing, stage
|
| 61 |
+
|
| 62 |
+
return [], "done"
|
| 63 |
|
| 64 |
|
| 65 |
# -------------------- NODES --------------------
|
| 66 |
|
| 67 |
+
def triage_node(state: IntakeState) -> dict:
|
| 68 |
+
msgs = state.get("messages", [])
|
| 69 |
+
if not msgs:
|
| 70 |
+
return {"current_node": "triage"}
|
| 71 |
+
|
| 72 |
+
last_msg = msgs[-1]
|
| 73 |
+
if last_msg["role"] == "user":
|
| 74 |
+
content = last_msg["content"].lower()
|
| 75 |
+
emergencies = ["suicide", "kill myself", "crushing chest pain", "can't breathe", "heart attack"]
|
| 76 |
+
if any(e in content for e in emergencies):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
return {
|
| 78 |
+
"messages": [{"role": "assistant", "content": "🚨 EMERGENCY OVERRIDE: Your symptoms sound like a medical emergency. Please call 911 or visit the nearest emergency room immediately."}],
|
| 79 |
+
"current_node": "done",
|
| 80 |
+
"frontend_stage": "done"
|
|
|
|
|
|
|
| 81 |
}
|
| 82 |
+
|
| 83 |
+
return {"current_node": "extractor"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
+
def extractor_node(state: IntakeState) -> dict:
|
| 87 |
+
msgs = state.get("messages", [])
|
| 88 |
+
if not msgs:
|
| 89 |
+
# Initial state setup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
return {
|
| 91 |
+
"clinical_state": ClinicalStateExtraction().model_dump_json(),
|
| 92 |
+
"current_node": "evaluator"
|
|
|
|
|
|
|
| 93 |
}
|
| 94 |
+
|
| 95 |
+
# Only run extractor if the last message was from the user
|
| 96 |
+
if msgs[-1]["role"] != "user":
|
| 97 |
+
return {"current_node": "evaluator"}
|
| 98 |
+
|
| 99 |
+
llm = get_llm()
|
| 100 |
+
transcript = format_transcript(msgs)
|
| 101 |
+
|
| 102 |
+
current_state_json = state.get("clinical_state")
|
| 103 |
+
if not current_state_json:
|
| 104 |
+
current_state_json = ClinicalStateExtraction().model_dump_json()
|
| 105 |
+
|
| 106 |
+
# Extractor Agent updates the state passively
|
| 107 |
+
new_state = llm.ask_json(transcript, current_state_json, ClinicalStateExtraction)
|
| 108 |
+
|
| 109 |
+
# Check if the extractor detected a latent emergency
|
| 110 |
+
if new_state.emergency_detected:
|
| 111 |
+
return {
|
| 112 |
+
"messages": [{"role": "assistant", "content": "🚨 EMERGENCY OVERRIDE: Based on your details, you require immediate medical attention. Call 911."}],
|
| 113 |
+
"current_node": "done",
|
| 114 |
+
"frontend_stage": "done",
|
| 115 |
+
"clinical_state": new_state.model_dump_json()
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"clinical_state": new_state.model_dump_json(),
|
| 120 |
+
"current_node": "evaluator"
|
| 121 |
+
}
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
def evaluator_node(state: IntakeState) -> dict:
|
| 125 |
+
state_json = state.get("clinical_state")
|
| 126 |
+
if not state_json:
|
| 127 |
+
clinical_state = ClinicalStateExtraction()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
else:
|
| 129 |
+
clinical_state = ClinicalStateExtraction.model_validate_json(state_json)
|
| 130 |
+
|
| 131 |
+
missing, stage = evaluate_missing(clinical_state)
|
| 132 |
+
|
| 133 |
+
if not missing:
|
| 134 |
+
return {
|
| 135 |
+
"missing_fields": missing,
|
| 136 |
+
"frontend_stage": "done",
|
| 137 |
+
"current_node": "scribe"
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
return {
|
| 141 |
+
"missing_fields": missing,
|
| 142 |
+
"frontend_stage": stage,
|
| 143 |
+
"current_node": "conversationalist"
|
|
|
|
| 144 |
}
|
| 145 |
|
| 146 |
|
| 147 |
+
def conversationalist_node(state: IntakeState) -> dict:
|
| 148 |
+
msgs = state.get("messages", [])
|
| 149 |
+
clinical_json = state.get("clinical_state", "{}")
|
| 150 |
+
missing = state.get("missing_fields", [])
|
| 151 |
+
|
| 152 |
+
if not msgs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return {
|
| 154 |
+
"messages": [{"role": "assistant", "content": "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"}],
|
| 155 |
+
"current_node": "conversationalist"
|
|
|
|
| 156 |
}
|
| 157 |
+
|
| 158 |
+
# Check if the agent just spoke (prevent double-speaking if no user input)
|
| 159 |
+
if msgs[-1]["role"] == "assistant":
|
| 160 |
+
return {"current_node": "conversationalist"}
|
| 161 |
+
|
| 162 |
+
# Dynamic target targeting the top missing field
|
| 163 |
+
target = missing[0] if missing else "general details"
|
| 164 |
+
|
| 165 |
+
system_prompt = (
|
| 166 |
+
"You are an empathetic clinical intake assistant. "
|
| 167 |
+
"Your sole job is to ask the next logical medical question in a conversational way. "
|
| 168 |
+
f"We currently know this info about the patient:\n{clinical_json}\n\n"
|
| 169 |
+
f"YOUR GOAL: You MUST naturally uncover the following missing information: {target}. "
|
| 170 |
+
"Keep your response to exactly ONE question. Be concise and friendly."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
transcript = format_transcript(msgs[-6:]) # Context window
|
| 174 |
+
llm = get_llm()
|
| 175 |
+
reply = llm.ask(f"Transcript:\n{transcript}\n\nAsk the next question about: {target}.", system=system_prompt)
|
| 176 |
+
|
| 177 |
return {
|
| 178 |
"messages": [{"role": "assistant", "content": reply}],
|
| 179 |
+
"current_node": "conversationalist"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
}
|
| 181 |
|
| 182 |
|
| 183 |
+
def scribe_node(state: IntakeState) -> dict:
|
| 184 |
+
state_json = state.get("clinical_state")
|
| 185 |
+
data = ClinicalStateExtraction.model_validate_json(state_json)
|
| 186 |
+
|
| 187 |
+
from datetime import datetime, timezone
|
| 188 |
+
|
| 189 |
+
brief = ClinicalBrief(
|
| 190 |
+
chief_complaint=data.chief_complaint or "Not specified",
|
| 191 |
+
hpi=data.hpi,
|
| 192 |
+
ros=data.ros,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
| 194 |
)
|
| 195 |
|
| 196 |
return {
|
| 197 |
+
"messages": [{"role": "assistant", "content": "Thank you — I have everything I need. Your clinical summary is ready."}],
|
| 198 |
"current_node": "done",
|
| 199 |
"clinical_brief": brief.model_dump(),
|
| 200 |
}
|
|
|
|
| 203 |
def build_graph():
|
| 204 |
workflow = StateGraph(IntakeState)
|
| 205 |
|
| 206 |
+
workflow.add_node("triage", triage_node)
|
| 207 |
+
workflow.add_node("extractor", extractor_node)
|
| 208 |
+
workflow.add_node("evaluator", evaluator_node)
|
| 209 |
+
workflow.add_node("conversationalist", conversationalist_node)
|
| 210 |
+
workflow.add_node("scribe", scribe_node)
|
| 211 |
+
|
| 212 |
+
def route_triage(state: IntakeState) -> str:
|
| 213 |
+
# If triage marked it 'done' (emergency), skip everything
|
| 214 |
+
return state.get("current_node", "extractor")
|
| 215 |
+
|
| 216 |
+
def route_extractor(state: IntakeState) -> str:
|
| 217 |
+
# Extractor marks it 'done' if latent emergency, else 'evaluator'
|
| 218 |
+
return state.get("current_node", "evaluator")
|
| 219 |
+
|
| 220 |
+
def route_evaluator(state: IntakeState) -> str:
|
| 221 |
+
return state.get("current_node", "conversationalist")
|
| 222 |
+
|
| 223 |
+
workflow.add_edge(START, "triage")
|
| 224 |
+
workflow.add_conditional_edges("triage", route_triage, {"done": END, "extractor": "extractor"})
|
| 225 |
+
workflow.add_conditional_edges("extractor", route_extractor, {"done": END, "evaluator": "evaluator"})
|
| 226 |
+
workflow.add_conditional_edges("evaluator", route_evaluator, {"conversationalist": "conversationalist", "scribe": "scribe"})
|
| 227 |
+
|
| 228 |
+
workflow.add_edge("conversationalist", END)
|
| 229 |
+
workflow.add_edge("scribe", END)
|
| 230 |
|
| 231 |
checkpointer = MemorySaver()
|
| 232 |
+
# Interrupt after conversationalist so it waits for user input
|
| 233 |
graph = workflow.compile(
|
| 234 |
checkpointer=checkpointer,
|
| 235 |
+
interrupt_after=["conversationalist"]
|
| 236 |
)
|
| 237 |
|
| 238 |
return graph, checkpointer
|
app/llm.py
CHANGED
|
@@ -1,26 +1,79 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
|
| 3 |
CLINICAL_SYSTEM_PROMPT = (
|
| 4 |
"You are a clinical intake assistant conducting a pre-visit patient interview. "
|
| 5 |
-
"
|
| 6 |
"Do not diagnose or give medical advice. Keep responses under 2 sentences. "
|
| 7 |
-
"Be empathetic but professional."
|
| 8 |
)
|
| 9 |
|
| 10 |
-
|
| 11 |
class MockLLM:
|
| 12 |
def __init__(self):
|
| 13 |
-
|
| 14 |
-
self.current_hpi_index = 0
|
| 15 |
-
self.ros_systems_done = False
|
| 16 |
-
|
| 17 |
-
def ask(self, instruction: str) -> str:
|
| 18 |
-
return "" # unused in mock mode — graph uses hardcoded questions
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
class TransformersLLM:
|
| 26 |
def __init__(self):
|
|
@@ -39,11 +92,11 @@ class TransformersLLM:
|
|
| 39 |
device_map="cpu",
|
| 40 |
)
|
| 41 |
|
| 42 |
-
def ask(self, instruction: str) -> str:
|
| 43 |
self._load()
|
| 44 |
import torch
|
| 45 |
messages = [
|
| 46 |
-
{"role": "system", "content":
|
| 47 |
{"role": "user", "content": instruction},
|
| 48 |
]
|
| 49 |
text = self.tokenizer.apply_chat_template(
|
|
@@ -53,8 +106,8 @@ class TransformersLLM:
|
|
| 53 |
with torch.no_grad():
|
| 54 |
outputs = self.model.generate(
|
| 55 |
**inputs,
|
| 56 |
-
max_new_tokens=
|
| 57 |
-
temperature=0.
|
| 58 |
do_sample=True,
|
| 59 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 60 |
)
|
|
@@ -64,9 +117,61 @@ class TransformersLLM:
|
|
| 64 |
)
|
| 65 |
return response.strip()
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
_llm_instance = None
|
| 69 |
|
|
|
|
| 70 |
|
| 71 |
def get_llm():
|
| 72 |
global _llm_instance
|
|
|
|
| 1 |
import os
|
| 2 |
+
import json
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
|
| 5 |
CLINICAL_SYSTEM_PROMPT = (
|
| 6 |
"You are a clinical intake assistant conducting a pre-visit patient interview. "
|
| 7 |
+
"Be empathetic, warm, and highly professional. "
|
| 8 |
"Do not diagnose or give medical advice. Keep responses under 2 sentences. "
|
|
|
|
| 9 |
)
|
| 10 |
|
|
|
|
| 11 |
class MockLLM:
|
| 12 |
def __init__(self):
|
| 13 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
def ask(self, instruction: str, system: str = CLINICAL_SYSTEM_PROMPT) -> str:
|
| 16 |
+
# We will heavily mock the responses in graph.py for tests
|
| 17 |
+
if "empathetic reply" in instruction.lower():
|
| 18 |
+
if "chest" in instruction.lower():
|
| 19 |
+
return "I'm sorry to hear about your chest pain. When did it start?"
|
| 20 |
+
return "I understand. Can you tell me more?"
|
| 21 |
+
|
| 22 |
+
# General fallback that allows tests to check for context
|
| 23 |
+
if "onset" in instruction.lower():
|
| 24 |
+
return "When did this start?"
|
| 25 |
+
elif "severity" in instruction.lower() or "scale" in instruction.lower():
|
| 26 |
+
return "On a scale of 1 to 10, how severe is this?"
|
| 27 |
+
elif "location" in instruction.lower():
|
| 28 |
+
return "Where exactly do you feel this?"
|
| 29 |
+
|
| 30 |
+
return "Can you elaborate on that?"
|
| 31 |
|
| 32 |
+
def ask_json(self, transcript: str, current_state: str, schema_cls: type[BaseModel]) -> BaseModel:
|
| 33 |
+
# Mocking extraction logic for deterministic testing
|
| 34 |
+
t_low = transcript.lower()
|
| 35 |
+
state_dict = json.loads(current_state)
|
| 36 |
+
|
| 37 |
+
# very basic test logic
|
| 38 |
+
if "chest pain" in t_low:
|
| 39 |
+
state_dict["chief_complaint"] = "chest pain"
|
| 40 |
+
if "yesterday" in t_low or "morning" in t_low:
|
| 41 |
+
if not state_dict.get("hpi"): state_dict["hpi"] = {}
|
| 42 |
+
state_dict["hpi"]["onset"] = "this morning" if "morning" in t_low else "yesterday"
|
| 43 |
+
if "center" in t_low:
|
| 44 |
+
if not state_dict.get("hpi"): state_dict["hpi"] = {}
|
| 45 |
+
state_dict["hpi"]["location"] = "center of chest"
|
| 46 |
+
if "constant" in t_low:
|
| 47 |
+
if not state_dict.get("hpi"): state_dict["hpi"] = {}
|
| 48 |
+
state_dict["hpi"]["duration"] = "constant"
|
| 49 |
+
if "pressure" in t_low or "tight" in t_low:
|
| 50 |
+
if not state_dict.get("hpi"): state_dict["hpi"] = {}
|
| 51 |
+
state_dict["hpi"]["character"] = "tight pressure"
|
| 52 |
+
if "7" in t_low or "seven" in t_low:
|
| 53 |
+
if not state_dict.get("hpi"): state_dict["hpi"] = {}
|
| 54 |
+
state_dict["hpi"]["severity"] = "7/10"
|
| 55 |
+
if "walk" in t_low or "running" in t_low:
|
| 56 |
+
if not state_dict.get("hpi"): state_dict["hpi"] = {}
|
| 57 |
+
state_dict["hpi"]["aggravating"] = "walking"
|
| 58 |
+
if "rest" in t_low:
|
| 59 |
+
if not state_dict.get("hpi"): state_dict["hpi"] = {}
|
| 60 |
+
state_dict["hpi"]["relieving"] = "resting"
|
| 61 |
+
|
| 62 |
+
if "palpitations" in t_low:
|
| 63 |
+
if not state_dict.get("ros"): state_dict["ros"] = {}
|
| 64 |
+
state_dict["ros"]["cardiac"] = ["palpitations", "no syncope"]
|
| 65 |
+
if "breath" in t_low:
|
| 66 |
+
if not state_dict.get("ros"): state_dict["ros"] = {}
|
| 67 |
+
state_dict["ros"]["respiratory"] = ["shortness of breath", "no cough"]
|
| 68 |
+
if "nausea" in t_low:
|
| 69 |
+
if not state_dict.get("ros"): state_dict["ros"] = {}
|
| 70 |
+
state_dict["ros"]["gi"] = ["no nausea"]
|
| 71 |
+
|
| 72 |
+
if "crushing chest pain" in t_low or "heart attack" in t_low or "emergency" in t_low:
|
| 73 |
+
state_dict["emergency_detected"] = True
|
| 74 |
+
|
| 75 |
+
# Guarantee schema matches via Pydantic model_validate
|
| 76 |
+
return schema_cls.model_validate(state_dict)
|
| 77 |
|
| 78 |
class TransformersLLM:
|
| 79 |
def __init__(self):
|
|
|
|
| 92 |
device_map="cpu",
|
| 93 |
)
|
| 94 |
|
| 95 |
+
def ask(self, instruction: str, system: str = CLINICAL_SYSTEM_PROMPT) -> str:
|
| 96 |
self._load()
|
| 97 |
import torch
|
| 98 |
messages = [
|
| 99 |
+
{"role": "system", "content": system},
|
| 100 |
{"role": "user", "content": instruction},
|
| 101 |
]
|
| 102 |
text = self.tokenizer.apply_chat_template(
|
|
|
|
| 106 |
with torch.no_grad():
|
| 107 |
outputs = self.model.generate(
|
| 108 |
**inputs,
|
| 109 |
+
max_new_tokens=100,
|
| 110 |
+
temperature=0.4,
|
| 111 |
do_sample=True,
|
| 112 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 113 |
)
|
|
|
|
| 117 |
)
|
| 118 |
return response.strip()
|
| 119 |
|
| 120 |
+
def ask_json(self, transcript: str, current_state: str, schema_cls: type[BaseModel]) -> BaseModel:
|
| 121 |
+
self._load()
|
| 122 |
+
import torch
|
| 123 |
+
|
| 124 |
+
system = (
|
| 125 |
+
"You are a clinical data extraction engine. "
|
| 126 |
+
"Your objective is to read the patient transcript and output exactly a valid JSON document "
|
| 127 |
+
"that matches the requested schema. Extract all relevant medical facts you can find. "
|
| 128 |
+
"Merge new facts into the existing state."
|
| 129 |
+
)
|
| 130 |
+
instruction = (
|
| 131 |
+
f"CURRENT STATE JSON (Update this based on the transcript):\n{current_state}\n\n"
|
| 132 |
+
f"TRANSCRIPT:\n{transcript}\n\n"
|
| 133 |
+
f"Output ONLY valid JSON matching this schema structure:\n"
|
| 134 |
+
f"{schema_cls.model_json_schema()}"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
messages = [
|
| 138 |
+
{"role": "system", "content": system},
|
| 139 |
+
{"role": "user", "content": instruction},
|
| 140 |
+
]
|
| 141 |
+
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 142 |
+
inputs = self.tokenizer(text, return_tensors="pt")
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
outputs = self.model.generate(
|
| 145 |
+
**inputs,
|
| 146 |
+
max_new_tokens=400,
|
| 147 |
+
temperature=0.1, # Keep low for JSON determinism
|
| 148 |
+
do_sample=False,
|
| 149 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 150 |
+
)
|
| 151 |
+
response = self.tokenizer.decode(
|
| 152 |
+
outputs[0][inputs.input_ids.shape[1]:],
|
| 153 |
+
skip_special_tokens=True,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Attempt to parse json from output
|
| 157 |
+
json_str = response.strip()
|
| 158 |
+
if "```json" in json_str:
|
| 159 |
+
json_str = json_str.split("```json")[-1].split("```")[0]
|
| 160 |
+
elif "```" in json_str:
|
| 161 |
+
json_str = json_str.split("```")[-1].split("```")[0]
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
parsed = json.loads(json_str)
|
| 165 |
+
return schema_cls.model_validate(parsed)
|
| 166 |
+
except Exception:
|
| 167 |
+
# Fallback to current state if extraction fails (avoids crashing)
|
| 168 |
+
try:
|
| 169 |
+
return schema_cls.model_validate_json(current_state)
|
| 170 |
+
except Exception:
|
| 171 |
+
return schema_cls()
|
| 172 |
|
|
|
|
| 173 |
|
| 174 |
+
_llm_instance = None
|
| 175 |
|
| 176 |
def get_llm():
|
| 177 |
global _llm_instance
|
app/main.py
CHANGED
|
@@ -36,12 +36,12 @@ graph, checkpointer = build_graph()
|
|
| 36 |
|
| 37 |
|
| 38 |
def get_current_node(session_id: str) -> str:
|
| 39 |
-
"""Get
|
| 40 |
config = {"configurable": {"thread_id": session_id}}
|
| 41 |
try:
|
| 42 |
snapshot = graph.get_state(config)
|
| 43 |
if snapshot and snapshot.values:
|
| 44 |
-
return snapshot.values.get("
|
| 45 |
except Exception:
|
| 46 |
pass
|
| 47 |
return "intake"
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
def get_current_node(session_id: str) -> str:
|
| 39 |
+
"""Get frontend stage from checkpoint."""
|
| 40 |
config = {"configurable": {"thread_id": session_id}}
|
| 41 |
try:
|
| 42 |
snapshot = graph.get_state(config)
|
| 43 |
if snapshot and snapshot.values:
|
| 44 |
+
return snapshot.values.get("frontend_stage", "intake")
|
| 45 |
except Exception:
|
| 46 |
pass
|
| 47 |
return "intake"
|
app/schemas.py
CHANGED
|
@@ -1,18 +1,23 @@
|
|
| 1 |
-
from
|
| 2 |
-
|
| 3 |
|
| 4 |
class HPI(BaseModel):
|
| 5 |
-
onset: str
|
| 6 |
-
location: str
|
| 7 |
-
duration: str
|
| 8 |
-
character: str
|
| 9 |
-
severity: str
|
| 10 |
-
aggravating: str
|
| 11 |
-
relieving: str
|
| 12 |
-
|
| 13 |
|
| 14 |
class ClinicalBrief(BaseModel):
|
| 15 |
chief_complaint: str
|
| 16 |
hpi: HPI
|
| 17 |
-
ros:
|
| 18 |
generated_at: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Dict, List
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
|
| 4 |
class HPI(BaseModel):
|
| 5 |
+
onset: Optional[str] = Field(None, description="When the symptom started")
|
| 6 |
+
location: Optional[str] = Field(None, description="Where exactly the symptom is located")
|
| 7 |
+
duration: Optional[str] = Field(None, description="How long episodes last or if it is constant")
|
| 8 |
+
character: Optional[str] = Field(None, description="What the pain feels like (sharp, dull, pressure, etc.)")
|
| 9 |
+
severity: Optional[str] = Field(None, description="Pain scale severity (e.g., 7/10 or 'severe')")
|
| 10 |
+
aggravating: Optional[str] = Field(None, description="What makes the symptoms worse")
|
| 11 |
+
relieving: Optional[str] = Field(None, description="What helps relieve the symptoms")
|
|
|
|
| 12 |
|
| 13 |
class ClinicalBrief(BaseModel):
|
| 14 |
chief_complaint: str
|
| 15 |
hpi: HPI
|
| 16 |
+
ros: Dict[str, List[str]]
|
| 17 |
generated_at: str
|
| 18 |
+
|
| 19 |
+
class ClinicalStateExtraction(BaseModel):
|
| 20 |
+
chief_complaint: Optional[str] = Field(None, description="The main reason for the visit")
|
| 21 |
+
hpi: HPI = Field(default_factory=HPI)
|
| 22 |
+
ros: Dict[str, List[str]] = Field(default_factory=dict, description="Review of systems, keys are system names, values are list of findings (positive or negative)")
|
| 23 |
+
emergency_detected: bool = Field(False, description="True ONLY if the patient mentions life-threatening symptoms requiring immediate 911/ER like severe crushing chest pain radiating to jaw, active severe bleeding, or suicidal ideation")
|
tests/test_e2e.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
os.environ["MOCK_LLM"] = "true"
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
from httpx import AsyncClient, ASGITransport
|
| 7 |
|
| 8 |
from app.main import app
|
| 9 |
-
|
| 10 |
|
| 11 |
@pytest.fixture
|
| 12 |
async def client():
|
|
@@ -14,7 +13,6 @@ async def client():
|
|
| 14 |
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
| 15 |
yield c
|
| 16 |
|
| 17 |
-
|
| 18 |
@pytest.mark.asyncio(loop_scope="function")
|
| 19 |
async def test_health_endpoint(client):
|
| 20 |
response = await client.get("/health")
|
|
@@ -23,171 +21,62 @@ async def test_health_endpoint(client):
|
|
| 23 |
assert data["status"] == "ok"
|
| 24 |
assert data["mock_mode"] is True
|
| 25 |
|
| 26 |
-
|
| 27 |
@pytest.mark.asyncio(loop_scope="function")
|
| 28 |
-
async def
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
assert response.status_code == 200
|
| 33 |
data = response.json()
|
| 34 |
-
|
| 35 |
-
assert data["state"]
|
| 36 |
-
|
| 37 |
-
responses = [
|
| 38 |
-
"I have chest pain since this morning", # CC (intake)
|
| 39 |
-
"It started about 3 hours ago", # onset
|
| 40 |
-
"In the center of my chest", # location
|
| 41 |
-
"It has been constant for an hour", # duration
|
| 42 |
-
"It feels like pressure", # character
|
| 43 |
-
"About a 7 out of 10", # severity
|
| 44 |
-
"It gets worse when I walk", # aggravating
|
| 45 |
-
"Resting helps a little", # relieving
|
| 46 |
-
"palpitations present, no syncope", # cardiac ROS
|
| 47 |
-
"mild shortness of breath, no cough", # respiratory ROS
|
| 48 |
-
"no nausea or vomiting", # gi ROS
|
| 49 |
-
]
|
| 50 |
-
|
| 51 |
-
final_data = None
|
| 52 |
-
for resp_text in responses:
|
| 53 |
-
response = await client.post("/chat", json={"session_id": session_id, "message": resp_text})
|
| 54 |
-
assert response.status_code == 200
|
| 55 |
-
final_data = response.json()
|
| 56 |
-
|
| 57 |
-
assert final_data is not None
|
| 58 |
-
assert final_data["state"] == "done"
|
| 59 |
-
assert "brief" in final_data
|
| 60 |
-
assert final_data["brief"] is not None
|
| 61 |
-
|
| 62 |
-
brief = final_data["brief"]
|
| 63 |
-
assert "chief_complaint" in brief
|
| 64 |
-
assert "hpi" in brief
|
| 65 |
-
assert "ros" in brief
|
| 66 |
-
|
| 67 |
|
| 68 |
@pytest.mark.asyncio(loop_scope="function")
|
| 69 |
-
async def
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
# First HPI question is about onset
|
| 77 |
-
vague_response = await client.post("/chat", json={"session_id": session_id, "message": "I don't know"})
|
| 78 |
-
assert vague_response.status_code == 200
|
| 79 |
-
data = vague_response.json()
|
| 80 |
-
reply_lower = data["reply"].lower()
|
| 81 |
-
# Should ask again — should mention specificity or the field context
|
| 82 |
-
assert "specific" in reply_lower or "when" in reply_lower or "start" in reply_lower
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
@pytest.mark.asyncio(loop_scope="function")
|
| 86 |
-
async def test_ros_scoping(client):
|
| 87 |
-
"""For chest pain, ROS should include cardiac and respiratory systems."""
|
| 88 |
-
session_id = "test_chest_pain"
|
| 89 |
-
|
| 90 |
await client.post("/chat", json={"session_id": session_id, "message": "hello"})
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
"It started 3 hours ago",
|
| 128 |
-
"In the center of my chest",
|
| 129 |
-
"Constant",
|
| 130 |
-
"Pressure-like",
|
| 131 |
-
"7 out of 10",
|
| 132 |
-
"Walking worsens it",
|
| 133 |
-
"Resting helps",
|
| 134 |
-
"palpitations present, no syncope",
|
| 135 |
-
"shortness of breath, no cough",
|
| 136 |
-
"no nausea or vomiting",
|
| 137 |
-
]
|
| 138 |
-
|
| 139 |
-
response = None
|
| 140 |
-
for msg in messages:
|
| 141 |
-
response = await client.post("/chat", json={"session_id": session_id, "message": msg})
|
| 142 |
-
assert response.status_code == 200
|
| 143 |
-
|
| 144 |
-
final_data = response.json()
|
| 145 |
-
|
| 146 |
-
if final_data.get("brief"):
|
| 147 |
-
brief = final_data["brief"]
|
| 148 |
-
from app.schemas import ClinicalBrief
|
| 149 |
-
validated = ClinicalBrief.model_validate(brief)
|
| 150 |
-
|
| 151 |
-
assert validated.chief_complaint
|
| 152 |
-
assert validated.hpi.onset
|
| 153 |
-
assert validated.hpi.location
|
| 154 |
-
assert validated.hpi.duration
|
| 155 |
-
assert validated.hpi.character
|
| 156 |
-
assert validated.hpi.severity
|
| 157 |
-
assert validated.hpi.aggravating
|
| 158 |
-
assert validated.hpi.relieving
|
| 159 |
-
assert validated.generated_at
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
@pytest.mark.asyncio(loop_scope="function")
|
| 163 |
-
async def test_brief_cleaning(client):
|
| 164 |
-
"""Brief generator should strip informal filler words from patient answers."""
|
| 165 |
-
session_id = "test_cleaning"
|
| 166 |
-
|
| 167 |
-
messages = [
|
| 168 |
-
"hello",
|
| 169 |
-
"I have chest pain",
|
| 170 |
-
"yeah like since yesterday evening", # filler "yeah like"
|
| 171 |
-
"like in my chest area", # filler "like"
|
| 172 |
-
"Constant",
|
| 173 |
-
"um tight and squeezing", # filler "um"
|
| 174 |
-
"7 out of 10",
|
| 175 |
-
"yeah walking makes it worse", # filler "yeah"
|
| 176 |
-
"Resting helps",
|
| 177 |
-
"palpitations, no syncope",
|
| 178 |
-
"mild shortness of breath",
|
| 179 |
-
"no nausea",
|
| 180 |
-
]
|
| 181 |
-
|
| 182 |
-
response = None
|
| 183 |
-
for msg in messages:
|
| 184 |
-
response = await client.post("/chat", json={"session_id": session_id, "message": msg})
|
| 185 |
-
assert response.status_code == 200
|
| 186 |
-
|
| 187 |
-
final_data = response.json()
|
| 188 |
-
if final_data.get("brief"):
|
| 189 |
-
hpi = final_data["brief"]["hpi"]
|
| 190 |
-
# "yeah like since yesterday evening" → should not start with "yeah"
|
| 191 |
-
if hpi.get("onset"):
|
| 192 |
-
assert not hpi["onset"].lower().startswith("yeah"), \
|
| 193 |
-
f"Filler not cleaned from onset: {hpi['onset']}"
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
os.environ["MOCK_LLM"] = "true"
|
| 3 |
|
| 4 |
import pytest
|
| 5 |
from httpx import AsyncClient, ASGITransport
|
| 6 |
|
| 7 |
from app.main import app
|
| 8 |
+
from app.schemas import ClinicalBrief
|
| 9 |
|
| 10 |
@pytest.fixture
|
| 11 |
async def client():
|
|
|
|
| 13 |
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
| 14 |
yield c
|
| 15 |
|
|
|
|
| 16 |
@pytest.mark.asyncio(loop_scope="function")
|
| 17 |
async def test_health_endpoint(client):
|
| 18 |
response = await client.get("/health")
|
|
|
|
| 21 |
assert data["status"] == "ok"
|
| 22 |
assert data["mock_mode"] is True
|
| 23 |
|
|
|
|
| 24 |
@pytest.mark.asyncio(loop_scope="function")
|
| 25 |
+
async def test_emergency_triage_guardrail(client):
|
| 26 |
+
"""If user types 'crushing chest pain', the triage node should immediately abort to 'done'."""
|
| 27 |
+
session_id = "test_emergency"
|
| 28 |
+
|
| 29 |
+
await client.post("/chat", json={"session_id": session_id, "message": "hello"})
|
| 30 |
+
|
| 31 |
+
response = await client.post("/chat", json={"session_id": session_id, "message": "I am having crushing chest pain"})
|
| 32 |
assert response.status_code == 200
|
| 33 |
data = response.json()
|
| 34 |
+
|
| 35 |
+
assert data["state"] == "done"
|
| 36 |
+
assert "911" in data["reply"] or "emergency" in data["reply"].lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
@pytest.mark.asyncio(loop_scope="function")
|
| 39 |
+
async def test_shadow_extractor_logic(client):
|
| 40 |
+
"""
|
| 41 |
+
Test that the shadow extractor gracefully fills in missing information behind the scenes,
|
| 42 |
+
transitioning the frontend stage from hpi to ros and finally done.
|
| 43 |
+
"""
|
| 44 |
+
session_id = "test_extraction"
|
| 45 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
await client.post("/chat", json={"session_id": session_id, "message": "hello"})
|
| 47 |
+
|
| 48 |
+
# 1. Chief Complaint & some HPI
|
| 49 |
+
# The mock LLM maps "chest pain" -> CC, "yesterday" -> onset
|
| 50 |
+
res = await client.post("/chat", json={"session_id": session_id, "message": "I have chest pain since yesterday"})
|
| 51 |
+
assert res.status_code == 200
|
| 52 |
+
data = res.json()
|
| 53 |
+
assert data["state"] == "hpi" # Needs more HPI info
|
| 54 |
+
|
| 55 |
+
# 2. More HPI info
|
| 56 |
+
res = await client.post("/chat", json={"session_id": session_id, "message": "It is constant pressure in the center. Severity is 7. Walking makes it worse, rest helps."})
|
| 57 |
+
assert res.status_code == 200
|
| 58 |
+
data = res.json()
|
| 59 |
+
assert data["state"] == "ros" # Completes HPI, moves to ROS
|
| 60 |
+
|
| 61 |
+
# 3. ROS info
|
| 62 |
+
res = await client.post("/chat", json={"session_id": session_id, "message": "I have palpitations and shortness of breath. No nausea."})
|
| 63 |
+
assert res.status_code == 200
|
| 64 |
+
data = res.json()
|
| 65 |
+
|
| 66 |
+
# Should be done
|
| 67 |
+
assert data["state"] == "done"
|
| 68 |
+
assert data["brief"] is not None
|
| 69 |
+
|
| 70 |
+
brief = ClinicalBrief.model_validate(data["brief"])
|
| 71 |
+
assert brief.chief_complaint == "chest pain"
|
| 72 |
+
assert brief.hpi.onset == "yesterday"
|
| 73 |
+
assert brief.hpi.location == "center of chest"
|
| 74 |
+
assert brief.hpi.duration == "constant"
|
| 75 |
+
assert brief.hpi.character == "tight pressure"
|
| 76 |
+
assert brief.hpi.severity == "7/10"
|
| 77 |
+
assert brief.hpi.aggravating == "walking"
|
| 78 |
+
assert brief.hpi.relieving == "resting"
|
| 79 |
+
|
| 80 |
+
assert "cardiac" in brief.ros
|
| 81 |
+
assert "respiratory" in brief.ros
|
| 82 |
+
assert "gi" in brief.ros
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|