File size: 13,204 Bytes
6ea946a
284dfa9
 
6ea946a
daf4268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d6f802
 
 
 
 
 
daf4268
8d6f802
27b1ed4
 
 
 
daf4268
27b1ed4
 
 
 
 
 
 
daf4268
8d6f802
daf4268
27b1ed4
 
 
daf4268
 
27b1ed4
 
daf4268
 
 
 
 
27b1ed4
daf4268
27b1ed4
daf4268
 
 
 
 
 
 
 
 
 
 
 
 
 
44d41e8
daf4268
 
 
 
 
 
44d41e8
 
daf4268
 
 
 
 
8d6f802
daf4268
 
 
 
 
 
 
8d6f802
 
daf4268
 
8d6f802
27b1ed4
daf4268
0bcdd07
 
 
 
 
 
 
 
 
 
 
 
44d41e8
eb1b955
0bcdd07
 
6ea946a
99c13fa
8d6f802
 
27b1ed4
0bcdd07
 
 
 
6ea946a
7f00c10
 
 
 
 
 
 
 
 
 
 
8d6f802
 
 
 
 
 
 
 
 
7f00c10
 
8d6f802
7f00c10
 
 
 
8d6f802
7f00c10
8d6f802
 
 
 
 
 
 
 
 
7f00c10
 
 
8d6f802
7f00c10
 
0bcdd07
8d6f802
 
 
 
 
 
 
7f00c10
 
 
 
 
8d6f802
 
 
7f00c10
8d6f802
 
7f00c10
 
 
 
0bcdd07
 
 
 
 
6ea946a
4e16e37
6ea946a
4e16e37
2ea503f
6ea946a
27b1ed4
daf4268
 
0bcdd07
daf4268
0bcdd07
daf4268
 
 
8d6f802
daf4268
284dfa9
0bcdd07
03af64f
4e16e37
daf4268
03af64f
4e16e37
daf4268
 
4e16e37
 
2ea503f
daf4268
2ea503f
 
4e16e37
 
 
 
daf4268
4e16e37
 
8d6f802
4e16e37
 
 
 
eb1b955
4e16e37
8d6f802
4e16e37
 
 
0bcdd07
8d6f802
0bcdd07
284dfa9
8d6f802
284dfa9
8d6f802
0bcdd07
 
 
 
 
 
284dfa9
 
8d6f802
 
33531b2
 
 
8d6f802
 
 
 
33531b2
8d6f802
 
 
0bcdd07
4e16e37
284dfa9
8d6f802
 
 
284dfa9
0bcdd07
6ea946a
8d6f802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ea946a
284dfa9
6ea946a
8d6f802
6ea946a
99c13fa
 
 
4e16e37
99c13fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import os
import json
from pydantic import BaseModel

SYSTEM_PROMPT = """You are a clinical intake assistant conducting a pre-visit patient interview.

YOUR WORKFLOW (follow this order):
1. INTAKE: Identify the patient's chief complaint (main reason for visit).
2. HPI (History of Present Illness): Collect these fields ONE AT A TIME, in order:
   - onset: when the symptom started
   - location: where in the body
   - duration: how long it has lasted
   - character: quality (sharp, dull, pressure, burning, etc.)
   - severity: how bad on a scale of 1-10
   - aggravating: what makes it worse
   - relieving: what makes it better
3. ROS (Review of Systems): Screen 3 body systems RELEVANT to the chief complaint.
   Examples of relevant systems:
   - Leg/knee/joint pain β†’ musculoskeletal, neurological, vascular
   - Chest pain β†’ cardiac, respiratory, gi
   - Headache β†’ neurological, ophthalmologic, ent
   - Abdominal pain β†’ gi, genitourinary, musculoskeletal
   - Back pain β†’ musculoskeletal, neurological, genitourinary
4. DONE: When all HPI fields AND 3 ROS systems are filled, set reply to "Your clinical summary is ready. Please wait for the doctor."

CRITICAL RULES:
- NEVER re-ask a field that is already filled (marked βœ… in the status).
- Ask exactly ONE question per turn about the FIRST missing item.
- For HPI: accept any answer the patient gives, even vague ones like "moderate" or "not sure".
- For ROS: ALWAYS add the system to BOTH "ros" and "ros_asked" β€” even for negative answers.
  - Positive finding: "cardiac": ["palpitations present"]
  - Negative finding: "respiratory": ["no shortness of breath"]
  - Denied:          "gi": ["denied nausea and vomiting"]
  A "no" is still a valid clinical finding. Never leave a ros system in ros_asked but absent from ros.
- Do NOT ask emotional/psychological questions β€” stick to physical symptoms.
- All string fields must be strings, not arrays.
- Output ONLY valid JSON, no extra text.

OUTPUT FORMAT:
{
  "chief_complaint": "..." or null,
  "onset": "..." or null,
  "location": "..." or null,
  "duration": "..." or null,
  "character": "..." or null,
  "severity": "..." or null,
  "aggravating": "..." or null,
  "relieving": "..." or null,
  "ros": {"system_name": ["finding1", "finding2"], ...},
  "ros_asked": ["system_name1", "system_name2"],
  "emergency": false,
  "reply": "Your single question"
}"""

HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
ROS_REQUIRED = 3


def build_state_context(current_json: str) -> str:
    try:
        state = json.loads(current_json)
    except Exception:
        state = {}

    lines = ["FIELD STATUS:"]

    cc = state.get("chief_complaint")
    if cc:
        lines.append(f'  βœ… chief_complaint: "{cc}"')
    else:
        lines.append("  ❌ chief_complaint: MISSING β€” ask what brings them in")

    for field in HPI_FIELDS:
        val = state.get(field)
        if val:
            lines.append(f'  βœ… {field}: "{val}"')
        else:
            lines.append(f"  ❌ {field}: MISSING")

    ros = state.get("ros", {})
    ros_asked = state.get("ros_asked", [])
    if ros:
        for sys_name, findings in ros.items():
            lines.append(f'  βœ… ros.{sys_name}: {findings}')
    ros_remaining = ROS_REQUIRED - len(ros)
    if ros_remaining > 0:
        lines.append(f"  ❌ ros: {ros_remaining} more system(s) needed")
        if ros_asked:
            lines.append(f"  ℹ️ Already asked about: {', '.join(ros_asked)} β€” DO NOT ask about these again")
    else:
        lines.append(f"  βœ… ros: all {ROS_REQUIRED} systems collected")

    if not cc:
        phase = "INTAKE"
        lines.append(f"\nCURRENT PHASE: {phase}")
    elif any(not state.get(f) for f in HPI_FIELDS):
        phase = "HPI"
        first_missing = next(f for f in HPI_FIELDS if not state.get(f))
        lines.append(f"\nCURRENT PHASE: {phase} β€” ask about '{first_missing}' next")
    elif ros_remaining > 0:
        phase = "ROS"
        lines.append(f"\nCURRENT PHASE: {phase} β€” ask about the next body system relevant to '{cc}'")
        lines.append(f"  ⚠️  IMPORTANT: Store BOTH positive AND negative ROS findings in 'ros' dict.")
        lines.append(f"  ⚠️  A patient saying 'no' means: ros[\"system\"] = [\"no [symptom]\"]")
    else:
        phase = "DONE"
        lines.append(f"\nCURRENT PHASE: {phase} β€” all data collected")

    return "\n".join(lines)


class CombinedOutput(BaseModel):
    chief_complaint: str | None = None
    onset: str | None = None
    location: str | None = None
    duration: str | None = None
    character: str | None = None
    severity: str | None = None
    aggravating: str | None = None
    relieving: str | None = None
    ros: dict[str, list[str]] = {}
    ros_asked: list[str] = []
    emergency: bool = False
    reply: str = ""


class MockLLM:
    """Minimal mock for testing β€” deterministic field walker."""

    def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
        try:
            state = json.loads(current_json)
        except Exception:
            state = {}

        lines = transcript.strip().split("\n")
        last_patient_msg = ""
        for line in reversed(lines):
            if line.startswith("Patient:"):
                last_patient_msg = line.replace("Patient:", "").strip()
                break

        ros_systems = ["cardiac", "respiratory", "gi"]

        if stage == "intake":
            if last_patient_msg and not state.get("chief_complaint"):
                # Strip greeting words
                greetings = {"hello", "hi", "hey", "ok", "okay", "start", "yes", "sure"}
                if last_patient_msg.lower() not in greetings and len(last_patient_msg) > 4:
                    state["chief_complaint"] = last_patient_msg
            state["reply"] = (
                "What brings you in today?"
                if not state.get("chief_complaint")
                else f"When did the {state['chief_complaint']} start?"
            )

        elif stage == "hpi":
            for field in HPI_FIELDS:
                if not state.get(field):
                    if last_patient_msg:
                        state[field] = last_patient_msg
                    break
            for field in HPI_FIELDS:
                if not state.get(field):
                    labels = {
                        "onset": "when it started",
                        "location": "where you feel it",
                        "duration": "how long it's lasted",
                        "character": "what it feels like",
                        "severity": "how severe it is (1-10)",
                        "aggravating": "what makes it worse",
                        "relieving": "what makes it better",
                    }
                    state["reply"] = f"Can you tell me {labels.get(field, field)}?"
                    break
            else:
                state["reply"] = "Thank you, let me ask about other symptoms."

        elif stage == "ros":
            ros = state.get("ros", {})
            ros_asked = state.get("ros_asked", [])

            # Detect emergency keywords
            if any(k in last_patient_msg.lower() for k in ["crushing", "can't breathe", "dying"]):
                state["emergency"] = True

            # Store last patient message into the first un-asked system
            for sys_name in ros_systems:
                if sys_name not in ros:
                    if last_patient_msg:
                        ros[sys_name] = [last_patient_msg]
                        state["ros"] = ros
                        if sys_name not in ros_asked:
                            ros_asked.append(sys_name)
                            state["ros_asked"] = ros_asked
                    break

            # Ask about the next un-asked system
            for sys_name in ros_systems:
                if sys_name not in ros:
                    state["reply"] = f"Any {sys_name} symptoms?"
                    break
            else:
                state["reply"] = "Thank you β€” I have everything I need."

        return CombinedOutput.model_validate(state)


class OllamaLLM:
    def __init__(self):
        self.model_name = os.environ.get("MODEL_NAME", "qwen2.5:0.5b")
        self.api_url = "http://localhost:11434/api/chat"

    def combined_call(self, transcript: str, current_json: str, stage: str = "intake") -> CombinedOutput:
        state_context = build_state_context(current_json)

        prompt = (
            f"{state_context}\n\n"
            f"CURRENT CLINICAL STATE (update with any new patient info):\n{current_json}\n\n"
            f"CONVERSATION TRANSCRIPT:\n{transcript}\n\n"
            "TASK: Read the patient's latest message. Extract any new clinical facts into the JSON. "
            "Then ask exactly ONE question about the FIRST missing item shown above. "
            "For ROS: if the patient answers about a system (even 'no'), add it to BOTH ros AND ros_asked. "
            "Return ONLY the updated JSON object."
        )

        import time
        import requests

        t_start = time.time()
        print(f"[Ollama] Starting inference for model '{self.model_name}'...")
        print(f"[Ollama] State context:\n{state_context}")

        payload = {
            "model": self.model_name,
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt}
            ],
            "format": "json",
            "stream": False,
            "options": {
                "temperature": 0.0,
                "num_predict": 400
            }
        }

        try:
            response = requests.post(self.api_url, json=payload, timeout=60)
            response.raise_for_status()
            data = response.json()
            raw = data.get("message", {}).get("content", "").strip()
        except Exception as e:
            print(f"[Ollama] ERROR calling Ollama API: {e}")
            return CombinedOutput.model_validate_json(current_json)

        print(f"[Ollama] Inference completed in {time.time() - t_start:.2f}s total.")

        # Strip markdown fences
        json_str = raw
        if "```json" in json_str:
            json_str = json_str.split("```json", 1).split("```")[1]
        elif "```" in json_str:
            json_str = json_str.split("```", 1)[3].split("```")[0]

        start = json_str.find("{")
        end = json_str.rfind("}") + 1
        if start != -1 and end > start:
            json_str = json_str[start:end]

        try:
            parsed = json.loads(json_str)

            # ── Coerce all HPI string fields: listβ†’str, empty/nullβ†’None ──
            for field in ["chief_complaint", "onset", "location", "duration",
                          "character", "severity", "aggravating", "relieving"]:
                v = parsed.get(field)
                if isinstance(v, list):
                    # e.g. ["Walking"] β†’ "Walking"
                    parsed[field] = " ".join(str(x) for x in v) if v else None
                elif v is not None and str(v).strip() in ("", "null"):
                    parsed[field] = None

            result = CombinedOutput.model_validate(parsed)

        except Exception as e:
            print(f"[Ollama] JSON parse error: {e}\nRaw output: {raw[:300]}")
            try:
                result = CombinedOutput.model_validate_json(current_json)
                result = result.model_copy(update={"reply": "Could you please repeat that? I want to make sure I understood correctly."})
                return result
            except Exception:
                return CombinedOutput(reply="Could you please repeat that?")

        # ── Post-process: normalize ros_asked β†’ ros ──────────────────────
        # If LLM added a system to ros_asked but not ros (e.g. for "no" answers),
        # capture the last patient message as the finding for that system.
        if result.ros_asked:
            last_user = ""
            for line in reversed(transcript.strip().split("\n")):
                if line.startswith("Patient:"):
                    last_user = line.replace("Patient:", "").strip()
                    break

            updated_ros = dict(result.ros)
            changed = False
            for asked_sys in result.ros_asked:
                if asked_sys not in updated_ros:
                    updated_ros[asked_sys] = [last_user] if last_user else ["no symptoms reported"]
                    print(f"[ROSNorm] Filled ros['{asked_sys}'] from patient message: '{last_user[:40]}'")
                    changed = True
            if changed:
                result = result.model_copy(update={"ros": updated_ros})

        print(f"[Ollama] Parsed result β€” stage will be recomputed in graph.")
        return result


_llm_instance = None


def get_llm():
    global _llm_instance
    if _llm_instance is None:
        mock_mode = os.environ.get("MOCK_LLM", "true").lower() == "true"
        _llm_instance = MockLLM() if mock_mode else OllamaLLM()
    return _llm_instance