medintake-ai / app /llm.py
priyansh-saxena1
feat: unified prompt with state visibility
8d6f802
import os
import json
from pydantic import BaseModel
SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview.
YOUR WORKFLOW (follow this order):
1. INTAKE: Identify the patient's chief complaint (main reason for visit).
2. HPI (History of Present Illness): Collect these fields ONE AT A TIME, in order:
- onset: when the symptom started
- location: where in the body
- duration: how long it has lasted
- character: quality (sharp, dull, pressure, burning, etc.)
- severity: how bad on a scale of 1-10
- aggravating: what makes it worse
- relieving: what makes it better
3. ROS (Review of Systems): Screen 3 body systems RELEVANT to the chief complaint.
Examples of relevant systems:
- Leg/knee/joint pain β†’ musculoskeletal, neurological, vascular
- Chest pain β†’ cardiac, respiratory, gi
- Headache β†’ neurological, ophthalmologic, ent
- Abdominal pain β†’ gi, genitourinary, musculoskeletal
- Back pain β†’ musculoskeletal, neurological, genitourinary
4. DONE: When all HPI fields AND 3 ROS systems are filled, set reply to "Your clinical summary is ready. Please wait for the doctor."
CRITICAL RULES:
- NEVER re-ask a field that is already filled (marked βœ… in the status).
- Ask exactly ONE question per turn about the FIRST missing item.
- For HPI: accept any answer the patient gives, even vague ones like "moderate" or "not sure".
- For ROS: ALWAYS add the system to BOTH "ros" and "ros_asked" β€” even for negative answers.
- Positive finding: "cardiac": ["palpitations present"]
- Negative finding: "respiratory": ["no shortness of breath"]
- Denied: "gi": ["denied nausea and vomiting"]
A "no" is still a valid clinical finding. Never leave a ros system in ros_asked but absent from ros.
- Do NOT ask emotional/psychological questions β€” stick to physical symptoms.
- All string fields must be strings, not arrays.
- Output ONLY valid JSON, no extra text.
OUTPUT FORMAT:
{
"chief_complaint": "..." or null,
"onset": "..." or null,
"location": "..." or null,
"duration": "..." or null,
"character": "..." or null,
"severity": "..." or null,
"aggravating": "..." or null,
"relieving": "..." or null,
"ros": {"system_name": ["finding1", "finding2"], ...},
"ros_asked": ["system_name1", "system_name2"],
"emergency": false,
"reply": "Your single question"
}"""
HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
ROS_REQUIRED = 3
def build_state_context(current_json: str) -> str:
try:
state = json.loads(current_json)
except Exception:
state = {}
lines = ["FIELD STATUS:"]
cc = state.get("chief_complaint")
if cc:
lines.append(f' βœ… chief_complaint: "{cc}"')
else:
lines.append(" ❌ chief_complaint: MISSING β€” ask what brings them in")
for field in HPI_FIELDS:
val = state.get(field)
if val:
lines.append(f' βœ… {field}: "{val}"')
else:
lines.append(f" ❌ {field}: MISSING")
ros = state.get("ros", {})
ros_asked = state.get("ros_asked", [])
if ros:
for sys_name, findings in ros.items():
lines.append(f' βœ… ros.{sys_name}: {findings}')
ros_remaining = ROS_REQUIRED - len(ros)
if ros_remaining > 0:
lines.append(f" ❌ ros: {ros_remaining} more system(s) needed")
if ros_asked:
lines.append(f" ℹ️ Already asked about: {', '.join(ros_asked)} β€” DO NOT ask about these again")
else:
lines.append(f" βœ… ros: all {ROS_REQUIRED} systems collected")
if not cc:
phase = "INTAKE"
lines.append(f"\nCURRENT PHASE: {phase}")
elif any(not state.get(f) for f in HPI_FIELDS):
phase = "HPI"
first_missing = next(f for f in HPI_FIELDS if not state.get(f))
lines.append(f"\nCURRENT PHASE: {phase} β€” ask about '{first_missing}' next")
elif ros_remaining > 0:
phase = "ROS"
lines.append(f"\nCURRENT PHASE: {phase} β€” ask about the next body system relevant to '{cc}'")
lines.append(f" ⚠️ IMPORTANT: Store BOTH positive AND negative ROS findings in 'ros' dict.")
lines.append(f" ⚠️ A patient saying 'no' means: ros[\"system\"] = [\"no [symptom]\"]")
else:
phase = "DONE"
lines.append(f"\nCURRENT PHASE: {phase} β€” all data collected")
return "\n".join(lines)
class CombinedOutput(BaseModel):
chief_complaint: str | None = None
onset: str | None = None
location: str | None = None
duration: str | None = None
character: str | None = None
severity: str | None = None
aggravating: str | None = None
relieving: str | None = None
ros: dict[str, list[str]] = {}
ros_asked: list[str] = []
emergency: bool = False
reply: str = ""
class MockLLM:
"""Minimal mock for testing β€” deterministic field walker."""
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
try:
state = json.loads(current_json)
except Exception:
state = {}
lines = transcript.strip().split("\n")
last_patient_msg = ""
for line in reversed(lines):
if line.startswith("Patient:"):
last_patient_msg = line.replace("Patient:", "").strip()
break
ros_systems = ["cardiac", "respiratory", "gi"]
if stage == "intake":
if last_patient_msg and not state.get("chief_complaint"):
# Strip greeting words
greetings = {"hello", "hi", "hey", "ok", "okay", "start", "yes", "sure"}
if last_patient_msg.lower() not in greetings and len(last_patient_msg) > 4:
state["chief_complaint"] = last_patient_msg
state["reply"] = (
"What brings you in today?"
if not state.get("chief_complaint")
else f"When did the {state['chief_complaint']} start?"
)
elif stage == "hpi":
for field in HPI_FIELDS:
if not state.get(field):
if last_patient_msg:
state[field] = last_patient_msg
break
for field in HPI_FIELDS:
if not state.get(field):
labels = {
"onset": "when it started",
"location": "where you feel it",
"duration": "how long it's lasted",
"character": "what it feels like",
"severity": "how severe it is (1-10)",
"aggravating": "what makes it worse",
"relieving": "what makes it better",
}
state["reply"] = f"Can you tell me {labels.get(field, field)}?"
break
else:
state["reply"] = "Thank you, let me ask about other symptoms."
elif stage == "ros":
ros = state.get("ros", {})
ros_asked = state.get("ros_asked", [])
# Detect emergency keywords
if any(k in last_patient_msg.lower() for k in ["crushing", "can't breathe", "dying"]):
state["emergency"] = True
# Store last patient message into the first un-asked system
for sys_name in ros_systems:
if sys_name not in ros:
if last_patient_msg:
ros[sys_name] = [last_patient_msg]
state["ros"] = ros
if sys_name not in ros_asked:
ros_asked.append(sys_name)
state["ros_asked"] = ros_asked
break
# Ask about the next un-asked system
for sys_name in ros_systems:
if sys_name not in ros:
state["reply"] = f"Any {sys_name} symptoms?"
break
else:
state["reply"] = "Thank you β€” I have everything I need."
return CombinedOutput.model_validate(state)
class OllamaLLM:
def __init__(self):
self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
self.api_url = "http://localhost:11434/api/chat"
def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
state_context = build_state_context(current_json)
prompt = (
f"{state_context}\n\n"
f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
"TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. "
"Then ask exactly ONE question about the FIRST missing item shown above. "
"For ROS: if the patient answers about a system (even 'no'), add it to BOTH ros AND ros_asked. "
"Return ONLY the updated JSON object."
)
import time
import requests
t_start = time.time()
print(f"[Ollama] Starting inference for model '{self.model_name}'...")
print(f"[Ollama] State context:\n{state_context}")
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
],
"format": "json",
"stream": False,
"options": {
"temperature": 0.0,
"num_predict": 400
}
}
try:
response = requests.post(self.api_url, json=payload, timeout=60)
response.raise_for_status()
data = response.json()
raw = data.get("message", {}).get("content", "").strip()
except Exception as e:
print(f"[Ollama] ERROR calling Ollama API: {e}")
return CombinedOutput.model_validate_json(current_json)
print(f"[Ollama] Inference completed in {time.time() - t_start:.2f}s total.")
# Strip markdown fences
json_str = raw
if "```json" in json_str:
json_str = json_str.split("```json", 1).split("```")[1]
elif "```" in json_str:
json_str = json_str.split("```", 1)[3].split("```")[0]
start = json_str.find("{")
end = json_str.rfind("}") + 1
if start != -1 and end > start:
json_str = json_str[start:end]
try:
parsed = json.loads(json_str)
# ── Coerce all HPI string fields: listβ†’str, empty/nullβ†’None ──
for field in ["chief_complaint", "onset", "location", "duration",
"character", "severity", "aggravating", "relieving"]:
v = parsed.get(field)
if isinstance(v, list):
# e.g. ["Walking"] β†’ "Walking"
parsed[field] = " ".join(str(x) for x in v) if v else None
elif v is not None and str(v).strip() in ("", "null"):
parsed[field] = None
result = CombinedOutput.model_validate(parsed)
except Exception as e:
print(f"[Ollama] JSON parse error: {e}\nRaw output: {raw[:300]}")
try:
result = CombinedOutput.model_validate_json(current_json)
result = result.model_copy(update={"reply": "Could you please repeat that? I want to make sure I understood correctly."})
return result
except Exception:
return CombinedOutput(reply="Could you please repeat that?")
# ── Post-process: normalize ros_asked β†’ ros ──────────────────────
# If LLM added a system to ros_asked but not ros (e.g. for "no" answers),
# capture the last patient message as the finding for that system.
if result.ros_asked:
last_user = ""
for line in reversed(transcript.strip().split("\n")):
if line.startswith("Patient:"):
last_user = line.replace("Patient:", "").strip()
break
updated_ros = dict(result.ros)
changed = False
for asked_sys in result.ros_asked:
if asked_sys not in updated_ros:
updated_ros[asked_sys] = [last_user] if last_user else ["no symptoms reported"]
print(f"[ROSNorm] Filled ros['{asked_sys}'] from patient message: '{last_user[:40]}'")
changed = True
if changed:
result = result.model_copy(update={"ros": updated_ros})
print(f"[Ollama] Parsed result β€” stage will be recomputed in graph.")
return result
_llm_instance = None
def get_llm():
global _llm_instance
if _llm_instance is None:
mock_mode = os.environ.get("MOCK_LLM", "true").lower() == "true"
_llm_instance = MockLLM() if mock_mode else OllamaLLM()
return _llm_instance