priyansh-saxena1 commited on
Commit
27b1ed4
Β·
1 Parent(s): b7c799b

feat: stage-specific prompts + contextual ROS

Browse files
Files changed (2) hide show
  1. app/graph.py +11 -21
  2. app/llm.py +89 -38
app/graph.py CHANGED
@@ -119,12 +119,19 @@ def agent_node(state: IntakeState) -> dict:
119
  current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
120
  transcript = format_transcript(msgs)
121
 
 
 
 
 
 
 
 
122
  import time
123
  t_agent = time.time()
124
- print(f"[{time.time():.3f}] [Graph Node] Requesting LLM inference...")
125
 
126
  llm = get_llm()
127
- result: CombinedOutput = llm.combined_call(transcript, current_json)
128
 
129
  # ── Loop Guard: if LLM returned same reply as last turn, force-fill stuck field ──
130
  if _detect_repeat({"messages": msgs + [{"role": "assistant", "content": result.reply}]}):
@@ -141,11 +148,6 @@ def agent_node(state: IntakeState) -> dict:
141
  break
142
 
143
  # ── ROS Hallucination Guard: LLM can only ADD one new ROS system per turn ──
144
- ROS_QUESTIONS = {
145
- "cardiac": "Have you experienced any palpitations, leg swelling, or dizziness?",
146
- "respiratory": "Have you had any shortness of breath, coughing, or wheezing?",
147
- "gi": "Have you had any nausea, vomiting, or heartburn?",
148
- }
149
  try:
150
  prev_state = json.loads(current_json)
151
  prev_ros = prev_state.get("ros") or {}
@@ -153,7 +155,7 @@ def agent_node(state: IntakeState) -> dict:
153
  prev_ros = {}
154
  new_ros_keys = [k for k in result.ros if k not in prev_ros]
155
  if len(new_ros_keys) > 1:
156
- print(f"[ROSGuard] LLM hallucinated {len(new_ros_keys)} new ROS systems in one turn: {new_ros_keys}. Keeping only first.")
157
  allowed_ros = dict(prev_ros)
158
  allowed_ros[new_ros_keys[0]] = result.ros[new_ros_keys[0]]
159
  object.__setattr__(result, "ros", allowed_ros)
@@ -162,19 +164,7 @@ def agent_node(state: IntakeState) -> dict:
162
 
163
  stage = compute_stage(result)
164
  missing = missing_from(result)
165
-
166
- # ── ROS Question Forcing: if all HPI done but ROS incomplete, force a specific ROS question ──
167
- if stage == "ros":
168
- current_ros = result.ros or {}
169
- for sys_name, question in ROS_QUESTIONS.items():
170
- if sys_name not in current_ros:
171
- print(f"[ROSForce] Forcing question for missing ROS system: {sys_name}")
172
- reply = question
173
- break
174
- else:
175
- reply = result.reply or "Could you tell me more?"
176
- else:
177
- reply = result.reply or "Could you tell me more?"
178
 
179
  # All fields complete β€” build the brief inline so it's available this turn
180
  if stage == "done":
 
119
  current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
120
  transcript = format_transcript(msgs)
121
 
122
+ # Compute the current stage BEFORE the LLM call so we can pick the right prompt
123
+ try:
124
+ pre_state = CombinedOutput.model_validate_json(current_json)
125
+ current_stage = compute_stage(pre_state)
126
+ except Exception:
127
+ current_stage = "intake"
128
+
129
  import time
130
  t_agent = time.time()
131
+ print(f"[{time.time():.3f}] [Graph Node] Requesting LLM inference (stage={current_stage})...")
132
 
133
  llm = get_llm()
134
+ result: CombinedOutput = llm.combined_call(transcript, current_json, stage=current_stage)
135
 
136
  # ── Loop Guard: if LLM returned same reply as last turn, force-fill stuck field ──
137
  if _detect_repeat({"messages": msgs + [{"role": "assistant", "content": result.reply}]}):
 
148
  break
149
 
150
  # ── ROS Hallucination Guard: LLM can only ADD one new ROS system per turn ──
 
 
 
 
 
151
  try:
152
  prev_state = json.loads(current_json)
153
  prev_ros = prev_state.get("ros") or {}
 
155
  prev_ros = {}
156
  new_ros_keys = [k for k in result.ros if k not in prev_ros]
157
  if len(new_ros_keys) > 1:
158
+ print(f"[ROSGuard] LLM added {len(new_ros_keys)} new ROS systems in one turn: {new_ros_keys}. Keeping only first.")
159
  allowed_ros = dict(prev_ros)
160
  allowed_ros[new_ros_keys[0]] = result.ros[new_ros_keys[0]]
161
  object.__setattr__(result, "ros", allowed_ros)
 
164
 
165
  stage = compute_stage(result)
166
  missing = missing_from(result)
167
+ reply = result.reply or "Could you tell me more?"
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # All fields complete β€” build the brief inline so it's available this turn
170
  if stage == "done":
app/llm.py CHANGED
@@ -3,46 +3,97 @@ import json
3
  import re
4
  from pydantic import BaseModel
5
 
6
- COMBINED_SYSTEM_PROMPT = """You are a clinical intake assistant AI. You have two jobs per turn:
7
 
8
- JOB 1 (EXTRACT): Read the FULL conversation and update the clinical JSON state with any new information the patient provided.
9
- CRITICAL: If the patient denies a symptom, or replies with "none", "zero", "no", or "nothing", you MUST extract that exact word (e.g. "zero"). DO NOT leave it null if the patient has answered the question negatively.
10
 
11
- JOB 2 (RESPOND): Based on what is STILL MISSING from the clinical state, ask the patient ONE natural, empathetic question. Do NOT ask about things already filled in.
12
-
13
- CRITICAL RULES:
14
- - Output ONLY valid JSON, nothing else.
15
  - Do NOT diagnose or give medical advice.
16
- - Do NOT ask more than one question.
17
- - If all fields are complete, set reply to "Thank you β€” I have everything I need."
18
 
19
- OUTPUT FORMAT (strictly follow this, no extra text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  {
21
- "chief_complaint": "the main symptom or reason for visit",
22
- "onset": "when the symptom started",
23
- "location": "where in the body",
24
- "duration": "how long it has lasted, constant or intermittent",
25
- "character": "quality of pain: sharp, dull, tightening, pressure, burning, squeezing, etc.",
26
- "severity": "how bad, e.g. mild, moderate, severe, or a number out of 10",
27
- "aggravating": "what makes it worse",
28
- "relieving": "what makes it better or go away",
29
- "ros": {"cardiac": ["findings"], "respiratory": ["findings"], "gi": ["findings"]},
30
- "reply": "The single question to ask the patient next"
31
- }
32
-
33
- REVIEW OF SYSTEMS (ROS): Once all HPI fields above are filled, ask about these 3 systems ONE AT A TIME:
34
- 1. Cardiac: palpitations, leg swelling, dizziness
35
- 2. Respiratory: shortness of breath, cough, wheezing
36
- 3. GI: nausea, vomiting, heartburn
37
- For each system the patient denies symptoms, store as ["no palpitations", "no leg swelling"]. Do NOT ask emotional or psychological questions β€” stick to the 3 systems above.
38
-
39
- Use null for any field not yet known. Keep existing values if the patient didn't add new info.
40
-
41
- IMPORTANT β€” ACCEPTING VAGUE ANSWERS:
42
- - If the patient gives ANY answer (even "none", "zero", "not sure", "it goes away", "very mild"), that IS a valid value. Store it as a string.
43
- - For relieving/aggravating: if patient implies rest helps (e.g. "very mild when not running", "zero at rest"), set relieving="rest" and aggravating="physical activity/running".
44
- - Do NOT ask the same question twice. If the patient has answered (even vaguely), move on to the next missing field.
45
- - "zero", "none", "not really", "it's fine otherwise" β†’ treat as valid answer, fill the field."""
46
 
47
 
48
  class CombinedOutput(BaseModel):
@@ -60,7 +111,7 @@ class CombinedOutput(BaseModel):
60
 
61
 
62
  class MockLLM:
63
- def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
64
  """Single call: extract + generate reply. No real inference in mock mode."""
65
  t = transcript.lower()
66
  try:
@@ -163,7 +214,7 @@ class OllamaLLM:
163
  self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
164
  self.api_url = "http://localhost:11434/api/chat"
165
 
166
- def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
167
  """
168
  Calls the local Ollama instance using the /chat endpoint so system tags
169
  are properly applied.
@@ -185,7 +236,7 @@ class OllamaLLM:
185
  payload = {
186
  "model": self.model_name,
187
  "messages": [
188
- {"role": "system", "content": COMBINED_SYSTEM_PROMPT},
189
  {"role": "user", "content": prompt}
190
  ],
191
  "format": "json",
 
3
  import re
4
  from pydantic import BaseModel
5
 
6
+ INTAKE_PROMPT = """You are a clinical intake assistant. The patient just arrived.
7
 
8
+ JOB: Extract the chief complaint from the conversation. Ask ONE simple question to identify their main symptom.
 
9
 
10
+ RULES:
11
+ - Output ONLY valid JSON.
12
+ - If you already know the chief complaint, ask about onset to move forward.
 
13
  - Do NOT diagnose or give medical advice.
 
 
14
 
15
+ OUTPUT FORMAT:
16
+ {
17
+ "chief_complaint": "the main symptom" or null,
18
+ "onset": null, "location": null, "duration": null,
19
+ "character": null, "severity": null, "aggravating": null, "relieving": null,
20
+ "ros": {},
21
+ "reply": "Your question to the patient"
22
+ }"""
23
+
24
+ HPI_PROMPT = """You are a clinical intake assistant collecting History of Present Illness (HPI) using OLDCARTS.
25
+
26
+ JOB 1 (EXTRACT): Read the conversation and update the JSON with any new patient info. If a patient denies something or says "none"/"zero"/"no", store that exact word β€” do NOT leave it null.
27
+
28
+ JOB 2 (RESPOND): Ask ONE question about the FIRST missing field below. Do NOT re-ask fields already filled.
29
+
30
+ FIELDS TO COLLECT (in order):
31
+ - onset: when the symptom started
32
+ - location: where in the body
33
+ - duration: how long it has lasted
34
+ - character: quality of pain (sharp, dull, pressure, burning, etc.)
35
+ - severity: how bad on a scale of 1-10
36
+ - aggravating: what makes it worse
37
+ - relieving: what makes it better
38
+
39
+ RULES:
40
+ - Output ONLY valid JSON, no extra text.
41
+ - Ask exactly ONE question per turn.
42
+ - Keep existing values. Use null for unknowns.
43
+
44
+ OUTPUT FORMAT:
45
+ {
46
+ "chief_complaint": "...",
47
+ "onset": "..." or null,
48
+ "location": "..." or null,
49
+ "duration": "..." or null,
50
+ "character": "..." or null,
51
+ "severity": "..." or null,
52
+ "aggravating": "..." or null,
53
+ "relieving": "..." or null,
54
+ "ros": {},
55
+ "reply": "Your single question"
56
+ }"""
57
+
58
+ ROS_PROMPT = """You are a clinical intake assistant performing a Review of Systems (ROS).
59
+
60
+ All HPI fields are already collected. Now you must screen for symptoms in OTHER body systems that are RELEVANT to the patient's chief complaint.
61
+
62
+ JOB 1 (EXTRACT): The patient just answered a question about a body system. Extract their answer into the "ros" dict under the appropriate system key (e.g. "musculoskeletal": ["joint stiffness", "no swelling"]).
63
+
64
+ JOB 2 (RESPOND): Ask about the NEXT relevant body system that is NOT yet in the "ros" dict.
65
+
66
+ CHOOSING SYSTEMS: Pick 3 systems that are clinically relevant to the chief complaint. Examples:
67
+ - Leg/knee/joint pain β†’ musculoskeletal, neurological, vascular
68
+ - Chest pain β†’ cardiac, respiratory, gi
69
+ - Headache β†’ neurological, ophthalmologic, ent
70
+ - Abdominal pain β†’ gi, genitourinary, musculoskeletal
71
+ - Back pain β†’ musculoskeletal, neurological, genitourinary
72
+
73
+ RULES:
74
+ - Output ONLY valid JSON.
75
+ - Ask about ONE system at a time.
76
+ - If the patient denies symptoms, store as ["no X", "no Y"].
77
+ - Once 3 systems are in "ros", set reply to "Thank you β€” I have everything I need."
78
+ - Do NOT ask emotional, psychological, or off-topic questions.
79
+
80
+ OUTPUT FORMAT:
81
  {
82
+ "chief_complaint": "...", "onset": "...", "location": "...", "duration": "...",
83
+ "character": "...", "severity": "...", "aggravating": "...", "relieving": "...",
84
+ "ros": {"system_name": ["findings"], ...},
85
+ "reply": "Your single ROS question"
86
+ }"""
87
+
88
+
89
+ def get_system_prompt(stage: str) -> str:
90
+ """Return the appropriate system prompt for the current clinical stage."""
91
+ if stage == "ros":
92
+ return ROS_PROMPT
93
+ elif stage == "hpi":
94
+ return HPI_PROMPT
95
+ else:
96
+ return INTAKE_PROMPT
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  class CombinedOutput(BaseModel):
 
111
 
112
 
113
  class MockLLM:
114
+ def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
115
  """Single call: extract + generate reply. No real inference in mock mode."""
116
  t = transcript.lower()
117
  try:
 
214
  self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
215
  self.api_url = "http://localhost:11434/api/chat"
216
 
217
+ def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
218
  """
219
  Calls the local Ollama instance using the /chat endpoint so system tags
220
  are properly applied.
 
236
  payload = {
237
  "model": self.model_name,
238
  "messages": [
239
+ {"role": "system", "content": get_system_prompt(stage)},
240
  {"role": "user", "content": prompt}
241
  ],
242
  "format": "json",