Spaces:
Sleeping
Sleeping
File size: 13,204 Bytes
6ea946a 284dfa9 6ea946a daf4268 8d6f802 daf4268 8d6f802 27b1ed4 daf4268 27b1ed4 daf4268 8d6f802 daf4268 27b1ed4 daf4268 27b1ed4 daf4268 27b1ed4 daf4268 27b1ed4 daf4268 44d41e8 daf4268 44d41e8 daf4268 8d6f802 daf4268 8d6f802 daf4268 8d6f802 27b1ed4 daf4268 0bcdd07 44d41e8 eb1b955 0bcdd07 6ea946a 99c13fa 8d6f802 27b1ed4 0bcdd07 6ea946a 7f00c10 8d6f802 7f00c10 8d6f802 7f00c10 8d6f802 7f00c10 8d6f802 7f00c10 8d6f802 7f00c10 0bcdd07 8d6f802 7f00c10 8d6f802 7f00c10 8d6f802 7f00c10 0bcdd07 6ea946a 4e16e37 6ea946a 4e16e37 2ea503f 6ea946a 27b1ed4 daf4268 0bcdd07 daf4268 0bcdd07 daf4268 8d6f802 daf4268 284dfa9 0bcdd07 03af64f 4e16e37 daf4268 03af64f 4e16e37 daf4268 4e16e37 2ea503f daf4268 2ea503f 4e16e37 daf4268 4e16e37 8d6f802 4e16e37 eb1b955 4e16e37 8d6f802 4e16e37 0bcdd07 8d6f802 0bcdd07 284dfa9 8d6f802 284dfa9 8d6f802 0bcdd07 284dfa9 8d6f802 33531b2 8d6f802 33531b2 8d6f802 0bcdd07 4e16e37 284dfa9 8d6f802 284dfa9 0bcdd07 6ea946a 8d6f802 6ea946a 284dfa9 6ea946a 8d6f802 6ea946a 99c13fa 4e16e37 99c13fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | import os
import json
from pydantic import BaseModel
SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview.
YOUR WORKFLOW (follow this order):
1. INTAKE: Identify the patient's chief complaint (main reason for visit).
2. HPI (History of Present Illness): Collect these fields ONE AT A TIME, in order:
- onset: when the symptom started
- location: where in the body
- duration: how long it has lasted
- character: quality (sharp, dull, pressure, burning, etc.)
- severity: how bad on a scale of 1-10
- aggravating: what makes it worse
- relieving: what makes it better
3. ROS (Review of Systems): Screen 3 body systems RELEVANT to the chief complaint.
Examples of relevant systems:
- Leg/knee/joint pain β musculoskeletal, neurological, vascular
- Chest pain β cardiac, respiratory, gi
- Headache β neurological, ophthalmologic, ent
- Abdominal pain β gi, genitourinary, musculoskeletal
- Back pain β musculoskeletal, neurological, genitourinary
4. DONE: When all HPI fields AND 3 ROS systems are filled, set reply to "Your clinical summary is ready. Please wait for the doctor."
CRITICAL RULES:
- NEVER re-ask a field that is already filled (marked β
in the status).
- Ask exactly ONE question per turn about the FIRST missing item.
- For HPI: accept any answer the patient gives, even vague ones like "moderate" or "not sure".
- For ROS: ALWAYS add the system to BOTH "ros" and "ros_asked" β even for negative answers.
- Positive finding: "cardiac": ["palpitations present"]
- Negative finding: "respiratory": ["no shortness of breath"]
- Denied: "gi": ["denied nausea and vomiting"]
A "no" is still a valid clinical finding. Never leave a ros system in ros_asked but absent from ros.
- Do NOT ask emotional/psychological questions β stick to physical symptoms.
- All string fields must be strings, not arrays.
- Output ONLY valid JSON, no extra text.
OUTPUT FORMAT:
{
"chief_complaint": "..." or null,
"onset": "..." or null,
"location": "..." or null,
"duration": "..." or null,
"character": "..." or null,
"severity": "..." or null,
"aggravating": "..." or null,
"relieving": "..." or null,
"ros": {"system_name": ["finding1", "finding2"], ...},
"ros_asked": ["system_name1", "system_name2"],
"emergency": false,
"reply": "Your single question"
}"""
HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
ROS_REQUIRED = 3
def build_state_context(current_json: str) -> str:
try:
state = json.loads(current_json)
except Exception:
state = {}
lines = ["FIELD STATUS:"]
cc = state.get("chief_complaint")
if cc:
lines.append(f' β
chief_complaint: "{cc}"')
else:
lines.append(" β chief_complaint: MISSING β ask what brings them in")
for field in HPI_FIELDS:
val = state.get(field)
if val:
lines.append(f' β
{field}: "{val}"')
else:
lines.append(f" β {field}: MISSING")
ros = state.get("ros", {})
ros_asked = state.get("ros_asked", [])
if ros:
for sys_name, findings in ros.items():
lines.append(f' β
ros.{sys_name}: {findings}')
ros_remaining = ROS_REQUIRED - len(ros)
if ros_remaining > 0:
lines.append(f" β ros: {ros_remaining} more system(s) needed")
if ros_asked:
lines.append(f" βΉοΈ Already asked about: {', '.join(ros_asked)} β DO NOT ask about these again")
else:
lines.append(f" β
ros: all {ROS_REQUIRED} systems collected")
if not cc:
phase = "INTAKE"
lines.append(f"\nCURRENT PHASE: {phase}")
elif any(not state.get(f) for f in HPI_FIELDS):
phase = "HPI"
first_missing = next(f for f in HPI_FIELDS if not state.get(f))
lines.append(f"\nCURRENT PHASE: {phase} β ask about '{first_missing}' next")
elif ros_remaining > 0:
phase = "ROS"
lines.append(f"\nCURRENT PHASE: {phase} β ask about the next body system relevant to '{cc}'")
lines.append(f" β οΈ IMPORTANT: Store BOTH positive AND negative ROS findings in 'ros' dict.")
lines.append(f" β οΈ A patient saying 'no' means: ros[\"system\"] = [\"no [symptom]\"]")
else:
phase = "DONE"
lines.append(f"\nCURRENT PHASE: {phase} β all data collected")
return "\n".join(lines)
class CombinedOutput(BaseModel):
chief_complaint: str | None = None
onset: str | None = None
location: str | None = None
duration: str | None = None
character: str | None = None
severity: str | None = None
aggravating: str | None = None
relieving: str | None = None
ros: dict[str, list[str]] = {}
ros_asked: list[str] = []
emergency: bool = False
reply: str = ""
class MockLLM:
"""Minimal mock for testing β deterministic field walker."""
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
try:
state = json.loads(current_json)
except Exception:
state = {}
lines = transcript.strip().split("\n")
last_patient_msg = ""
for line in reversed(lines):
if line.startswith("Patient:"):
last_patient_msg = line.replace("Patient:", "").strip()
break
ros_systems = ["cardiac", "respiratory", "gi"]
if stage == "intake":
if last_patient_msg and not state.get("chief_complaint"):
# Strip greeting words
greetings = {"hello", "hi", "hey", "ok", "okay", "start", "yes", "sure"}
if last_patient_msg.lower() not in greetings and len(last_patient_msg) > 4:
state["chief_complaint"] = last_patient_msg
state["reply"] = (
"What brings you in today?"
if not state.get("chief_complaint")
else f"When did the {state['chief_complaint']} start?"
)
elif stage == "hpi":
for field in HPI_FIELDS:
if not state.get(field):
if last_patient_msg:
state[field] = last_patient_msg
break
for field in HPI_FIELDS:
if not state.get(field):
labels = {
"onset": "when it started",
"location": "where you feel it",
"duration": "how long it's lasted",
"character": "what it feels like",
"severity": "how severe it is (1-10)",
"aggravating": "what makes it worse",
"relieving": "what makes it better",
}
state["reply"] = f"Can you tell me {labels.get(field, field)}?"
break
else:
state["reply"] = "Thank you, let me ask about other symptoms."
elif stage == "ros":
ros = state.get("ros", {})
ros_asked = state.get("ros_asked", [])
# Detect emergency keywords
if any(k in last_patient_msg.lower() for k in ["crushing", "can't breathe", "dying"]):
state["emergency"] = True
# Store last patient message into the first un-asked system
for sys_name in ros_systems:
if sys_name not in ros:
if last_patient_msg:
ros[sys_name] = [last_patient_msg]
state["ros"] = ros
if sys_name not in ros_asked:
ros_asked.append(sys_name)
state["ros_asked"] = ros_asked
break
# Ask about the next un-asked system
for sys_name in ros_systems:
if sys_name not in ros:
state["reply"] = f"Any {sys_name} symptoms?"
break
else:
state["reply"] = "Thank you β I have everything I need."
return CombinedOutput.model_validate(state)
class OllamaLLM:
def __init__(self):
self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
self.api_url = "http://localhost:11434/api/chat"
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
state_context = build_state_context(current_json)
prompt = (
f"{state_context}\n\n"
f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
"TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. "
"Then ask exactly ONE question about the FIRST missing item shown above. "
"For ROS: if the patient answers about a system (even 'no'), add it to BOTH ros AND ros_asked. "
"Return ONLY the updated JSON object."
)
import time
import requests
t_start = time.time()
print(f"[Ollama] Starting inference for model '{self.model_name}'...")
print(f"[Ollama] State context:\n{state_context}")
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
],
"format": "json",
"stream": False,
"options": {
"temperature": 0.0,
"num_predict": 400
}
}
try:
response = requests.post(self.api_url, json=payload, timeout=60)
response.raise_for_status()
data = response.json()
raw = data.get("message", {}).get("content", "").strip()
except Exception as e:
print(f"[Ollama] ERROR calling Ollama API: {e}")
return CombinedOutput.model_validate_json(current_json)
print(f"[Ollama] Inference completed in {time.time() - t_start:.2f}s total.")
# Strip markdown fences
json_str = raw
if "```json" in json_str:
json_str = json_str.split("```json", 1).split("```")[1]
elif "```" in json_str:
json_str = json_str.split("```", 1)[3].split("```")[0]
start = json_str.find("{")
end = json_str.rfind("}") + 1
if start != -1 and end > start:
json_str = json_str[start:end]
try:
parsed = json.loads(json_str)
# ββ Coerce all HPI string fields: listβstr, empty/nullβNone ββ
for field in ["chief_complaint", "onset", "location", "duration",
"character", "severity", "aggravating", "relieving"]:
v = parsed.get(field)
if isinstance(v, list):
# e.g. ["Walking"] β "Walking"
parsed[field] = " ".join(str(x) for x in v) if v else None
elif v is not None and str(v).strip() in ("", "null"):
parsed[field] = None
result = CombinedOutput.model_validate(parsed)
except Exception as e:
print(f"[Ollama] JSON parse error: {e}\nRaw output: {raw[:300]}")
try:
result = CombinedOutput.model_validate_json(current_json)
result = result.model_copy(update={"reply": "Could you please repeat that? I want to make sure I understood correctly."})
return result
except Exception:
return CombinedOutput(reply="Could you please repeat that?")
# ββ Post-process: normalize ros_asked β ros ββββββββββββββββββββββ
# If LLM added a system to ros_asked but not ros (e.g. for "no" answers),
# capture the last patient message as the finding for that system.
if result.ros_asked:
last_user = ""
for line in reversed(transcript.strip().split("\n")):
if line.startswith("Patient:"):
last_user = line.replace("Patient:", "").strip()
break
updated_ros = dict(result.ros)
changed = False
for asked_sys in result.ros_asked:
if asked_sys not in updated_ros:
updated_ros[asked_sys] = [last_user] if last_user else ["no symptoms reported"]
print(f"[ROSNorm] Filled ros['{asked_sys}'] from patient message: '{last_user[:40]}'")
changed = True
if changed:
result = result.model_copy(update={"ros": updated_ros})
print(f"[Ollama] Parsed result β stage will be recomputed in graph.")
return result
_llm_instance = None
def get_llm():
global _llm_instance
if _llm_instance is None:
mock_mode = os.environ.get("MOCK_LLM", "true").lower() == "true"
_llm_instance = MockLLM() if mock_mode else OllamaLLM()
return _llm_instance |