Priyansh Saxena commited on
Commit
6ea946a
·
1 Parent(s): df4a61a

feat: add LLM providers and graph orchestration

Browse files
Files changed (2) hide show
  1. app/graph.py +362 -0
  2. app/llm.py +82 -0
app/graph.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, TypedDict, Annotated
2
+ from langgraph.graph import StateGraph, START, END
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+
5
+
6
+ def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
7
+ return left + right
8
+
9
+
10
+ class IntakeState(TypedDict):
11
+ messages: Annotated[list[dict], add_messages]
12
+ chief_complaint: str
13
+ hpi: dict
14
+ ros: dict[str, list[str]]
15
+ current_node: str
16
+ clinical_brief: Optional[dict]
17
+ ros_systems: list[str]
18
+ ros_current_index: int
19
+ ros_pending_system: Optional[str]
20
+ last_processed_message_index: int
21
+ vague_retry_field: Optional[str]
22
+
23
+
24
+ HPI_FIELDS = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
25
+
26
+ HPI_QUESTIONS = {
27
+ "onset": "When did your symptoms first start?",
28
+ "location": "Where exactly do you feel the pain or discomfort?",
29
+ "duration": "How long does each episode last? Is it constant or intermittent?",
30
+ "character": "Can you describe what the pain feels like?",
31
+ "severity": "On a scale of 1 to 10, how severe is your pain?",
32
+ "aggravating": "What makes your symptoms worse?",
33
+ "relieving": "What helps relieve your symptoms?"
34
+ }
35
+
36
+ HPI_FIELD_CONTEXT = {
37
+ "onset": "when your symptoms first started",
38
+ "location": "where exactly you feel it",
39
+ "duration": "how long each episode lasts",
40
+ "character": "what the pain feels like",
41
+ "severity": "how severe the pain is on a 1-10 scale",
42
+ "aggravating": "what makes your symptoms worse",
43
+ "relieving": "what helps relieve your symptoms",
44
+ }
45
+
46
+ CC_KEYWORDS_TO_ROS = {
47
+ "chest": ["cardiac", "respiratory", "gi"],
48
+ "pain": ["cardiac", "respiratory", "gi"],
49
+ "headache": ["neuro", "ent", "vision"],
50
+ "head": ["neuro", "ent", "vision"],
51
+ "breath": ["respiratory", "cardiac"],
52
+ "shortness": ["respiratory", "cardiac"],
53
+ "cough": ["respiratory", "ent"],
54
+ "dizzy": ["neuro", "cardiac"],
55
+ "nausea": ["gi", "constitutional"],
56
+ "vomiting": ["gi", "constitutional"],
57
+ }
58
+
59
+ DEFAULT_ROS = ["constitutional", "cardiac", "respiratory"]
60
+
61
+
62
+ def get_relevant_ros_systems(cc: str) -> list[str]:
63
+ cc_lower = cc.lower()
64
+ for keyword, systems in CC_KEYWORDS_TO_ROS.items():
65
+ if keyword in cc_lower:
66
+ return systems
67
+ return DEFAULT_ROS
68
+
69
+
70
+ import re
71
+
72
+
73
+ def extract_hpi_value(answer: str, field: str) -> str:
74
+ answer = answer.strip()
75
+ if field == "severity":
76
+ match = re.search(r'(\d{1,2})\s*(?:out of|/)?\s*10', answer, re.IGNORECASE)
77
+ if match:
78
+ return f"{match.group(1)}/10"
79
+ return answer
80
+
81
+
82
+ def _is_vague_answer(answer: str) -> bool:
83
+ vague_phrases = ["i don't know", "not sure", "dont know", "idk", "maybe", "i guess"]
84
+ answer_lower = answer.lower()
85
+ return any(phrase in answer_lower for phrase in vague_phrases)
86
+
87
+
88
+ def intake_node(state: IntakeState) -> dict:
89
+ messages = state.get("messages", [])
90
+ last_idx = state.get("last_processed_message_index", 0)
91
+ cc = state.get("chief_complaint", "")
92
+
93
+ has_new_user_msg = len(messages) > last_idx
94
+
95
+ if not cc and has_new_user_msg:
96
+ user_msg = messages[-1]
97
+ if user_msg.get("role") == "user":
98
+ cc = user_msg.get("content", "")
99
+ reply = f"I understand you're experiencing {cc}. Let me ask you some questions about this."
100
+ else:
101
+ reply = "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"
102
+ elif not cc:
103
+ reply = "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"
104
+ else:
105
+ return {
106
+ "current_node": "hpi",
107
+ }
108
+
109
+ return {
110
+ "messages": [{"role": "assistant", "content": reply}],
111
+ "chief_complaint": cc,
112
+ "current_node": "hpi",
113
+ "ros_systems": state.get("ros_systems", []),
114
+ "ros_current_index": state.get("ros_current_index", 0),
115
+ "ros_pending_system": state.get("ros_pending_system"),
116
+ "last_processed_message_index": len(messages) if has_new_user_msg else last_idx,
117
+ "vague_retry_field": None,
118
+ }
119
+
120
+
121
+ def hpi_node(state: IntakeState) -> dict:
122
+ messages = state.get("messages", [])
123
+ last_idx = state.get("last_processed_message_index", 0)
124
+ hpi = dict(state.get("hpi", {}))
125
+ vague_retry_field = state.get("vague_retry_field")
126
+
127
+ next_field = vague_retry_field
128
+ if not next_field:
129
+ for field in HPI_FIELDS:
130
+ if field not in hpi or not hpi.get(field):
131
+ next_field = field
132
+ break
133
+
134
+ if next_field is None:
135
+ reply = "Thank you for providing that information. Now let me ask about other symptoms."
136
+ return {
137
+ "messages": [{"role": "assistant", "content": reply}],
138
+ "current_node": "ros",
139
+ "last_processed_message_index": len(messages),
140
+ "vague_retry_field": None,
141
+ }
142
+
143
+ has_new_user_msg = len(messages) > last_idx
144
+
145
+ if has_new_user_msg:
146
+ user_msg = None
147
+ for i in range(last_idx, len(messages)):
148
+ if messages[i].get("role") == "user":
149
+ user_msg = messages[i]
150
+ break
151
+
152
+ if user_msg:
153
+ answer = user_msg.get("content", "")
154
+
155
+ if _is_vague_answer(answer):
156
+ field_context = HPI_FIELD_CONTEXT.get(next_field, "your symptoms")
157
+ reply = f"Could you be more specific about {field_context}?"
158
+ return {
159
+ "messages": [{"role": "assistant", "content": reply}],
160
+ "current_node": "hpi",
161
+ "last_processed_message_index": last_idx,
162
+ "vague_retry_field": next_field,
163
+ }
164
+
165
+ hpi[next_field] = extract_hpi_value(answer, next_field)
166
+
167
+ next_idx = HPI_FIELDS.index(next_field)
168
+ if next_idx < len(HPI_FIELDS) - 1:
169
+ next_q = HPI_FIELDS[next_idx + 1]
170
+ reply = HPI_QUESTIONS[next_q]
171
+ next_node = "hpi"
172
+ else:
173
+ reply = "Thank you. Now let me ask about other associated symptoms."
174
+ next_node = "ros"
175
+
176
+ return {
177
+ "messages": [{"role": "assistant", "content": reply}],
178
+ "hpi": hpi,
179
+ "current_node": next_node,
180
+ "last_processed_message_index": len(messages),
181
+ "vague_retry_field": None,
182
+ }
183
+
184
+ reply = HPI_QUESTIONS[next_field]
185
+ return {
186
+ "messages": [{"role": "assistant", "content": reply}],
187
+ "current_node": "hpi",
188
+ "last_processed_message_index": last_idx,
189
+ "vague_retry_field": None,
190
+ }
191
+
192
+
193
+ def ros_node(state: IntakeState) -> dict:
194
+ messages = state.get("messages", [])
195
+ last_idx = state.get("last_processed_message_index", 0)
196
+ ros = dict(state.get("ros", {}))
197
+ cc = state.get("chief_complaint", "")
198
+
199
+ ros_systems = state.get("ros_systems", [])
200
+ if not ros_systems:
201
+ ros_systems = get_relevant_ros_systems(cc)
202
+
203
+ current_idx = state.get("ros_current_index", 0)
204
+ pending_system = state.get("ros_pending_system")
205
+
206
+ if current_idx >= len(ros_systems):
207
+ reply = "Thank you. I have enough information to generate your clinical brief."
208
+ return {
209
+ "messages": [{"role": "assistant", "content": reply}],
210
+ "current_node": "brief_generator",
211
+ "ros_systems": ros_systems,
212
+ "ros_current_index": current_idx,
213
+ "ros_pending_system": None,
214
+ "last_processed_message_index": len(messages),
215
+ "vague_retry_field": None,
216
+ }
217
+
218
+ has_new_user_msg = len(messages) > last_idx
219
+
220
+ if has_new_user_msg:
221
+ user_msg = messages[-1]
222
+ if user_msg.get("role") == "user":
223
+ answer = user_msg.get("content", "")
224
+
225
+ if pending_system:
226
+ positive_findings = []
227
+ negative_findings = []
228
+
229
+ findings = [f.strip() for f in answer.split(",")]
230
+ for f in findings:
231
+ f_lower = f.lower()
232
+ if "no " in f_lower or "none" in f_lower:
233
+ negative_findings.append(f)
234
+ else:
235
+ positive_findings.append(f)
236
+
237
+ ros[pending_system] = positive_findings + negative_findings
238
+
239
+ if current_idx < len(ros_systems):
240
+ next_system = ros_systems[current_idx]
241
+ reply = f"Let's review your {next_system} system. Any {next_system} symptoms? Please mention what's present and what's not."
242
+ return {
243
+ "messages": [{"role": "assistant", "content": reply}],
244
+ "ros": ros,
245
+ "current_node": "ros",
246
+ "ros_systems": ros_systems,
247
+ "ros_current_index": current_idx + 1,
248
+ "ros_pending_system": next_system,
249
+ "last_processed_message_index": len(messages),
250
+ "vague_retry_field": None,
251
+ }
252
+ else:
253
+ reply = "Thank you. I have enough information."
254
+ return {
255
+ "messages": [{"role": "assistant", "content": reply}],
256
+ "ros": ros,
257
+ "current_node": "brief_generator",
258
+ "ros_systems": ros_systems,
259
+ "ros_current_index": current_idx,
260
+ "ros_pending_system": None,
261
+ "last_processed_message_index": len(messages),
262
+ "vague_retry_field": None,
263
+ }
264
+
265
+ if current_idx < len(ros_systems):
266
+ next_system = ros_systems[current_idx]
267
+ reply = f"Let's start with your {next_system} system. Any {next_system} symptoms? Please mention what's present and what's not."
268
+ return {
269
+ "messages": [{"role": "assistant", "content": reply}],
270
+ "current_node": "ros",
271
+ "ros_systems": ros_systems,
272
+ "ros_current_index": current_idx + 1,
273
+ "ros_pending_system": next_system,
274
+ "last_processed_message_index": last_idx,
275
+ "vague_retry_field": None,
276
+ }
277
+
278
+ reply = "Continuing review of systems..."
279
+ return {
280
+ "messages": [{"role": "assistant", "content": reply}],
281
+ "current_node": "ros",
282
+ "ros_systems": ros_systems,
283
+ "ros_current_index": current_idx,
284
+ "ros_pending_system": None,
285
+ "last_processed_message_index": last_idx,
286
+ "vague_retry_field": None,
287
+ }
288
+
289
+
290
+ from datetime import datetime, timezone
291
+ from app.schemas import HPI as HPIModel, ClinicalBrief as ClinicalBriefModel
292
+
293
+
294
+ def brief_generator_node(state: IntakeState) -> dict:
295
+ ros = state.get("ros", {})
296
+ hpi_data = state.get("hpi", {})
297
+
298
+ hpi_obj = HPIModel(
299
+ onset=hpi_data.get("onset") or "not specified",
300
+ location=hpi_data.get("location") or "not specified",
301
+ duration=hpi_data.get("duration") or "not specified",
302
+ character=hpi_data.get("character") or "not specified",
303
+ severity=hpi_data.get("severity") or "not specified",
304
+ aggravating=hpi_data.get("aggravating") or "not specified",
305
+ relieving=hpi_data.get("relieving") or "not specified",
306
+ )
307
+
308
+ brief = ClinicalBriefModel(
309
+ chief_complaint=state.get("chief_complaint", ""),
310
+ hpi=hpi_obj,
311
+ ros=ros,
312
+ generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
313
+ )
314
+
315
+ reply = "Your clinical intake is complete. Here is your summary."
316
+ return {
317
+ "messages": [{"role": "assistant", "content": reply}],
318
+ "current_node": "done",
319
+ "clinical_brief": brief.model_dump(),
320
+ "ros_systems": state.get("ros_systems", []),
321
+ "ros_current_index": state.get("ros_current_index", 0),
322
+ "ros_pending_system": None,
323
+ "last_processed_message_index": len(state.get("messages", [])),
324
+ "vague_retry_field": None,
325
+ }
326
+
327
+
328
+ def route_from_intake(state: IntakeState) -> str:
329
+ return "hpi"
330
+
331
+
332
+ def route_from_hpi(state: IntakeState) -> str:
333
+ hpi = state.get("hpi", {})
334
+ all_filled = all(hpi.get(f) for f in HPI_FIELDS)
335
+ return "ros" if all_filled else "hpi"
336
+
337
+
338
+ def route_from_ros(state: IntakeState) -> str:
339
+ ros_systems = state.get("ros_systems", [])
340
+ current_index = state.get("ros_current_index", 0)
341
+ all_processed = current_index >= len(ros_systems)
342
+ return "brief_generator" if all_processed else "ros"
343
+
344
+
345
+ def build_graph() -> tuple:
346
+ workflow = StateGraph(IntakeState)
347
+
348
+ workflow.add_node("intake", intake_node)
349
+ workflow.add_node("hpi", hpi_node)
350
+ workflow.add_node("ros", ros_node)
351
+ workflow.add_node("brief_generator", brief_generator_node)
352
+
353
+ workflow.add_edge(START, "intake")
354
+ workflow.add_conditional_edges("intake", route_from_intake, {"hpi": "hpi"})
355
+ workflow.add_conditional_edges("hpi", route_from_hpi, {"hpi": "hpi", "ros": "ros"})
356
+ workflow.add_conditional_edges("ros", route_from_ros, {"ros": "ros", "brief_generator": "brief_generator"})
357
+ workflow.add_edge("brief_generator", END)
358
+
359
+ checkpointer = MemorySaver()
360
+ graph = workflow.compile(checkpointer=checkpointer, interrupt_after=["intake", "hpi", "ros"])
361
+
362
+ return graph, checkpointer
app/llm.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class MockLLM:
5
+ """Mock LLM for testing - returns hardcoded clinical responses."""
6
+
7
+ def __init__(self):
8
+ self.hpi_fields = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
9
+ self.current_hpi_index = 0
10
+ self.ros_systems_done = False
11
+ self.ros_current_system = 0
12
+
13
+ def generate_response(self, conversation_history: list[dict], current_node: str) -> str:
14
+ if current_node == "intake":
15
+ return "I have chest pain since this morning"
16
+
17
+ if current_node == "hpi":
18
+ responses = [
19
+ "It started about 3 hours ago",
20
+ "In the center of my chest",
21
+ "It has been constant",
22
+ "It feels like pressure",
23
+ "About a 7 out of 10",
24
+ "It gets worse when I walk",
25
+ "Resting helps a little"
26
+ ]
27
+ if self.current_hpi_index < len(responses):
28
+ response = responses[self.current_hpi_index]
29
+ self.current_hpi_index += 1
30
+ return response
31
+ return "I already answered all those questions"
32
+
33
+ if current_node == "ros":
34
+ if not self.ros_systems_done:
35
+ self.ros_systems_done = True
36
+ return "cardiac:palpitations present,no syncope|respiratory:mild shortness of breath,no cough"
37
+ return "done"
38
+
39
+ return ""
40
+
41
+ def reset(self):
42
+ self.current_hpi_index = 0
43
+ self.ros_systems_done = False
44
+ self.ros_current_system = 0
45
+
46
+
47
+ class RealLLM:
48
+ """Real LLM using llama-cpp-python with lazy loading."""
49
+
50
+ def __init__(self):
51
+ self.model = None
52
+ self.model_path = "/models/qwen2.5-0.5b-instruct-q4_k_m.gguf"
53
+
54
+ def _load_model(self):
55
+ if self.model is None:
56
+ from llama_cpp import Llama
57
+ self.model = Llama(
58
+ model_path=self.model_path,
59
+ n_ctx=2048,
60
+ n_threads=4
61
+ )
62
+
63
+ def generate_response(self, conversation_history: list[dict], current_node: str) -> str:
64
+ self._load_model()
65
+
66
+ system_prompt = (
67
+ "You are a clinical AI assistant conducting patient intake. "
68
+ "Ask one question at a time. Be concise and professional."
69
+ )
70
+
71
+ messages = [{"role": "system", "content": system_prompt}]
72
+ messages.extend(conversation_history)
73
+
74
+ output = self.model.create_chat_completion(messages, max_tokens=256)
75
+ return output["choices"][0]["message"]["content"]
76
+
77
+
78
+ def get_llm():
79
+ mock_mode = os.environ.get("MOCK_LLM", "false").lower() == "true"
80
+ if mock_mode:
81
+ return MockLLM()
82
+ return RealLLM()