Spaces:
Sleeping
Sleeping
priyansh-saxena1 commited on
Commit Β·
27b1ed4
1
Parent(s): b7c799b
feat: stage-specific prompts + contextual ROS
Browse files- app/graph.py +11 -21
- app/llm.py +89 -38
app/graph.py
CHANGED
|
@@ -119,12 +119,19 @@ def agent_node(state: IntakeState) -> dict:
|
|
| 119 |
current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
|
| 120 |
transcript = format_transcript(msgs)
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
import time
|
| 123 |
t_agent = time.time()
|
| 124 |
-
print(f"[{time.time():.3f}] [Graph Node] Requesting LLM inference...")
|
| 125 |
|
| 126 |
llm = get_llm()
|
| 127 |
-
result: CombinedOutput = llm.combined_call(transcript, current_json)
|
| 128 |
|
| 129 |
# ββ Loop Guard: if LLM returned same reply as last turn, force-fill stuck field ββ
|
| 130 |
if _detect_repeat({"messages": msgs + [{"role": "assistant", "content": result.reply}]}):
|
|
@@ -141,11 +148,6 @@ def agent_node(state: IntakeState) -> dict:
|
|
| 141 |
break
|
| 142 |
|
| 143 |
# ββ ROS Hallucination Guard: LLM can only ADD one new ROS system per turn ββ
|
| 144 |
-
ROS_QUESTIONS = {
|
| 145 |
-
"cardiac": "Have you experienced any palpitations, leg swelling, or dizziness?",
|
| 146 |
-
"respiratory": "Have you had any shortness of breath, coughing, or wheezing?",
|
| 147 |
-
"gi": "Have you had any nausea, vomiting, or heartburn?",
|
| 148 |
-
}
|
| 149 |
try:
|
| 150 |
prev_state = json.loads(current_json)
|
| 151 |
prev_ros = prev_state.get("ros") or {}
|
|
@@ -153,7 +155,7 @@ def agent_node(state: IntakeState) -> dict:
|
|
| 153 |
prev_ros = {}
|
| 154 |
new_ros_keys = [k for k in result.ros if k not in prev_ros]
|
| 155 |
if len(new_ros_keys) > 1:
|
| 156 |
-
print(f"[ROSGuard] LLM
|
| 157 |
allowed_ros = dict(prev_ros)
|
| 158 |
allowed_ros[new_ros_keys[0]] = result.ros[new_ros_keys[0]]
|
| 159 |
object.__setattr__(result, "ros", allowed_ros)
|
|
@@ -162,19 +164,7 @@ def agent_node(state: IntakeState) -> dict:
|
|
| 162 |
|
| 163 |
stage = compute_stage(result)
|
| 164 |
missing = missing_from(result)
|
| 165 |
-
|
| 166 |
-
# ββ ROS Question Forcing: if all HPI done but ROS incomplete, force a specific ROS question ββ
|
| 167 |
-
if stage == "ros":
|
| 168 |
-
current_ros = result.ros or {}
|
| 169 |
-
for sys_name, question in ROS_QUESTIONS.items():
|
| 170 |
-
if sys_name not in current_ros:
|
| 171 |
-
print(f"[ROSForce] Forcing question for missing ROS system: {sys_name}")
|
| 172 |
-
reply = question
|
| 173 |
-
break
|
| 174 |
-
else:
|
| 175 |
-
reply = result.reply or "Could you tell me more?"
|
| 176 |
-
else:
|
| 177 |
-
reply = result.reply or "Could you tell me more?"
|
| 178 |
|
| 179 |
# All fields complete β build the brief inline so it's available this turn
|
| 180 |
if stage == "done":
|
|
|
|
| 119 |
current_json = state.get("clinical_state") or CombinedOutput().model_dump_json()
|
| 120 |
transcript = format_transcript(msgs)
|
| 121 |
|
| 122 |
+
# Compute the current stage BEFORE the LLM call so we can pick the right prompt
|
| 123 |
+
try:
|
| 124 |
+
pre_state = CombinedOutput.model_validate_json(current_json)
|
| 125 |
+
current_stage = compute_stage(pre_state)
|
| 126 |
+
except Exception:
|
| 127 |
+
current_stage = "intake"
|
| 128 |
+
|
| 129 |
import time
|
| 130 |
t_agent = time.time()
|
| 131 |
+
print(f"[{time.time():.3f}] [Graph Node] Requesting LLM inference (stage={current_stage})...")
|
| 132 |
|
| 133 |
llm = get_llm()
|
| 134 |
+
result: CombinedOutput = llm.combined_call(transcript, current_json, stage=current_stage)
|
| 135 |
|
| 136 |
# ββ Loop Guard: if LLM returned same reply as last turn, force-fill stuck field ββ
|
| 137 |
if _detect_repeat({"messages": msgs + [{"role": "assistant", "content": result.reply}]}):
|
|
|
|
| 148 |
break
|
| 149 |
|
| 150 |
# ββ ROS Hallucination Guard: LLM can only ADD one new ROS system per turn ββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
try:
|
| 152 |
prev_state = json.loads(current_json)
|
| 153 |
prev_ros = prev_state.get("ros") or {}
|
|
|
|
| 155 |
prev_ros = {}
|
| 156 |
new_ros_keys = [k for k in result.ros if k not in prev_ros]
|
| 157 |
if len(new_ros_keys) > 1:
|
| 158 |
+
print(f"[ROSGuard] LLM added {len(new_ros_keys)} new ROS systems in one turn: {new_ros_keys}. Keeping only first.")
|
| 159 |
allowed_ros = dict(prev_ros)
|
| 160 |
allowed_ros[new_ros_keys[0]] = result.ros[new_ros_keys[0]]
|
| 161 |
object.__setattr__(result, "ros", allowed_ros)
|
|
|
|
| 164 |
|
| 165 |
stage = compute_stage(result)
|
| 166 |
missing = missing_from(result)
|
| 167 |
+
reply = result.reply or "Could you tell me more?"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
# All fields complete β build the brief inline so it's available this turn
|
| 170 |
if stage == "done":
|
app/llm.py
CHANGED
|
@@ -3,46 +3,97 @@ import json
|
|
| 3 |
import re
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
-
|
| 7 |
|
| 8 |
-
JOB
|
| 9 |
-
CRITICAL: If the patient denies a symptom, or replies with "none", "zero", "no", or "nothing", you MUST extract that exact word (e.g. "zero"). DO NOT leave it null if the patient has answered the question negatively.
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
- Output ONLY valid JSON, nothing else.
|
| 15 |
- Do NOT diagnose or give medical advice.
|
| 16 |
-
- Do NOT ask more than one question.
|
| 17 |
-
- If all fields are complete, set reply to "Thank you β I have everything I need."
|
| 18 |
|
| 19 |
-
OUTPUT FORMAT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
{
|
| 21 |
-
"chief_complaint": "
|
| 22 |
-
"
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
3. GI: nausea, vomiting, heartburn
|
| 37 |
-
For each system the patient denies symptoms, store as ["no palpitations", "no leg swelling"]. Do NOT ask emotional or psychological questions β stick to the 3 systems above.
|
| 38 |
-
|
| 39 |
-
Use null for any field not yet known. Keep existing values if the patient didn't add new info.
|
| 40 |
-
|
| 41 |
-
IMPORTANT β ACCEPTING VAGUE ANSWERS:
|
| 42 |
-
- If the patient gives ANY answer (even "none", "zero", "not sure", "it goes away", "very mild"), that IS a valid value. Store it as a string.
|
| 43 |
-
- For relieving/aggravating: if patient implies rest helps (e.g. "very mild when not running", "zero at rest"), set relieving="rest" and aggravating="physical activity/running".
|
| 44 |
-
- Do NOT ask the same question twice. If the patient has answered (even vaguely), move on to the next missing field.
|
| 45 |
-
- "zero", "none", "not really", "it's fine otherwise" β treat as valid answer, fill the field."""
|
| 46 |
|
| 47 |
|
| 48 |
class CombinedOutput(BaseModel):
|
|
@@ -60,7 +111,7 @@ class CombinedOutput(BaseModel):
|
|
| 60 |
|
| 61 |
|
| 62 |
class MockLLM:
|
| 63 |
-
def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
|
| 64 |
"""Single call: extract + generate reply. No real inference in mock mode."""
|
| 65 |
t = transcript.lower()
|
| 66 |
try:
|
|
@@ -163,7 +214,7 @@ class OllamaLLM:
|
|
| 163 |
self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
|
| 164 |
self.api_url = "http://localhost:11434/api/chat"
|
| 165 |
|
| 166 |
-
def combined_call(self, transcript: str, current_json: str) -> CombinedOutput:
|
| 167 |
"""
|
| 168 |
Calls the local Ollama instance using the /chat endpoint so system tags
|
| 169 |
are properly applied.
|
|
@@ -185,7 +236,7 @@ class OllamaLLM:
|
|
| 185 |
payload = {
|
| 186 |
"model": self.model_name,
|
| 187 |
"messages": [
|
| 188 |
-
{"role": "system", "content":
|
| 189 |
{"role": "user", "content": prompt}
|
| 190 |
],
|
| 191 |
"format": "json",
|
|
|
|
| 3 |
import re
|
| 4 |
from pydantic import BaseModel
|
| 5 |
|
| 6 |
+
INTAKE_PROMPT = """You are a clinical intake assistant. The patient just arrived.
|
| 7 |
|
| 8 |
+
JOB: Extract the chief complaint from the conversation. Ask ONE simple question to identify their main symptom.
|
|
|
|
| 9 |
|
| 10 |
+
RULES:
|
| 11 |
+
- Output ONLY valid JSON.
|
| 12 |
+
- If you already know the chief complaint, ask about onset to move forward.
|
|
|
|
| 13 |
- Do NOT diagnose or give medical advice.
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
OUTPUT FORMAT:
|
| 16 |
+
{
|
| 17 |
+
"chief_complaint": "the main symptom" or null,
|
| 18 |
+
"onset": null, "location": null, "duration": null,
|
| 19 |
+
"character": null, "severity": null, "aggravating": null, "relieving": null,
|
| 20 |
+
"ros": {},
|
| 21 |
+
"reply": "Your question to the patient"
|
| 22 |
+
}"""
|
| 23 |
+
|
| 24 |
+
HPI_PROMPT = """You are a clinical intake assistant collecting History of Present Illness (HPI) using OLDCARTS.
|
| 25 |
+
|
| 26 |
+
JOB 1 (EXTRACT): Read the conversation and update the JSON with any new patient info. If a patient denies something or says "none"/"zero"/"no", store that exact word β do NOT leave it null.
|
| 27 |
+
|
| 28 |
+
JOB 2 (RESPOND): Ask ONE question about the FIRST missing field below. Do NOT re-ask fields already filled.
|
| 29 |
+
|
| 30 |
+
FIELDS TO COLLECT (in order):
|
| 31 |
+
- onset: when the symptom started
|
| 32 |
+
- location: where in the body
|
| 33 |
+
- duration: how long it has lasted
|
| 34 |
+
- character: quality of pain (sharp, dull, pressure, burning, etc.)
|
| 35 |
+
- severity: how bad on a scale of 1-10
|
| 36 |
+
- aggravating: what makes it worse
|
| 37 |
+
- relieving: what makes it better
|
| 38 |
+
|
| 39 |
+
RULES:
|
| 40 |
+
- Output ONLY valid JSON, no extra text.
|
| 41 |
+
- Ask exactly ONE question per turn.
|
| 42 |
+
- Keep existing values. Use null for unknowns.
|
| 43 |
+
|
| 44 |
+
OUTPUT FORMAT:
|
| 45 |
+
{
|
| 46 |
+
"chief_complaint": "...",
|
| 47 |
+
"onset": "..." or null,
|
| 48 |
+
"location": "..." or null,
|
| 49 |
+
"duration": "..." or null,
|
| 50 |
+
"character": "..." or null,
|
| 51 |
+
"severity": "..." or null,
|
| 52 |
+
"aggravating": "..." or null,
|
| 53 |
+
"relieving": "..." or null,
|
| 54 |
+
"ros": {},
|
| 55 |
+
"reply": "Your single question"
|
| 56 |
+
}"""
|
| 57 |
+
|
| 58 |
+
ROS_PROMPT = """You are a clinical intake assistant performing a Review of Systems (ROS).
|
| 59 |
+
|
| 60 |
+
All HPI fields are already collected. Now you must screen for symptoms in OTHER body systems that are RELEVANT to the patient's chief complaint.
|
| 61 |
+
|
| 62 |
+
JOB 1 (EXTRACT): The patient just answered a question about a body system. Extract their answer into the "ros" dict under the appropriate system key (e.g. "musculoskeletal": ["joint stiffness", "no swelling"]).
|
| 63 |
+
|
| 64 |
+
JOB 2 (RESPOND): Ask about the NEXT relevant body system that is NOT yet in the "ros" dict.
|
| 65 |
+
|
| 66 |
+
CHOOSING SYSTEMS: Pick 3 systems that are clinically relevant to the chief complaint. Examples:
|
| 67 |
+
- Leg/knee/joint pain β musculoskeletal, neurological, vascular
|
| 68 |
+
- Chest pain β cardiac, respiratory, gi
|
| 69 |
+
- Headache β neurological, ophthalmologic, ent
|
| 70 |
+
- Abdominal pain β gi, genitourinary, musculoskeletal
|
| 71 |
+
- Back pain β musculoskeletal, neurological, genitourinary
|
| 72 |
+
|
| 73 |
+
RULES:
|
| 74 |
+
- Output ONLY valid JSON.
|
| 75 |
+
- Ask about ONE system at a time.
|
| 76 |
+
- If the patient denies symptoms, store as ["no X", "no Y"].
|
| 77 |
+
- Once 3 systems are in "ros", set reply to "Thank you β I have everything I need."
|
| 78 |
+
- Do NOT ask emotional, psychological, or off-topic questions.
|
| 79 |
+
|
| 80 |
+
OUTPUT FORMAT:
|
| 81 |
{
|
| 82 |
+
"chief_complaint": "...", "onset": "...", "location": "...", "duration": "...",
|
| 83 |
+
"character": "...", "severity": "...", "aggravating": "...", "relieving": "...",
|
| 84 |
+
"ros": {"system_name": ["findings"], ...},
|
| 85 |
+
"reply": "Your single ROS question"
|
| 86 |
+
}"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_system_prompt(stage: str) -> str:
|
| 90 |
+
"""Return the appropriate system prompt for the current clinical stage."""
|
| 91 |
+
if stage == "ros":
|
| 92 |
+
return ROS_PROMPT
|
| 93 |
+
elif stage == "hpi":
|
| 94 |
+
return HPI_PROMPT
|
| 95 |
+
else:
|
| 96 |
+
return INTAKE_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
class CombinedOutput(BaseModel):
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
class MockLLM:
|
| 114 |
+
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
|
| 115 |
"""Single call: extract + generate reply. No real inference in mock mode."""
|
| 116 |
t = transcript.lower()
|
| 117 |
try:
|
|
|
|
| 214 |
self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
|
| 215 |
self.api_url = "http://localhost:11434/api/chat"
|
| 216 |
|
| 217 |
+
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
|
| 218 |
"""
|
| 219 |
Calls the local Ollama instance using the /chat endpoint so system tags
|
| 220 |
are properly applied.
|
|
|
|
| 236 |
payload = {
|
| 237 |
"model": self.model_name,
|
| 238 |
"messages": [
|
| 239 |
+
{"role": "system", "content": get_system_prompt(stage)},
|
| 240 |
{"role": "user", "content": prompt}
|
| 241 |
],
|
| 242 |
"format": "json",
|