priyansh-saxena1 commited on
Commit
8d6f802
Β·
1 Parent(s): 44d41e8

feat: unified prompt with state visibility

Browse files
Files changed (2) hide show
  1. app/graph.py +77 -63
  2. app/llm.py +89 -40
app/graph.py CHANGED
@@ -5,7 +5,7 @@ from langgraph.graph import StateGraph, START, END
5
  from langgraph.checkpoint.memory import MemorySaver
6
 
7
  from app.llm import get_llm, CombinedOutput, HPI_FIELDS, ROS_REQUIRED
8
- from app.schemas import ClinicalBrief, HPI, ClinicalStateExtraction
9
 
10
  _MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
11
 
@@ -21,10 +21,9 @@ class IntakeState(TypedDict):
21
  current_node: str
22
  clinical_brief: Optional[dict]
23
  frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
 
24
 
25
 
26
- # HPI_FIELDS and ROS_REQUIRED imported from app.llm
27
-
28
  EMERGENCY_PHRASES = [
29
  "crushing chest pain", "can't breathe", "cannot breathe",
30
  "heart attack", "suicide", "kill myself", "can't move", "dying"
@@ -65,11 +64,19 @@ def missing_from(state: CombinedOutput) -> list[str]:
65
  return missing
66
 
67
 
68
- def _detect_repeat(state) -> bool:
69
- """Return True if the last two assistant replies are identical."""
70
- msgs = state.get("messages", [])
 
 
 
 
 
 
71
  assistant_replies = [m.get("content", "") for m in msgs if m.get("role") == "assistant"]
72
- return len(assistant_replies) >= 2 and assistant_replies[-1] == assistant_replies[-2]
 
 
73
 
74
 
75
  # ------------------------------------------------------------------- nodes ---
@@ -96,20 +103,19 @@ def triage_node(state: IntakeState) -> dict:
96
 
97
  def agent_node(state: IntakeState) -> dict:
98
  """
99
- Core agent node β€” ONE combined LLM call per turn:
100
- 1. Extracts any new clinical data from the transcript.
101
- 2. Generates the next conversational question.
102
- 3. If all data is collected, builds the ClinicalBrief inline (no separate scribe node).
103
  """
104
  msgs = state.get("messages", [])
105
 
106
- # On first call with no messages, return opening greeting
107
  if not msgs or (len(msgs) == 1 and msgs[0]["role"] == "assistant"):
108
  return {
109
  "messages": [{"role": "assistant", "content": "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"}],
110
  "clinical_state": CombinedOutput().model_dump_json(),
111
  "frontend_stage": "intake",
112
  "current_node": "agent",
 
113
  }
114
 
115
  if msgs[-1]["role"] == "assistant":
@@ -117,8 +123,8 @@ def agent_node(state: IntakeState) -> dict:
117
 
118
  current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
119
  transcript = format_transcript(msgs)
 
120
 
121
- # Compute the current stage BEFORE the LLM call so we can pick the right prompt
122
  try:
123
  pre_state = CombinedOutput.model_validate_json(current_json)
124
  current_stage = compute_stage(pre_state)
@@ -126,62 +132,66 @@ def agent_node(state: IntakeState) -> dict:
126
  current_stage = "intake"
127
 
128
  import time
129
- t_agent = time.time()
130
  print(f"[{time.time():.3f}] [Graph Node] Requesting LLM inference (stage={current_stage})...")
131
 
132
  llm = get_llm()
133
  result: CombinedOutput = llm.combined_call(transcript, current_json, stage=current_stage)
134
 
135
- # ── Loop Guard: if LLM returned same reply as last turn, force-fill stuck field ──
136
- if _detect_repeat({"messages": msgs + [{"role": "assistant", "content": result.reply}]}):
137
- # Check if the LLM made progress extracting data despite repeating the reply
138
- try:
139
- prev_state = CombinedOutput.model_validate_json(current_json)
140
- prev_filled = sum(1 for f in HPI_FIELDS if getattr(prev_state, f, None)) + len(prev_state.ros)
141
- new_filled = sum(1 for f in HPI_FIELDS if getattr(result, f, None)) + len(result.ros)
142
- made_progress = new_filled > prev_filled
143
- except Exception:
144
- made_progress = False
145
-
146
- hpi_filled = all(getattr(result, f, None) for f in HPI_FIELDS)
147
-
148
- if not hpi_filled:
149
- if not made_progress:
150
- # Still in HPI and stuck β€” force-fill the first empty HPI field
151
- for stuck_field in HPI_FIELDS:
152
- if getattr(result, stuck_field, None) is None:
153
- object.__setattr__(result, stuck_field, "not specified")
154
- print(f"[LoopGuard] Force-filled HPI '{stuck_field}' = 'not specified' to break repeat loop")
155
- break
156
-
157
- # Ensure we ask a new question to break the loop
158
- new_missing = missing_from(result)
159
- if new_missing:
160
- object.__setattr__(result, "reply", f"Thank you. Now, could you tell me about {new_missing[0].replace('HPI:', '')}?")
161
- else:
162
- object.__setattr__(result, "reply", "Thank you β€” I have everything I need.")
163
- else:
164
- # In ROS stage
165
- if not made_progress:
166
- print("[LoopGuard] LLM stuck in ROS without extracting data. Skipping system.")
167
-
168
- if len(result.ros) < ROS_REQUIRED:
169
- object.__setattr__(result, "reply", f"Thank you. Are there any other symptoms you've been experiencing?")
170
- else:
171
- object.__setattr__(result, "reply", "Thank you β€” I have everything I need.")
172
-
173
- # ── ROS Hallucination Guard: LLM can only ADD one new ROS system per turn ──
174
  try:
175
- prev_state = json.loads(current_json)
176
- prev_ros = prev_state.get("ros") or {}
177
  except Exception:
178
  prev_ros = {}
 
179
  new_ros_keys = [k for k in result.ros if k not in prev_ros]
180
  if len(new_ros_keys) > 1:
181
- print(f"[ROSGuard] LLM added {len(new_ros_keys)} new ROS systems in one turn: {new_ros_keys}. Keeping only first.")
182
  allowed_ros = dict(prev_ros)
183
  allowed_ros[new_ros_keys[0]] = result.ros[new_ros_keys[0]]
184
- object.__setattr__(result, "ros", allowed_ros)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  print(f"[{time.time():.3f}] [Graph Node] LLM returned. Preparing node dictionaries...")
187
 
@@ -189,7 +199,11 @@ def agent_node(state: IntakeState) -> dict:
189
  missing = missing_from(result)
190
  reply = result.reply or "Could you tell me more?"
191
 
192
- # All fields complete β€” build the brief inline so it's available this turn
 
 
 
 
193
  if stage == "done":
194
  from datetime import datetime, timezone
195
  brief = ClinicalBrief(
@@ -213,6 +227,7 @@ def agent_node(state: IntakeState) -> dict:
213
  "frontend_stage": "done",
214
  "current_node": "done",
215
  "clinical_brief": brief.model_dump(),
 
216
  }
217
 
218
  return {
@@ -221,6 +236,7 @@ def agent_node(state: IntakeState) -> dict:
221
  "missing_fields": missing,
222
  "frontend_stage": stage,
223
  "current_node": "agent",
 
224
  }
225
 
226
 
@@ -240,10 +256,8 @@ def build_graph():
240
  workflow.add_edge("agent", END)
241
 
242
  checkpointer = MemorySaver()
243
- # Interrupt after agent so it pauses for user input each turn
244
- graph = workflow.compile(
245
- checkpointer=checkpointer,
246
- interrupt_after=["agent"]
247
- )
248
 
249
  return graph, checkpointer
 
5
  from langgraph.checkpoint.memory import MemorySaver
6
 
7
  from app.llm import get_llm, CombinedOutput, HPI_FIELDS, ROS_REQUIRED
8
+ from app.schemas import ClinicalBrief, HPI
9
 
10
  _MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
11
 
 
21
  current_node: str
22
  clinical_brief: Optional[dict]
23
  frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
24
+ ros_stuck_count: int # consecutive turns stuck in ROS with no progress
25
 
26
 
 
 
27
  EMERGENCY_PHRASES = [
28
  "crushing chest pain", "can't breathe", "cannot breathe",
29
  "heart attack", "suicide", "kill myself", "can't move", "dying"
 
64
  return missing
65
 
66
 
67
+ def _get_last_user_message(msgs: list[dict]) -> str:
68
+ for m in reversed(msgs):
69
+ if m.get("role") == "user":
70
+ return m.get("content", "")
71
+ return ""
72
+
73
+
74
+ def _detect_repeat(msgs: list[dict], new_reply: str) -> bool:
75
+ """Return True if new_reply is identical to the last two stored assistant replies."""
76
  assistant_replies = [m.get("content", "") for m in msgs if m.get("role") == "assistant"]
77
+ if len(assistant_replies) >= 2:
78
+ return new_reply == assistant_replies[-1] == assistant_replies[-2]
79
+ return False
80
 
81
 
82
  # ------------------------------------------------------------------- nodes ---
 
103
 
104
  def agent_node(state: IntakeState) -> dict:
105
  """
106
+ Core agent β€” one LLM call per turn.
107
+ Extracts clinical data, generates next question, builds brief when complete.
 
 
108
  """
109
  msgs = state.get("messages", [])
110
 
111
+ # First call: no messages yet β†’ return opening greeting
112
  if not msgs or (len(msgs) == 1 and msgs[0]["role"] == "assistant"):
113
  return {
114
  "messages": [{"role": "assistant", "content": "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"}],
115
  "clinical_state": CombinedOutput().model_dump_json(),
116
  "frontend_stage": "intake",
117
  "current_node": "agent",
118
+ "ros_stuck_count": 0,
119
  }
120
 
121
  if msgs[-1]["role"] == "assistant":
 
123
 
124
  current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
125
  transcript = format_transcript(msgs)
126
+ ros_stuck_count = state.get("ros_stuck_count", 0)
127
 
 
128
  try:
129
  pre_state = CombinedOutput.model_validate_json(current_json)
130
  current_stage = compute_stage(pre_state)
 
132
  current_stage = "intake"
133
 
134
  import time
 
135
  print(f"[{time.time():.3f}] [Graph Node] Requesting LLM inference (stage={current_stage})...")
136
 
137
  llm = get_llm()
138
  result: CombinedOutput = llm.combined_call(transcript, current_json, stage=current_stage)
139
 
140
+ # ── ROS Hallucination Guard: max 1 new ROS system per turn ──────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  try:
142
+ prev_ros = json.loads(current_json).get("ros") or {}
 
143
  except Exception:
144
  prev_ros = {}
145
+
146
  new_ros_keys = [k for k in result.ros if k not in prev_ros]
147
  if len(new_ros_keys) > 1:
148
+ print(f"[ROSGuard] LLM added {len(new_ros_keys)} systems in one turn: {new_ros_keys}. Keeping first only.")
149
  allowed_ros = dict(prev_ros)
150
  allowed_ros[new_ros_keys[0]] = result.ros[new_ros_keys[0]]
151
+ result = result.model_copy(update={"ros": allowed_ros})
152
+
153
+ # ── Loop Guard ───────────────────────────────────────────────────────
154
+ try:
155
+ prev_state_obj = CombinedOutput.model_validate_json(current_json)
156
+ prev_filled = sum(1 for f in HPI_FIELDS if getattr(prev_state_obj, f, None)) + len(prev_state_obj.ros)
157
+ new_filled = sum(1 for f in HPI_FIELDS if getattr(result, f, None)) + len(result.ros)
158
+ made_progress = new_filled > prev_filled
159
+ except Exception:
160
+ made_progress = True # assume progress on parse error
161
+
162
+ hpi_complete = all(getattr(result, f, None) for f in HPI_FIELDS)
163
+
164
+ if not made_progress:
165
+ last_user_msg = _get_last_user_message(msgs)
166
+
167
+ if not hpi_complete:
168
+ # HPI stuck β€” force-fill the first empty field
169
+ for stuck_field in HPI_FIELDS:
170
+ if not getattr(result, stuck_field, None):
171
+ result = result.model_copy(update={stuck_field: last_user_msg or "not specified"})
172
+ print(f"[LoopGuard] Force-filled HPI '{stuck_field}' = '{last_user_msg or 'not specified'}'")
173
+ break
174
+ else:
175
+ # ROS stuck β€” force-store the user's answer into a pending ros_asked system
176
+ ros_stuck_count += 1
177
+ pending = [s for s in result.ros_asked if s not in result.ros]
178
+
179
+ if pending:
180
+ # Store whatever the user just said as the finding for this system
181
+ new_ros = dict(result.ros)
182
+ new_ros[pending[0]] = [last_user_msg] if last_user_msg else ["no symptoms reported"]
183
+ result = result.model_copy(update={"ros": new_ros})
184
+ print(f"[LoopGuard] Force-stored ROS['{pending[0]}'] = [{last_user_msg[:40]}]")
185
+ elif ros_stuck_count >= 2:
186
+ # LLM isn't even updating ros_asked β€” force a dummy system to unblock
187
+ stub_key = f"general_{len(result.ros)}"
188
+ new_ros = dict(result.ros)
189
+ new_ros[stub_key] = [last_user_msg] if last_user_msg else ["no additional symptoms"]
190
+ result = result.model_copy(update={"ros": new_ros})
191
+ print(f"[LoopGuard] Force-added stub ROS['{stub_key}'] after {ros_stuck_count} stuck turns.")
192
+ ros_stuck_count = 0
193
+ else:
194
+ ros_stuck_count = 0 # reset counter when progress is made
195
 
196
  print(f"[{time.time():.3f}] [Graph Node] LLM returned. Preparing node dictionaries...")
197
 
 
199
  missing = missing_from(result)
200
  reply = result.reply or "Could you tell me more?"
201
 
202
+ # Sanitize reply β€” avoid storing empty or whitespace-only replies
203
+ if not reply.strip():
204
+ reply = "Could you tell me more?"
205
+
206
+ # All fields complete β€” build the brief inline
207
  if stage == "done":
208
  from datetime import datetime, timezone
209
  brief = ClinicalBrief(
 
227
  "frontend_stage": "done",
228
  "current_node": "done",
229
  "clinical_brief": brief.model_dump(),
230
+ "ros_stuck_count": 0,
231
  }
232
 
233
  return {
 
236
  "missing_fields": missing,
237
  "frontend_stage": stage,
238
  "current_node": "agent",
239
+ "ros_stuck_count": ros_stuck_count,
240
  }
241
 
242
 
 
256
  workflow.add_edge("agent", END)
257
 
258
  checkpointer = MemorySaver()
259
+ graph = workflow.compile(checkpointer=checkpointer)
260
+ # NOTE: interrupt_after removed β€” state accumulates via MemorySaver reducer
261
+ # on every fresh invoke, which is correct behavior (has_next is always False)
 
 
262
 
263
  return graph, checkpointer
app/llm.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import json
3
  from pydantic import BaseModel
4
 
5
- # ── Single unified system prompt β€” LLM sees the full workflow ──
6
  SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview.
7
 
8
  YOUR WORKFLOW (follow this order):
@@ -27,9 +26,14 @@ YOUR WORKFLOW (follow this order):
27
  CRITICAL RULES:
28
  - NEVER re-ask a field that is already filled (marked βœ… in the status).
29
  - Ask exactly ONE question per turn about the FIRST missing item.
30
- - If a patient says "none"/"zero"/"no"/"denied", store that exact answer β€” do NOT leave it null.
31
- - For ROS: store findings as a list, e.g. "musculoskeletal": ["joint stiffness", "no swelling"].
 
 
 
 
32
  - Do NOT ask emotional/psychological questions β€” stick to physical symptoms.
 
33
  - Output ONLY valid JSON, no extra text.
34
 
35
  OUTPUT FORMAT:
@@ -43,7 +47,7 @@ OUTPUT FORMAT:
43
  "aggravating": "..." or null,
44
  "relieving": "..." or null,
45
  "ros": {"system_name": ["finding1", "finding2"], ...},
46
- "ros_asked": ["system_name1"] (append any new system you ask about here to prevent repeating),
47
  "emergency": false,
48
  "reply": "Your single question"
49
  }"""
@@ -53,7 +57,6 @@ ROS_REQUIRED = 3
53
 
54
 
55
  def build_state_context(current_json: str) -> str:
56
- """Build a human-readable status summary so the LLM knows exactly what's filled and missing."""
57
  try:
58
  state = json.loads(current_json)
59
  except Exception:
@@ -61,14 +64,12 @@ def build_state_context(current_json: str) -> str:
61
 
62
  lines = ["FIELD STATUS:"]
63
 
64
- # Chief complaint
65
  cc = state.get("chief_complaint")
66
  if cc:
67
  lines.append(f' βœ… chief_complaint: "{cc}"')
68
  else:
69
  lines.append(" ❌ chief_complaint: MISSING β€” ask what brings them in")
70
 
71
- # HPI fields
72
  for field in HPI_FIELDS:
73
  val = state.get(field)
74
  if val:
@@ -76,7 +77,6 @@ def build_state_context(current_json: str) -> str:
76
  else:
77
  lines.append(f" ❌ {field}: MISSING")
78
 
79
- # ROS
80
  ros = state.get("ros", {})
81
  ros_asked = state.get("ros_asked", [])
82
  if ros:
@@ -90,9 +90,9 @@ def build_state_context(current_json: str) -> str:
90
  else:
91
  lines.append(f" βœ… ros: all {ROS_REQUIRED} systems collected")
92
 
93
- # Determine current phase
94
  if not cc:
95
  phase = "INTAKE"
 
96
  elif any(not state.get(f) for f in HPI_FIELDS):
97
  phase = "HPI"
98
  first_missing = next(f for f in HPI_FIELDS if not state.get(f))
@@ -100,12 +100,11 @@ def build_state_context(current_json: str) -> str:
100
  elif ros_remaining > 0:
101
  phase = "ROS"
102
  lines.append(f"\nCURRENT PHASE: {phase} β€” ask about the next body system relevant to '{cc}'")
 
 
103
  else:
104
  phase = "DONE"
105
- lines.append(f"\nCURRENT PHASE: {phase} β€” all data collected, set reply to completion message")
106
-
107
- if not cc:
108
- lines.append(f"\nCURRENT PHASE: {phase}")
109
 
110
  return "\n".join(lines)
111
 
@@ -126,7 +125,8 @@ class CombinedOutput(BaseModel):
126
 
127
 
128
  class MockLLM:
129
- """Minimal mock for testing β€” no regex, no extraction logic. Just walks through fields."""
 
130
  def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
131
  try:
132
  state = json.loads(current_json)
@@ -140,39 +140,62 @@ class MockLLM:
140
  last_patient_msg = line.replace("Patient:", "").strip()
141
  break
142
 
143
- hpi_fields = ["chief_complaint", "onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
144
  ros_systems = ["cardiac", "respiratory", "gi"]
145
 
146
  if stage == "intake":
147
  if last_patient_msg and not state.get("chief_complaint"):
148
- state["chief_complaint"] = last_patient_msg
149
- state["reply"] = "What brings you in today?" if not state.get("chief_complaint") else f"When did the {state['chief_complaint']} start?"
 
 
 
 
 
 
 
150
 
151
  elif stage == "hpi":
152
- for field in hpi_fields[1:]:
153
  if not state.get(field):
154
  if last_patient_msg:
155
  state[field] = last_patient_msg
156
  break
157
- for field in hpi_fields[1:]:
158
  if not state.get(field):
159
- labels = {"onset": "when it started", "location": "where you feel it",
160
- "duration": "how long it's lasted", "character": "what it feels like",
161
- "severity": "how severe it is (1-10)", "aggravating": "what makes it worse",
162
- "relieving": "what makes it better"}
 
 
 
 
 
163
  state["reply"] = f"Can you tell me {labels.get(field, field)}?"
164
  break
165
  else:
166
- state["reply"] = "Thank you, moving on to review of systems."
167
 
168
  elif stage == "ros":
169
  ros = state.get("ros", {})
 
 
 
 
 
 
 
170
  for sys_name in ros_systems:
171
  if sys_name not in ros:
172
  if last_patient_msg:
173
  ros[sys_name] = [last_patient_msg]
174
  state["ros"] = ros
 
 
 
175
  break
 
 
176
  for sys_name in ros_systems:
177
  if sys_name not in ros:
178
  state["reply"] = f"Any {sys_name} symptoms?"
@@ -189,10 +212,6 @@ class OllamaLLM:
189
  self.api_url = "http://localhost:11434/api/chat"
190
 
191
  def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
192
- """
193
- Single LLM call: extracts clinical data + generates next question.
194
- The unified prompt + state context gives the LLM full visibility.
195
- """
196
  state_context = build_state_context(current_json)
197
 
198
  prompt = (
@@ -201,6 +220,7 @@ class OllamaLLM:
201
  f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
202
  "TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. "
203
  "Then ask exactly ONE question about the FIRST missing item shown above. "
 
204
  "Return ONLY the updated JSON object."
205
  )
206
 
@@ -224,25 +244,24 @@ class OllamaLLM:
224
  "num_predict": 400
225
  }
226
  }
227
-
228
  try:
229
  response = requests.post(self.api_url, json=payload, timeout=60)
230
  response.raise_for_status()
231
  data = response.json()
232
  raw = data.get("message", {}).get("content", "").strip()
233
  except Exception as e:
234
- print(f"[Ollama] ERROR calling local Ollama API: {e}")
235
- print("[Ollama] Make sure Ollama is installed and running, and the model is downloaded!")
236
  return CombinedOutput.model_validate_json(current_json)
237
 
238
  print(f"[Ollama] Inference completed in {time.time() - t_start:.2f}s total.")
239
 
240
- # Parse JSON robustly
241
  json_str = raw
242
  if "```json" in json_str:
243
- json_str = json_str.split("```json", 1)[1].split("```")[0]
244
  elif "```" in json_str:
245
- json_str = json_str.split("```", 1)[1].split("```")[0]
246
 
247
  start = json_str.find("{")
248
  end = json_str.rfind("}") + 1
@@ -251,25 +270,55 @@ class OllamaLLM:
251
 
252
  try:
253
  parsed = json.loads(json_str)
254
- # Coerce empty strings and literal "null" back to None
 
255
  for field in ["chief_complaint", "onset", "location", "duration",
256
  "character", "severity", "aggravating", "relieving"]:
257
  v = parsed.get(field)
258
- if v is not None and str(v).strip() in ("", "null"):
 
 
 
259
  parsed[field] = None
260
- return CombinedOutput.model_validate(parsed)
 
 
261
  except Exception as e:
262
  print(f"[Ollama] JSON parse error: {e}\nRaw output: {raw[:300]}")
263
  try:
264
- base = CombinedOutput.model_validate_json(current_json)
265
- base.reply = "Could you please repeat that? I want to make sure I understood correctly."
266
- return base
267
  except Exception:
268
  return CombinedOutput(reply="Could you please repeat that?")
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  _llm_instance = None
272
 
 
273
  def get_llm():
274
  global _llm_instance
275
  if _llm_instance is None:
 
2
  import json
3
  from pydantic import BaseModel
4
 
 
5
  SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview.
6
 
7
  YOUR WORKFLOW (follow this order):
 
26
  CRITICAL RULES:
27
  - NEVER re-ask a field that is already filled (marked βœ… in the status).
28
  - Ask exactly ONE question per turn about the FIRST missing item.
29
+ - For HPI: accept any answer the patient gives, even vague ones like "moderate" or "not sure".
30
+ - For ROS: ALWAYS add the system to BOTH "ros" and "ros_asked" β€” even for negative answers.
31
+ - Positive finding: "cardiac": ["palpitations present"]
32
+ - Negative finding: "respiratory": ["no shortness of breath"]
33
+ - Denied: "gi": ["denied nausea and vomiting"]
34
+ A "no" is still a valid clinical finding. Never leave a ros system in ros_asked but absent from ros.
35
  - Do NOT ask emotional/psychological questions β€” stick to physical symptoms.
36
+ - All string fields must be strings, not arrays.
37
  - Output ONLY valid JSON, no extra text.
38
 
39
  OUTPUT FORMAT:
 
47
  "aggravating": "..." or null,
48
  "relieving": "..." or null,
49
  "ros": {"system_name": ["finding1", "finding2"], ...},
50
+ "ros_asked": ["system_name1", "system_name2"],
51
  "emergency": false,
52
  "reply": "Your single question"
53
  }"""
 
57
 
58
 
59
  def build_state_context(current_json: str) -> str:
 
60
  try:
61
  state = json.loads(current_json)
62
  except Exception:
 
64
 
65
  lines = ["FIELD STATUS:"]
66
 
 
67
  cc = state.get("chief_complaint")
68
  if cc:
69
  lines.append(f' βœ… chief_complaint: "{cc}"')
70
  else:
71
  lines.append(" ❌ chief_complaint: MISSING β€” ask what brings them in")
72
 
 
73
  for field in HPI_FIELDS:
74
  val = state.get(field)
75
  if val:
 
77
  else:
78
  lines.append(f" ❌ {field}: MISSING")
79
 
 
80
  ros = state.get("ros", {})
81
  ros_asked = state.get("ros_asked", [])
82
  if ros:
 
90
  else:
91
  lines.append(f" βœ… ros: all {ROS_REQUIRED} systems collected")
92
 
 
93
  if not cc:
94
  phase = "INTAKE"
95
+ lines.append(f"\nCURRENT PHASE: {phase}")
96
  elif any(not state.get(f) for f in HPI_FIELDS):
97
  phase = "HPI"
98
  first_missing = next(f for f in HPI_FIELDS if not state.get(f))
 
100
  elif ros_remaining > 0:
101
  phase = "ROS"
102
  lines.append(f"\nCURRENT PHASE: {phase} β€” ask about the next body system relevant to '{cc}'")
103
+ lines.append(f" ⚠️ IMPORTANT: Store BOTH positive AND negative ROS findings in 'ros' dict.")
104
+ lines.append(f" ⚠️ A patient saying 'no' means: ros[\"system\"] = [\"no [symptom]\"]")
105
  else:
106
  phase = "DONE"
107
+ lines.append(f"\nCURRENT PHASE: {phase} β€” all data collected")
 
 
 
108
 
109
  return "\n".join(lines)
110
 
 
125
 
126
 
127
  class MockLLM:
128
+ """Minimal mock for testing β€” deterministic field walker."""
129
+
130
  def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
131
  try:
132
  state = json.loads(current_json)
 
140
  last_patient_msg = line.replace("Patient:", "").strip()
141
  break
142
 
 
143
  ros_systems = ["cardiac", "respiratory", "gi"]
144
 
145
  if stage == "intake":
146
  if last_patient_msg and not state.get("chief_complaint"):
147
+ # Strip greeting words
148
+ greetings = {"hello", "hi", "hey", "ok", "okay", "start", "yes", "sure"}
149
+ if last_patient_msg.lower() not in greetings and len(last_patient_msg) > 4:
150
+ state["chief_complaint"] = last_patient_msg
151
+ state["reply"] = (
152
+ "What brings you in today?"
153
+ if not state.get("chief_complaint")
154
+ else f"When did the {state['chief_complaint']} start?"
155
+ )
156
 
157
  elif stage == "hpi":
158
+ for field in HPI_FIELDS:
159
  if not state.get(field):
160
  if last_patient_msg:
161
  state[field] = last_patient_msg
162
  break
163
+ for field in HPI_FIELDS:
164
  if not state.get(field):
165
+ labels = {
166
+ "onset": "when it started",
167
+ "location": "where you feel it",
168
+ "duration": "how long it's lasted",
169
+ "character": "what it feels like",
170
+ "severity": "how severe it is (1-10)",
171
+ "aggravating": "what makes it worse",
172
+ "relieving": "what makes it better",
173
+ }
174
  state["reply"] = f"Can you tell me {labels.get(field, field)}?"
175
  break
176
  else:
177
+ state["reply"] = "Thank you, let me ask about other symptoms."
178
 
179
  elif stage == "ros":
180
  ros = state.get("ros", {})
181
+ ros_asked = state.get("ros_asked", [])
182
+
183
+ # Detect emergency keywords
184
+ if any(k in last_patient_msg.lower() for k in ["crushing", "can't breathe", "dying"]):
185
+ state["emergency"] = True
186
+
187
+ # Store last patient message into the first un-asked system
188
  for sys_name in ros_systems:
189
  if sys_name not in ros:
190
  if last_patient_msg:
191
  ros[sys_name] = [last_patient_msg]
192
  state["ros"] = ros
193
+ if sys_name not in ros_asked:
194
+ ros_asked.append(sys_name)
195
+ state["ros_asked"] = ros_asked
196
  break
197
+
198
+ # Ask about the next un-asked system
199
  for sys_name in ros_systems:
200
  if sys_name not in ros:
201
  state["reply"] = f"Any {sys_name} symptoms?"
 
212
  self.api_url = "http://localhost:11434/api/chat"
213
 
214
  def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
 
 
 
 
215
  state_context = build_state_context(current_json)
216
 
217
  prompt = (
 
220
  f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
221
  "TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. "
222
  "Then ask exactly ONE question about the FIRST missing item shown above. "
223
+ "For ROS: if the patient answers about a system (even 'no'), add it to BOTH ros AND ros_asked. "
224
  "Return ONLY the updated JSON object."
225
  )
226
 
 
244
  "num_predict": 400
245
  }
246
  }
247
+
248
  try:
249
  response = requests.post(self.api_url, json=payload, timeout=60)
250
  response.raise_for_status()
251
  data = response.json()
252
  raw = data.get("message", {}).get("content", "").strip()
253
  except Exception as e:
254
+ print(f"[Ollama] ERROR calling Ollama API: {e}")
 
255
  return CombinedOutput.model_validate_json(current_json)
256
 
257
  print(f"[Ollama] Inference completed in {time.time() - t_start:.2f}s total.")
258
 
259
+ # Strip markdown fences
260
  json_str = raw
261
  if "```json" in json_str:
262
+ json_str = json_str.split("```json", 1).split("```")[1]
263
  elif "```" in json_str:
264
+ json_str = json_str.split("```", 1)[3].split("```")[0]
265
 
266
  start = json_str.find("{")
267
  end = json_str.rfind("}") + 1
 
270
 
271
  try:
272
  parsed = json.loads(json_str)
273
+
274
+ # ── Coerce all HPI string fields: listβ†’str, empty/nullβ†’None ──
275
  for field in ["chief_complaint", "onset", "location", "duration",
276
  "character", "severity", "aggravating", "relieving"]:
277
  v = parsed.get(field)
278
+ if isinstance(v, list):
279
+ # e.g. ["Walking"] β†’ "Walking"
280
+ parsed[field] = " ".join(str(x) for x in v) if v else None
281
+ elif v is not None and str(v).strip() in ("", "null"):
282
  parsed[field] = None
283
+
284
+ result = CombinedOutput.model_validate(parsed)
285
+
286
  except Exception as e:
287
  print(f"[Ollama] JSON parse error: {e}\nRaw output: {raw[:300]}")
288
  try:
289
+ result = CombinedOutput.model_validate_json(current_json)
290
+ result = result.model_copy(update={"reply": "Could you please repeat that? I want to make sure I understood correctly."})
291
+ return result
292
  except Exception:
293
  return CombinedOutput(reply="Could you please repeat that?")
294
 
295
+ # ── Post-process: normalize ros_asked β†’ ros ──────────────────────
296
+ # If LLM added a system to ros_asked but not ros (e.g. for "no" answers),
297
+ # capture the last patient message as the finding for that system.
298
+ if result.ros_asked:
299
+ last_user = ""
300
+ for line in reversed(transcript.strip().split("\n")):
301
+ if line.startswith("Patient:"):
302
+ last_user = line.replace("Patient:", "").strip()
303
+ break
304
+
305
+ updated_ros = dict(result.ros)
306
+ changed = False
307
+ for asked_sys in result.ros_asked:
308
+ if asked_sys not in updated_ros:
309
+ updated_ros[asked_sys] = [last_user] if last_user else ["no symptoms reported"]
310
+ print(f"[ROSNorm] Filled ros['{asked_sys}'] from patient message: '{last_user[:40]}'")
311
+ changed = True
312
+ if changed:
313
+ result = result.model_copy(update={"ros": updated_ros})
314
+
315
+ print(f"[Ollama] Parsed result β€” stage will be recomputed in graph.")
316
+ return result
317
+
318
 
319
  _llm_instance = None
320
 
321
+
322
  def get_llm():
323
  global _llm_instance
324
  if _llm_instance is None: