priyansh-saxena1 commited on
Commit
daf4268
Β·
1 Parent(s): f42a7a8

feat: unified prompt with state visibility

Browse files
Files changed (2) hide show
  1. app/graph.py +2 -3
  2. app/llm.py +97 -87
app/graph.py CHANGED
@@ -4,7 +4,7 @@ from typing import Optional, TypedDict, Annotated
4
  from langgraph.graph import StateGraph, START, END
5
  from langgraph.checkpoint.memory import MemorySaver
6
 
7
- from app.llm import get_llm, CombinedOutput
8
  from app.schemas import ClinicalBrief, HPI, ClinicalStateExtraction
9
 
10
  _MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
@@ -23,8 +23,7 @@ class IntakeState(TypedDict):
23
  frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
24
 
25
 
26
- HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
27
- ROS_REQUIRED = 3
28
 
29
  EMERGENCY_PHRASES = [
30
  "crushing chest pain", "can't breathe", "cannot breathe",
 
4
  from langgraph.graph import StateGraph, START, END
5
  from langgraph.checkpoint.memory import MemorySaver
6
 
7
+ from app.llm import get_llm, CombinedOutput, HPI_FIELDS, ROS_REQUIRED
8
  from app.schemas import ClinicalBrief, HPI, ClinicalStateExtraction
9
 
10
  _MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
 
23
  frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
24
 
25
 
26
+ # HPI_FIELDS and ROS_REQUIRED imported from app.llm
 
27
 
28
  EMERGENCY_PHRASES = [
29
  "crushing chest pain", "can't breathe", "cannot breathe",
app/llm.py CHANGED
@@ -2,47 +2,39 @@ import os
2
  import json
3
  from pydantic import BaseModel
4
 
5
- INTAKE_PROMPT = """You are a clinical intake assistant. The patient just arrived.
6
-
7
- JOB: Extract the chief complaint from the conversation. Ask ONE simple question to identify their main symptom.
8
-
9
- RULES:
10
- - Output ONLY valid JSON.
11
- - If you already know the chief complaint, ask about onset to move forward.
12
- - Do NOT diagnose or give medical advice.
13
-
14
- OUTPUT FORMAT:
15
- {
16
- "chief_complaint": "the main symptom" or null,
17
- "onset": null, "location": null, "duration": null,
18
- "character": null, "severity": null, "aggravating": null, "relieving": null,
19
- "ros": {},
20
- "reply": "Your question to the patient"
21
- }"""
22
-
23
- HPI_PROMPT = """You are a clinical intake assistant collecting History of Present Illness (HPI) using OLDCARTS.
24
-
25
- JOB 1 (EXTRACT): Read the conversation and update the JSON with any new patient info. If a patient denies something or says "none"/"zero"/"no", store that exact word β€” do NOT leave it null.
26
-
27
- JOB 2 (RESPOND): Ask ONE question about the FIRST missing field below. Do NOT re-ask fields already filled.
28
-
29
- FIELDS TO COLLECT (in order):
30
- - onset: when the symptom started
31
- - location: where in the body
32
- - duration: how long it has lasted
33
- - character: quality of pain (sharp, dull, pressure, burning, etc.)
34
- - severity: how bad on a scale of 1-10
35
- - aggravating: what makes it worse
36
- - relieving: what makes it better
37
-
38
- RULES:
39
  - Output ONLY valid JSON, no extra text.
40
- - Ask exactly ONE question per turn.
41
- - Keep existing values. Use null for unknowns.
42
 
43
  OUTPUT FORMAT:
44
  {
45
- "chief_complaint": "...",
46
  "onset": "..." or null,
47
  "location": "..." or null,
48
  "duration": "..." or null,
@@ -50,49 +42,68 @@ OUTPUT FORMAT:
50
  "severity": "..." or null,
51
  "aggravating": "..." or null,
52
  "relieving": "..." or null,
53
- "ros": {},
 
54
  "reply": "Your single question"
55
  }"""
56
 
57
- ROS_PROMPT = """You are a clinical intake assistant performing a Review of Systems (ROS).
58
-
59
- All HPI fields are already collected. Now you must screen for symptoms in OTHER body systems that are RELEVANT to the patient's chief complaint.
60
-
61
- JOB 1 (EXTRACT): The patient just answered a question about a body system. Extract their answer into the "ros" dict under the appropriate system key (e.g. "musculoskeletal": ["joint stiffness", "no swelling"]).
62
 
63
- JOB 2 (RESPOND): Ask about the NEXT relevant body system that is NOT yet in the "ros" dict.
64
 
65
- CHOOSING SYSTEMS: Pick 3 systems that are clinically relevant to the chief complaint. Examples:
66
- - Leg/knee/joint pain β†’ musculoskeletal, neurological, vascular
67
- - Chest pain β†’ cardiac, respiratory, gi
68
- - Headache β†’ neurological, ophthalmologic, ent
69
- - Abdominal pain β†’ gi, genitourinary, musculoskeletal
70
- - Back pain β†’ musculoskeletal, neurological, genitourinary
71
 
72
- RULES:
73
- - Output ONLY valid JSON.
74
- - Ask about ONE system at a time.
75
- - If the patient denies symptoms, store as ["no X", "no Y"].
76
- - Once 3 systems are in "ros", set reply to "Thank you β€” I have everything I need."
77
- - Do NOT ask emotional, psychological, or off-topic questions.
78
 
79
- OUTPUT FORMAT:
80
- {
81
- "chief_complaint": "...", "onset": "...", "location": "...", "duration": "...",
82
- "character": "...", "severity": "...", "aggravating": "...", "relieving": "...",
83
- "ros": {"system_name": ["findings"], ...},
84
- "reply": "Your single ROS question"
85
- }"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
87
 
88
- def get_system_prompt(stage: str) -> str:
89
- """Return the appropriate system prompt for the current clinical stage."""
90
- if stage == "ros":
91
- return ROS_PROMPT
92
- elif stage == "hpi":
93
- return HPI_PROMPT
94
- else:
95
- return INTAKE_PROMPT
96
 
97
 
98
  class CombinedOutput(BaseModel):
@@ -117,7 +128,6 @@ class MockLLM:
117
  except Exception:
118
  state = {}
119
 
120
- # Mock just steps through HPI fields in order, using the patient's last message as the value
121
  lines = transcript.strip().split("\n")
122
  last_patient_msg = ""
123
  for line in reversed(lines):
@@ -134,13 +144,11 @@ class MockLLM:
134
  state["reply"] = "What brings you in today?" if not state.get("chief_complaint") else f"When did the {state['chief_complaint']} start?"
135
 
136
  elif stage == "hpi":
137
- # Fill the first empty HPI field with the patient's answer
138
- for field in hpi_fields[1:]: # skip chief_complaint, already filled
139
  if not state.get(field):
140
  if last_patient_msg:
141
  state[field] = last_patient_msg
142
  break
143
- # Ask about the next missing field
144
  for field in hpi_fields[1:]:
145
  if not state.get(field):
146
  labels = {"onset": "when it started", "location": "where you feel it",
@@ -154,14 +162,12 @@ class MockLLM:
154
 
155
  elif stage == "ros":
156
  ros = state.get("ros", {})
157
- # Fill the first empty ROS system
158
  for sys_name in ros_systems:
159
  if sys_name not in ros:
160
  if last_patient_msg:
161
  ros[sys_name] = [last_patient_msg]
162
  state["ros"] = ros
163
  break
164
- # Ask about next missing system
165
  for sys_name in ros_systems:
166
  if sys_name not in ros:
167
  state["reply"] = f"Any {sys_name} symptoms?"
@@ -179,34 +185,38 @@ class OllamaLLM:
179
 
180
  def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
181
  """
182
- Calls the local Ollama instance using the /chat endpoint so system tags
183
- are properly applied.
184
  """
 
 
185
  prompt = (
 
186
  f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
187
- f"FULL CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
188
- "Instructions: Extract all new clinical facts from the transcript, merge them into the state, "
189
- "and generate exactly ONE empathetic follow-up question for whatever is still missing. "
190
- "Return ONLY the JSON object, no other text."
191
  )
192
 
193
  import time
194
  import requests
195
-
196
  t_start = time.time()
197
  print(f"[Ollama] Starting inference for model '{self.model_name}'...")
198
-
 
199
  payload = {
200
  "model": self.model_name,
201
  "messages": [
202
- {"role": "system", "content": get_system_prompt(stage)},
203
  {"role": "user", "content": prompt}
204
  ],
205
  "format": "json",
206
  "stream": False,
207
  "options": {
208
  "temperature": 0.0,
209
- "num_predict": 250
210
  }
211
  }
212
 
 
2
  import json
3
  from pydantic import BaseModel
4
 
5
+ # ── Single unified system prompt β€” LLM sees the full workflow ──
6
+ SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview.
7
+
8
+ YOUR WORKFLOW (follow this order):
9
+ 1. INTAKE: Identify the patient's chief complaint (main reason for visit).
10
+ 2. HPI (History of Present Illness): Collect these fields ONE AT A TIME, in order:
11
+ - onset: when the symptom started
12
+ - location: where in the body
13
+ - duration: how long it has lasted
14
+ - character: quality (sharp, dull, pressure, burning, etc.)
15
+ - severity: how bad on a scale of 1-10
16
+ - aggravating: what makes it worse
17
+ - relieving: what makes it better
18
+ 3. ROS (Review of Systems): Screen 3 body systems RELEVANT to the chief complaint.
19
+ Examples of relevant systems:
20
+ - Leg/knee/joint pain β†’ musculoskeletal, neurological, vascular
21
+ - Chest pain β†’ cardiac, respiratory, gi
22
+ - Headache β†’ neurological, ophthalmologic, ent
23
+ - Abdominal pain β†’ gi, genitourinary, musculoskeletal
24
+ - Back pain β†’ musculoskeletal, neurological, genitourinary
25
+ 4. DONE: When all HPI fields AND 3 ROS systems are filled, set reply to "Your clinical summary is ready. Please wait for the doctor."
26
+
27
+ CRITICAL RULES:
28
+ - NEVER re-ask a field that is already filled (marked βœ… in the status).
29
+ - Ask exactly ONE question per turn about the FIRST missing item.
30
+ - If a patient says "none"/"zero"/"no"/"denied", store that exact answer β€” do NOT leave it null.
31
+ - For ROS: store findings as a list, e.g. "musculoskeletal": ["joint stiffness", "no swelling"].
32
+ - Do NOT ask emotional/psychological questions β€” stick to physical symptoms.
 
 
 
 
 
 
33
  - Output ONLY valid JSON, no extra text.
 
 
34
 
35
  OUTPUT FORMAT:
36
  {
37
+ "chief_complaint": "..." or null,
38
  "onset": "..." or null,
39
  "location": "..." or null,
40
  "duration": "..." or null,
 
42
  "severity": "..." or null,
43
  "aggravating": "..." or null,
44
  "relieving": "..." or null,
45
+ "ros": {"system_name": ["finding1", "finding2"], ...},
46
+ "emergency": false,
47
  "reply": "Your single question"
48
  }"""
49
 
50
+ HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
51
+ ROS_REQUIRED = 3
 
 
 
52
 
 
53
 
54
+ def build_state_context(current_json: str) -> str:
55
+ """Build a human-readable status summary so the LLM knows exactly what's filled and missing."""
56
+ try:
57
+ state = json.loads(current_json)
58
+ except Exception:
59
+ state = {}
60
 
61
+ lines = ["FIELD STATUS:"]
 
 
 
 
 
62
 
63
+ # Chief complaint
64
+ cc = state.get("chief_complaint")
65
+ if cc:
66
+ lines.append(f' βœ… chief_complaint: "{cc}"')
67
+ else:
68
+ lines.append(" ❌ chief_complaint: MISSING β€” ask what brings them in")
69
+
70
+ # HPI fields
71
+ for field in HPI_FIELDS:
72
+ val = state.get(field)
73
+ if val:
74
+ lines.append(f' βœ… {field}: "{val}"')
75
+ else:
76
+ lines.append(f" ❌ {field}: MISSING")
77
+
78
+ # ROS
79
+ ros = state.get("ros", {})
80
+ if ros:
81
+ for sys_name, findings in ros.items():
82
+ lines.append(f' βœ… ros.{sys_name}: {findings}')
83
+ ros_remaining = ROS_REQUIRED - len(ros)
84
+ if ros_remaining > 0:
85
+ lines.append(f" ❌ ros: {ros_remaining} more system(s) needed")
86
+ else:
87
+ lines.append(f" βœ… ros: all {ROS_REQUIRED} systems collected")
88
+
89
+ # Determine current phase
90
+ if not cc:
91
+ phase = "INTAKE"
92
+ elif any(not state.get(f) for f in HPI_FIELDS):
93
+ phase = "HPI"
94
+ first_missing = next(f for f in HPI_FIELDS if not state.get(f))
95
+ lines.append(f"\nCURRENT PHASE: {phase} β€” ask about '{first_missing}' next")
96
+ elif ros_remaining > 0:
97
+ phase = "ROS"
98
+ lines.append(f"\nCURRENT PHASE: {phase} β€” ask about the next body system relevant to '{cc}'")
99
+ else:
100
+ phase = "DONE"
101
+ lines.append(f"\nCURRENT PHASE: {phase} β€” all data collected, set reply to completion message")
102
 
103
+ if not cc:
104
+ lines.append(f"\nCURRENT PHASE: {phase}")
105
 
106
+ return "\n".join(lines)
 
 
 
 
 
 
 
107
 
108
 
109
  class CombinedOutput(BaseModel):
 
128
  except Exception:
129
  state = {}
130
 
 
131
  lines = transcript.strip().split("\n")
132
  last_patient_msg = ""
133
  for line in reversed(lines):
 
144
  state["reply"] = "What brings you in today?" if not state.get("chief_complaint") else f"When did the {state['chief_complaint']} start?"
145
 
146
  elif stage == "hpi":
147
+ for field in hpi_fields[1:]:
 
148
  if not state.get(field):
149
  if last_patient_msg:
150
  state[field] = last_patient_msg
151
  break
 
152
  for field in hpi_fields[1:]:
153
  if not state.get(field):
154
  labels = {"onset": "when it started", "location": "where you feel it",
 
162
 
163
  elif stage == "ros":
164
  ros = state.get("ros", {})
 
165
  for sys_name in ros_systems:
166
  if sys_name not in ros:
167
  if last_patient_msg:
168
  ros[sys_name] = [last_patient_msg]
169
  state["ros"] = ros
170
  break
 
171
  for sys_name in ros_systems:
172
  if sys_name not in ros:
173
  state["reply"] = f"Any {sys_name} symptoms?"
 
185
 
186
  def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
187
  """
188
+ Single LLM call: extracts clinical data + generates next question.
189
+ The unified prompt + state context gives the LLM full visibility.
190
  """
191
+ state_context = build_state_context(current_json)
192
+
193
  prompt = (
194
+ f"{state_context}\n\n"
195
  f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
196
+ f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
197
+ "TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. "
198
+ "Then ask exactly ONE question about the FIRST missing item shown above. "
199
+ "Return ONLY the updated JSON object."
200
  )
201
 
202
  import time
203
  import requests
204
+
205
  t_start = time.time()
206
  print(f"[Ollama] Starting inference for model '{self.model_name}'...")
207
+ print(f"[Ollama] State context:\n{state_context}")
208
+
209
  payload = {
210
  "model": self.model_name,
211
  "messages": [
212
+ {"role": "system", "content": SYSTEM_PROMPT},
213
  {"role": "user", "content": prompt}
214
  ],
215
  "format": "json",
216
  "stream": False,
217
  "options": {
218
  "temperature": 0.0,
219
+ "num_predict": 400
220
  }
221
  }
222