Spaces:
Sleeping
Sleeping
priyansh-saxena1 commited on
Commit Β·
daf4268
1
Parent(s): f42a7a8
feat: unified prompt with state visibility
Browse files- app/graph.py +2 -3
- app/llm.py +97 -87
app/graph.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Optional, TypedDict, Annotated
|
|
| 4 |
from langgraph.graph import StateGraph, START, END
|
| 5 |
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
|
| 7 |
-
from app.llm import get_llm, CombinedOutput
|
| 8 |
from app.schemas import ClinicalBrief, HPI, ClinicalStateExtraction
|
| 9 |
|
| 10 |
_MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
|
|
@@ -23,8 +23,7 @@ class IntakeState(TypedDict):
|
|
| 23 |
frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
|
| 24 |
|
| 25 |
|
| 26 |
-
HPI_FIELDS
|
| 27 |
-
ROS_REQUIRED = 3
|
| 28 |
|
| 29 |
EMERGENCY_PHRASES = [
|
| 30 |
"crushing chest pain", "can't breathe", "cannot breathe",
|
|
|
|
| 4 |
from langgraph.graph import StateGraph, START, END
|
| 5 |
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
|
| 7 |
+
from app.llm import get_llm, CombinedOutput, HPI_FIELDS, ROS_REQUIRED
|
| 8 |
from app.schemas import ClinicalBrief, HPI, ClinicalStateExtraction
|
| 9 |
|
| 10 |
_MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
|
|
|
|
| 23 |
frontend_stage: str # 'intake', 'hpi', 'ros', 'done'
|
| 24 |
|
| 25 |
|
| 26 |
+
# HPI_FIELDS and ROS_REQUIRED imported from app.llm
|
|
|
|
| 27 |
|
| 28 |
EMERGENCY_PHRASES = [
|
| 29 |
"crushing chest pain", "can't breathe", "cannot breathe",
|
app/llm.py
CHANGED
|
@@ -2,47 +2,39 @@ import os
|
|
| 2 |
import json
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
-
|
| 31 |
-
-
|
| 32 |
-
-
|
| 33 |
-
- character: quality of pain (sharp, dull, pressure, burning, etc.)
|
| 34 |
-
- severity: how bad on a scale of 1-10
|
| 35 |
-
- aggravating: what makes it worse
|
| 36 |
-
- relieving: what makes it better
|
| 37 |
-
|
| 38 |
-
RULES:
|
| 39 |
- Output ONLY valid JSON, no extra text.
|
| 40 |
-
- Ask exactly ONE question per turn.
|
| 41 |
-
- Keep existing values. Use null for unknowns.
|
| 42 |
|
| 43 |
OUTPUT FORMAT:
|
| 44 |
{
|
| 45 |
-
"chief_complaint": "...",
|
| 46 |
"onset": "..." or null,
|
| 47 |
"location": "..." or null,
|
| 48 |
"duration": "..." or null,
|
|
@@ -50,49 +42,68 @@ OUTPUT FORMAT:
|
|
| 50 |
"severity": "..." or null,
|
| 51 |
"aggravating": "..." or null,
|
| 52 |
"relieving": "..." or null,
|
| 53 |
-
"ros": {},
|
|
|
|
| 54 |
"reply": "Your single question"
|
| 55 |
}"""
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
All HPI fields are already collected. Now you must screen for symptoms in OTHER body systems that are RELEVANT to the patient's chief complaint.
|
| 60 |
-
|
| 61 |
-
JOB 1 (EXTRACT): The patient just answered a question about a body system. Extract their answer into the "ros" dict under the appropriate system key (e.g. "musculoskeletal": ["joint stiffness", "no swelling"]).
|
| 62 |
|
| 63 |
-
JOB 2 (RESPOND): Ask about the NEXT relevant body system that is NOT yet in the "ros" dict.
|
| 64 |
|
| 65 |
-
|
| 66 |
-
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
|
| 72 |
-
|
| 73 |
-
- Output ONLY valid JSON.
|
| 74 |
-
- Ask about ONE system at a time.
|
| 75 |
-
- If the patient denies symptoms, store as ["no X", "no Y"].
|
| 76 |
-
- Once 3 systems are in "ros", set reply to "Thank you β I have everything I need."
|
| 77 |
-
- Do NOT ask emotional, psychological, or off-topic questions.
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
"""Return the appropriate system prompt for the current clinical stage."""
|
| 90 |
-
if stage == "ros":
|
| 91 |
-
return ROS_PROMPT
|
| 92 |
-
elif stage == "hpi":
|
| 93 |
-
return HPI_PROMPT
|
| 94 |
-
else:
|
| 95 |
-
return INTAKE_PROMPT
|
| 96 |
|
| 97 |
|
| 98 |
class CombinedOutput(BaseModel):
|
|
@@ -117,7 +128,6 @@ class MockLLM:
|
|
| 117 |
except Exception:
|
| 118 |
state = {}
|
| 119 |
|
| 120 |
-
# Mock just steps through HPI fields in order, using the patient's last message as the value
|
| 121 |
lines = transcript.strip().split("\n")
|
| 122 |
last_patient_msg = ""
|
| 123 |
for line in reversed(lines):
|
|
@@ -134,13 +144,11 @@ class MockLLM:
|
|
| 134 |
state["reply"] = "What brings you in today?" if not state.get("chief_complaint") else f"When did the {state['chief_complaint']} start?"
|
| 135 |
|
| 136 |
elif stage == "hpi":
|
| 137 |
-
|
| 138 |
-
for field in hpi_fields[1:]: # skip chief_complaint, already filled
|
| 139 |
if not state.get(field):
|
| 140 |
if last_patient_msg:
|
| 141 |
state[field] = last_patient_msg
|
| 142 |
break
|
| 143 |
-
# Ask about the next missing field
|
| 144 |
for field in hpi_fields[1:]:
|
| 145 |
if not state.get(field):
|
| 146 |
labels = {"onset": "when it started", "location": "where you feel it",
|
|
@@ -154,14 +162,12 @@ class MockLLM:
|
|
| 154 |
|
| 155 |
elif stage == "ros":
|
| 156 |
ros = state.get("ros", {})
|
| 157 |
-
# Fill the first empty ROS system
|
| 158 |
for sys_name in ros_systems:
|
| 159 |
if sys_name not in ros:
|
| 160 |
if last_patient_msg:
|
| 161 |
ros[sys_name] = [last_patient_msg]
|
| 162 |
state["ros"] = ros
|
| 163 |
break
|
| 164 |
-
# Ask about next missing system
|
| 165 |
for sys_name in ros_systems:
|
| 166 |
if sys_name not in ros:
|
| 167 |
state["reply"] = f"Any {sys_name} symptoms?"
|
|
@@ -179,34 +185,38 @@ class OllamaLLM:
|
|
| 179 |
|
| 180 |
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
|
| 181 |
"""
|
| 182 |
-
|
| 183 |
-
|
| 184 |
"""
|
|
|
|
|
|
|
| 185 |
prompt = (
|
|
|
|
| 186 |
f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
|
| 187 |
-
f"
|
| 188 |
-
"
|
| 189 |
-
"
|
| 190 |
-
"Return ONLY the JSON object
|
| 191 |
)
|
| 192 |
|
| 193 |
import time
|
| 194 |
import requests
|
| 195 |
-
|
| 196 |
t_start = time.time()
|
| 197 |
print(f"[Ollama] Starting inference for model '{self.model_name}'...")
|
| 198 |
-
|
|
|
|
| 199 |
payload = {
|
| 200 |
"model": self.model_name,
|
| 201 |
"messages": [
|
| 202 |
-
{"role": "system", "content":
|
| 203 |
{"role": "user", "content": prompt}
|
| 204 |
],
|
| 205 |
"format": "json",
|
| 206 |
"stream": False,
|
| 207 |
"options": {
|
| 208 |
"temperature": 0.0,
|
| 209 |
-
"num_predict":
|
| 210 |
}
|
| 211 |
}
|
| 212 |
|
|
|
|
| 2 |
import json
|
| 3 |
from pydantic import BaseModel
|
| 4 |
|
| 5 |
+
# ββ Single unified system prompt β LLM sees the full workflow ββ
|
| 6 |
+
SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview.
|
| 7 |
+
|
| 8 |
+
YOUR WORKFLOW (follow this order):
|
| 9 |
+
1. INTAKE: Identify the patient's chief complaint (main reason for visit).
|
| 10 |
+
2. HPI (History of Present Illness): Collect these fields ONE AT A TIME, in order:
|
| 11 |
+
- onset: when the symptom started
|
| 12 |
+
- location: where in the body
|
| 13 |
+
- duration: how long it has lasted
|
| 14 |
+
- character: quality (sharp, dull, pressure, burning, etc.)
|
| 15 |
+
- severity: how bad on a scale of 1-10
|
| 16 |
+
- aggravating: what makes it worse
|
| 17 |
+
- relieving: what makes it better
|
| 18 |
+
3. ROS (Review of Systems): Screen 3 body systems RELEVANT to the chief complaint.
|
| 19 |
+
Examples of relevant systems:
|
| 20 |
+
- Leg/knee/joint pain β musculoskeletal, neurological, vascular
|
| 21 |
+
- Chest pain β cardiac, respiratory, gi
|
| 22 |
+
- Headache β neurological, ophthalmologic, ent
|
| 23 |
+
- Abdominal pain β gi, genitourinary, musculoskeletal
|
| 24 |
+
- Back pain β musculoskeletal, neurological, genitourinary
|
| 25 |
+
4. DONE: When all HPI fields AND 3 ROS systems are filled, set reply to "Your clinical summary is ready. Please wait for the doctor."
|
| 26 |
+
|
| 27 |
+
CRITICAL RULES:
|
| 28 |
+
- NEVER re-ask a field that is already filled (marked β
in the status).
|
| 29 |
+
- Ask exactly ONE question per turn about the FIRST missing item.
|
| 30 |
+
- If a patient says "none"/"zero"/"no"/"denied", store that exact answer β do NOT leave it null.
|
| 31 |
+
- For ROS: store findings as a list, e.g. "musculoskeletal": ["joint stiffness", "no swelling"].
|
| 32 |
+
- Do NOT ask emotional/psychological questions β stick to physical symptoms.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
- Output ONLY valid JSON, no extra text.
|
|
|
|
|
|
|
| 34 |
|
| 35 |
OUTPUT FORMAT:
|
| 36 |
{
|
| 37 |
+
"chief_complaint": "..." or null,
|
| 38 |
"onset": "..." or null,
|
| 39 |
"location": "..." or null,
|
| 40 |
"duration": "..." or null,
|
|
|
|
| 42 |
"severity": "..." or null,
|
| 43 |
"aggravating": "..." or null,
|
| 44 |
"relieving": "..." or null,
|
| 45 |
+
"ros": {"system_name": ["finding1", "finding2"], ...},
|
| 46 |
+
"emergency": false,
|
| 47 |
"reply": "Your single question"
|
| 48 |
}"""
|
| 49 |
|
| 50 |
+
HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
|
| 51 |
+
ROS_REQUIRED = 3
|
|
|
|
|
|
|
|
|
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
+
def build_state_context(current_json: str) -> str:
|
| 55 |
+
"""Build a human-readable status summary so the LLM knows exactly what's filled and missing."""
|
| 56 |
+
try:
|
| 57 |
+
state = json.loads(current_json)
|
| 58 |
+
except Exception:
|
| 59 |
+
state = {}
|
| 60 |
|
| 61 |
+
lines = ["FIELD STATUS:"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# Chief complaint
|
| 64 |
+
cc = state.get("chief_complaint")
|
| 65 |
+
if cc:
|
| 66 |
+
lines.append(f' β
chief_complaint: "{cc}"')
|
| 67 |
+
else:
|
| 68 |
+
lines.append(" β chief_complaint: MISSING β ask what brings them in")
|
| 69 |
+
|
| 70 |
+
# HPI fields
|
| 71 |
+
for field in HPI_FIELDS:
|
| 72 |
+
val = state.get(field)
|
| 73 |
+
if val:
|
| 74 |
+
lines.append(f' β
{field}: "{val}"')
|
| 75 |
+
else:
|
| 76 |
+
lines.append(f" β {field}: MISSING")
|
| 77 |
+
|
| 78 |
+
# ROS
|
| 79 |
+
ros = state.get("ros", {})
|
| 80 |
+
if ros:
|
| 81 |
+
for sys_name, findings in ros.items():
|
| 82 |
+
lines.append(f' β
ros.{sys_name}: {findings}')
|
| 83 |
+
ros_remaining = ROS_REQUIRED - len(ros)
|
| 84 |
+
if ros_remaining > 0:
|
| 85 |
+
lines.append(f" β ros: {ros_remaining} more system(s) needed")
|
| 86 |
+
else:
|
| 87 |
+
lines.append(f" β
ros: all {ROS_REQUIRED} systems collected")
|
| 88 |
+
|
| 89 |
+
# Determine current phase
|
| 90 |
+
if not cc:
|
| 91 |
+
phase = "INTAKE"
|
| 92 |
+
elif any(not state.get(f) for f in HPI_FIELDS):
|
| 93 |
+
phase = "HPI"
|
| 94 |
+
first_missing = next(f for f in HPI_FIELDS if not state.get(f))
|
| 95 |
+
lines.append(f"\nCURRENT PHASE: {phase} β ask about '{first_missing}' next")
|
| 96 |
+
elif ros_remaining > 0:
|
| 97 |
+
phase = "ROS"
|
| 98 |
+
lines.append(f"\nCURRENT PHASE: {phase} β ask about the next body system relevant to '{cc}'")
|
| 99 |
+
else:
|
| 100 |
+
phase = "DONE"
|
| 101 |
+
lines.append(f"\nCURRENT PHASE: {phase} β all data collected, set reply to completion message")
|
| 102 |
|
| 103 |
+
if not cc:
|
| 104 |
+
lines.append(f"\nCURRENT PHASE: {phase}")
|
| 105 |
|
| 106 |
+
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
class CombinedOutput(BaseModel):
|
|
|
|
| 128 |
except Exception:
|
| 129 |
state = {}
|
| 130 |
|
|
|
|
| 131 |
lines = transcript.strip().split("\n")
|
| 132 |
last_patient_msg = ""
|
| 133 |
for line in reversed(lines):
|
|
|
|
| 144 |
state["reply"] = "What brings you in today?" if not state.get("chief_complaint") else f"When did the {state['chief_complaint']} start?"
|
| 145 |
|
| 146 |
elif stage == "hpi":
|
| 147 |
+
for field in hpi_fields[1:]:
|
|
|
|
| 148 |
if not state.get(field):
|
| 149 |
if last_patient_msg:
|
| 150 |
state[field] = last_patient_msg
|
| 151 |
break
|
|
|
|
| 152 |
for field in hpi_fields[1:]:
|
| 153 |
if not state.get(field):
|
| 154 |
labels = {"onset": "when it started", "location": "where you feel it",
|
|
|
|
| 162 |
|
| 163 |
elif stage == "ros":
|
| 164 |
ros = state.get("ros", {})
|
|
|
|
| 165 |
for sys_name in ros_systems:
|
| 166 |
if sys_name not in ros:
|
| 167 |
if last_patient_msg:
|
| 168 |
ros[sys_name] = [last_patient_msg]
|
| 169 |
state["ros"] = ros
|
| 170 |
break
|
|
|
|
| 171 |
for sys_name in ros_systems:
|
| 172 |
if sys_name not in ros:
|
| 173 |
state["reply"] = f"Any {sys_name} symptoms?"
|
|
|
|
| 185 |
|
| 186 |
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
|
| 187 |
"""
|
| 188 |
+
Single LLM call: extracts clinical data + generates next question.
|
| 189 |
+
The unified prompt + state context gives the LLM full visibility.
|
| 190 |
"""
|
| 191 |
+
state_context = build_state_context(current_json)
|
| 192 |
+
|
| 193 |
prompt = (
|
| 194 |
+
f"{state_context}\n\n"
|
| 195 |
f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
|
| 196 |
+
f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
|
| 197 |
+
"TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. "
|
| 198 |
+
"Then ask exactly ONE question about the FIRST missing item shown above. "
|
| 199 |
+
"Return ONLY the updated JSON object."
|
| 200 |
)
|
| 201 |
|
| 202 |
import time
|
| 203 |
import requests
|
| 204 |
+
|
| 205 |
t_start = time.time()
|
| 206 |
print(f"[Ollama] Starting inference for model '{self.model_name}'...")
|
| 207 |
+
print(f"[Ollama] State context:\n{state_context}")
|
| 208 |
+
|
| 209 |
payload = {
|
| 210 |
"model": self.model_name,
|
| 211 |
"messages": [
|
| 212 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 213 |
{"role": "user", "content": prompt}
|
| 214 |
],
|
| 215 |
"format": "json",
|
| 216 |
"stream": False,
|
| 217 |
"options": {
|
| 218 |
"temperature": 0.0,
|
| 219 |
+
"num_predict": 400
|
| 220 |
}
|
| 221 |
}
|
| 222 |
|