medintake-ai / app /graph.py
priyansh-saxena1
feat: unified prompt with state visibility
8d6f802
import os
import json
from typing import Optional, TypedDict, Annotated
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from app.llm import get_llm, CombinedOutput, HPI_FIELDS, ROS_REQUIRED
from app.schemas import ClinicalBrief, HPI
_MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
return left + right
class IntakeState(TypedDict):
messages: Annotated[list[dict], add_messages]
clinical_state: str # JSON of CombinedOutput (accumulated clinical data)
missing_fields: list[str]
current_node: str
clinical_brief: Optional[dict]
frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
ros_stuck_count: int # consecutive turns stuck in ROS with no progress
EMERGENCY_PHRASES = [
"crushing chest pain", "can't breathe", "cannot breathe",
"heart attack", "suicide", "kill myself", "can't move", "dying"
]
# ------------------------------------------------------------------ helpers --
def format_transcript(messages: list[dict]) -> str:
lines = []
for m in messages:
role = "AI" if m["role"] == "assistant" else "Patient"
lines.append(f"{role}: {m['content']}")
return "\n".join(lines)
def compute_stage(state: CombinedOutput) -> str:
if not state.chief_complaint:
return "intake"
for f in HPI_FIELDS:
if not getattr(state, f):
return "hpi"
if len(state.ros) < ROS_REQUIRED:
return "ros"
return "done"
def missing_from(state: CombinedOutput) -> list[str]:
missing = []
if not state.chief_complaint:
missing.append("chief complaint")
return missing
for f in HPI_FIELDS:
if not getattr(state, f):
missing.append(f"HPI:{f}")
if len(state.ros) < ROS_REQUIRED:
missing.append(f"ROS ({ROS_REQUIRED - len(state.ros)} more systems needed)")
return missing
def _get_last_user_message(msgs: list[dict]) -> str:
for m in reversed(msgs):
if m.get("role") == "user":
return m.get("content", "")
return ""
def _detect_repeat(msgs: list[dict], new_reply: str) -> bool:
"""Return True if new_reply is identical to the last two stored assistant replies."""
assistant_replies = [m.get("content", "") for m in msgs if m.get("role") == "assistant"]
if len(assistant_replies) >= 2:
return new_reply == assistant_replies[-1] == assistant_replies[-2]
return False
# ------------------------------------------------------------------- nodes ---
def triage_node(state: IntakeState) -> dict:
"""Fast keyword check β€” no LLM call. Abort immediately on emergency phrases."""
msgs = state.get("messages", [])
if msgs and msgs[-1]["role"] == "user":
content = msgs[-1]["content"].lower()
if any(p in content for p in EMERGENCY_PHRASES):
return {
"messages": [{
"role": "assistant",
"content": (
"🚨 EMERGENCY: Your symptoms require immediate attention. "
"Please call 911 or go to your nearest emergency room right away."
)
}],
"current_node": "done",
"frontend_stage": "done",
}
return {"current_node": "agent"}
def agent_node(state: IntakeState) -> dict:
"""
Core agent β€” one LLM call per turn.
Extracts clinical data, generates next question, builds brief when complete.
"""
msgs = state.get("messages", [])
# First call: no messages yet β†’ return opening greeting
if not msgs or (len(msgs) == 1 and msgs[0]["role"] == "assistant"):
return {
"messages": [{"role": "assistant", "content": "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"}],
"clinical_state": CombinedOutput().model_dump_json(),
"frontend_stage": "intake",
"current_node": "agent",
"ros_stuck_count": 0,
}
if msgs[-1]["role"] == "assistant":
return {"current_node": "agent"}
current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
transcript = format_transcript(msgs)
ros_stuck_count = state.get("ros_stuck_count", 0)
try:
pre_state = CombinedOutput.model_validate_json(current_json)
current_stage = compute_stage(pre_state)
except Exception:
current_stage = "intake"
import time
print(f"[{time.time():.3f}] [Graph Node] Requesting LLM inference (stage={current_stage})...")
llm = get_llm()
result: CombinedOutput = llm.combined_call(transcript, current_json, stage=current_stage)
# ── ROS Hallucination Guard: max 1 new ROS system per turn ──────────
try:
prev_ros = json.loads(current_json).get("ros") or {}
except Exception:
prev_ros = {}
new_ros_keys = [k for k in result.ros if k not in prev_ros]
if len(new_ros_keys) > 1:
print(f"[ROSGuard] LLM added {len(new_ros_keys)} systems in one turn: {new_ros_keys}. Keeping first only.")
allowed_ros = dict(prev_ros)
allowed_ros[new_ros_keys[0]] = result.ros[new_ros_keys[0]]
result = result.model_copy(update={"ros": allowed_ros})
# ── Loop Guard ───────────────────────────────────────────────────────
try:
prev_state_obj = CombinedOutput.model_validate_json(current_json)
prev_filled = sum(1 for f in HPI_FIELDS if getattr(prev_state_obj, f, None)) + len(prev_state_obj.ros)
new_filled = sum(1 for f in HPI_FIELDS if getattr(result, f, None)) + len(result.ros)
made_progress = new_filled > prev_filled
except Exception:
made_progress = True # assume progress on parse error
hpi_complete = all(getattr(result, f, None) for f in HPI_FIELDS)
if not made_progress:
last_user_msg = _get_last_user_message(msgs)
if not hpi_complete:
# HPI stuck β€” force-fill the first empty field
for stuck_field in HPI_FIELDS:
if not getattr(result, stuck_field, None):
result = result.model_copy(update={stuck_field: last_user_msg or "not specified"})
print(f"[LoopGuard] Force-filled HPI '{stuck_field}' = '{last_user_msg or 'not specified'}'")
break
else:
# ROS stuck β€” force-store the user's answer into a pending ros_asked system
ros_stuck_count += 1
pending = [s for s in result.ros_asked if s not in result.ros]
if pending:
# Store whatever the user just said as the finding for this system
new_ros = dict(result.ros)
new_ros[pending[0]] = [last_user_msg] if last_user_msg else ["no symptoms reported"]
result = result.model_copy(update={"ros": new_ros})
print(f"[LoopGuard] Force-stored ROS['{pending[0]}'] = [{last_user_msg[:40]}]")
elif ros_stuck_count >= 2:
# LLM isn't even updating ros_asked β€” force a dummy system to unblock
stub_key = f"general_{len(result.ros)}"
new_ros = dict(result.ros)
new_ros[stub_key] = [last_user_msg] if last_user_msg else ["no additional symptoms"]
result = result.model_copy(update={"ros": new_ros})
print(f"[LoopGuard] Force-added stub ROS['{stub_key}'] after {ros_stuck_count} stuck turns.")
ros_stuck_count = 0
else:
ros_stuck_count = 0 # reset counter when progress is made
print(f"[{time.time():.3f}] [Graph Node] LLM returned. Preparing node dictionaries...")
stage = compute_stage(result)
missing = missing_from(result)
reply = result.reply or "Could you tell me more?"
# Sanitize reply β€” avoid storing empty or whitespace-only replies
if not reply.strip():
reply = "Could you tell me more?"
# All fields complete β€” build the brief inline
if stage == "done":
from datetime import datetime, timezone
brief = ClinicalBrief(
chief_complaint=result.chief_complaint or "Not specified",
hpi=HPI(
onset=result.onset or "Not specified",
location=result.location or "Not specified",
duration=result.duration or "Not specified",
character=result.character or "Not specified",
severity=result.severity or "Not specified",
aggravating=result.aggravating or "Not specified",
relieving=result.relieving or "Not specified",
),
ros=result.ros,
generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
)
return {
"messages": [{"role": "assistant", "content": "Your clinical summary is ready. Please wait for the doctor."}],
"clinical_state": result.model_dump_json(),
"missing_fields": [],
"frontend_stage": "done",
"current_node": "done",
"clinical_brief": brief.model_dump(),
"ros_stuck_count": 0,
}
return {
"messages": [{"role": "assistant", "content": reply}],
"clinical_state": result.model_dump_json(),
"missing_fields": missing,
"frontend_stage": stage,
"current_node": "agent",
"ros_stuck_count": ros_stuck_count,
}
# -------------------------------------------------------------- graph build --
def build_graph():
workflow = StateGraph(IntakeState)
workflow.add_node("triage", triage_node)
workflow.add_node("agent", agent_node)
def route_triage(state: IntakeState) -> str:
return state.get("current_node", "agent")
workflow.add_edge(START, "triage")
workflow.add_conditional_edges("triage", route_triage, {"done": END, "agent": "agent"})
workflow.add_edge("agent", END)
checkpointer = MemorySaver()
graph = workflow.compile(checkpointer=checkpointer)
# NOTE: interrupt_after removed β€” state accumulates via MemorySaver reducer
# on every fresh invoke, which is correct behavior (has_next is always False)
return graph, checkpointer