priyansh-saxena1 commited on
Commit
99c13fa
·
1 Parent(s): 058b7cd

fix: imoprve llm propmts

Browse files
Files changed (4) hide show
  1. Dockerfile +12 -2
  2. app/graph.py +111 -147
  3. app/llm.py +49 -55
  4. requirements.txt +4 -1
Dockerfile CHANGED
@@ -3,10 +3,20 @@ FROM python:3.11-slim
3
  WORKDIR /app
4
 
5
  COPY requirements.txt .
 
 
 
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
8
- # Default to mock mode for HF Spaces (no GPU available)
9
- ENV MOCK_LLM=true
 
 
 
 
 
 
 
10
 
11
  COPY app/ ./app/
12
  COPY tests/ ./tests/
 
3
  WORKDIR /app
4
 
5
  COPY requirements.txt .
6
+
7
+ # CPU-only torch (~220MB vs 2.4GB CUDA wheel)
8
+ RUN pip install --no-cache-dir torch --extra-index-url https://download.pytorch.org/whl/cpu
9
  RUN pip install --no-cache-dir -r requirements.txt
10
 
11
+ # Pre-download model weights at build time (baked into image)
12
+ # Swap model name here if you want a bigger one
13
+ ARG MODEL_NAME=Qwen/Qwen2.5-0.5B-Instruct
14
+ RUN python -c "from transformers import AutoModelForCausalLM, AutoTokenizer; \
15
+ AutoTokenizer.from_pretrained('${MODEL_NAME}'); \
16
+ AutoModelForCausalLM.from_pretrained('${MODEL_NAME}')"
17
+
18
+ ENV MOCK_LLM=false
19
+ ENV MODEL_NAME=${MODEL_NAME}
20
 
21
  COPY app/ ./app/
22
  COPY tests/ ./tests/
app/graph.py CHANGED
@@ -1,6 +1,29 @@
1
  from typing import Optional, TypedDict, Annotated
2
  from langgraph.graph import StateGraph, START, END
3
  from langgraph.checkpoint.memory import MemorySaver
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
@@ -38,9 +61,9 @@ HPI_FIELD_CONTEXT = {
38
  "location": "where exactly you feel it",
39
  "duration": "how long each episode lasts",
40
  "character": "what the pain feels like",
41
- "severity": "how severe the pain is on a 1-10 scale",
42
- "aggravating": "what makes your symptoms worse",
43
- "relieving": "what helps relieve your symptoms",
44
  }
45
 
46
  CC_KEYWORDS_TO_ROS = {
@@ -85,6 +108,8 @@ def _is_vague_answer(answer: str) -> bool:
85
  return any(phrase in answer_lower for phrase in vague_phrases)
86
 
87
 
 
 
88
  def intake_node(state: IntakeState) -> dict:
89
  messages = state.get("messages", [])
90
  last_idx = state.get("last_processed_message_index", 0)
@@ -96,23 +121,25 @@ def intake_node(state: IntakeState) -> dict:
96
  user_msg = messages[-1]
97
  if user_msg.get("role") == "user":
98
  cc = user_msg.get("content", "")
99
- reply = f"I understand you're experiencing {cc}. Let me ask you some questions about this."
 
 
 
 
 
 
 
100
  else:
101
- reply = "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"
102
  elif not cc:
103
- reply = "Hello, I'm conducting your pre-visit clinical intake. What brings you in today?"
104
  else:
105
- return {
106
- "current_node": "hpi",
107
- }
108
 
109
  return {
110
  "messages": [{"role": "assistant", "content": reply}],
111
  "chief_complaint": cc,
112
  "current_node": "hpi",
113
- "ros_systems": state.get("ros_systems", []),
114
- "ros_current_index": state.get("ros_current_index", 0),
115
- "ros_pending_system": state.get("ros_pending_system"),
116
  "last_processed_message_index": len(messages) if has_new_user_msg else last_idx,
117
  "vague_retry_field": None,
118
  }
@@ -123,6 +150,7 @@ def hpi_node(state: IntakeState) -> dict:
123
  last_idx = state.get("last_processed_message_index", 0)
124
  hpi = dict(state.get("hpi", {}))
125
  vague_retry_field = state.get("vague_retry_field")
 
126
 
127
  next_field = vague_retry_field
128
  if not next_field:
@@ -132,29 +160,32 @@ def hpi_node(state: IntakeState) -> dict:
132
  break
133
 
134
  if next_field is None:
135
- reply = "Thank you for providing that information. Now let me ask about other symptoms."
136
  return {
137
- "messages": [{"role": "assistant", "content": reply}],
138
  "current_node": "ros",
139
  "last_processed_message_index": len(messages),
140
  "vague_retry_field": None,
141
  }
142
 
143
  has_new_user_msg = len(messages) > last_idx
144
-
145
  if has_new_user_msg:
146
- user_msg = None
147
- for i in range(last_idx, len(messages)):
148
- if messages[i].get("role") == "user":
149
- user_msg = messages[i]
150
- break
151
-
152
  if user_msg:
153
- answer = user_msg.get("content", "")
154
 
155
  if _is_vague_answer(answer):
156
- field_context = HPI_FIELD_CONTEXT.get(next_field, "your symptoms")
157
- reply = f"Could you be more specific about {field_context}?"
 
 
 
 
 
 
 
 
158
  return {
159
  "messages": [{"role": "assistant", "content": reply}],
160
  "current_node": "hpi",
@@ -166,22 +197,40 @@ def hpi_node(state: IntakeState) -> dict:
166
 
167
  next_idx = HPI_FIELDS.index(next_field)
168
  if next_idx < len(HPI_FIELDS) - 1:
169
- next_q = HPI_FIELDS[next_idx + 1]
170
- reply = HPI_QUESTIONS[next_q]
171
- next_node = "hpi"
172
- else:
173
- reply = "Thank you. Now let me ask about other associated symptoms."
174
- next_node = "ros"
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  return {
177
- "messages": [{"role": "assistant", "content": reply}],
178
  "hpi": hpi,
179
- "current_node": next_node,
180
  "last_processed_message_index": len(messages),
181
  "vague_retry_field": None,
182
  }
183
 
184
- reply = HPI_QUESTIONS[next_field]
 
 
 
 
 
 
 
185
  return {
186
  "messages": [{"role": "assistant", "content": reply}],
187
  "current_node": "hpi",
@@ -196,153 +245,68 @@ def ros_node(state: IntakeState) -> dict:
196
  ros = dict(state.get("ros", {}))
197
  cc = state.get("chief_complaint", "")
198
 
199
- ros_systems = state.get("ros_systems", [])
200
- if not ros_systems:
201
- ros_systems = get_relevant_ros_systems(cc)
202
-
203
  current_idx = state.get("ros_current_index", 0)
204
- pending_system = state.get("ros_pending_system")
205
 
206
  if current_idx >= len(ros_systems):
207
- reply = "Thank you. I have enough information to generate your clinical brief."
208
  return {
209
- "messages": [{"role": "assistant", "content": reply}],
210
  "current_node": "brief_generator",
211
- "ros_systems": ros_systems,
212
- "ros_current_index": current_idx,
213
- "ros_pending_system": None,
214
  "last_processed_message_index": len(messages),
215
- "vague_retry_field": None,
216
  }
217
 
218
  has_new_user_msg = len(messages) > last_idx
219
 
220
- if has_new_user_msg:
221
- user_msg = messages[-1]
222
- if user_msg.get("role") == "user":
223
- answer = user_msg.get("content", "")
224
 
225
- if pending_system:
226
- positive_findings = []
227
- negative_findings = []
228
 
229
- findings = [f.strip() for f in answer.split(",")]
230
- for f in findings:
231
- f_lower = f.lower()
232
- if "no " in f_lower or "none" in f_lower:
233
- negative_findings.append(f)
234
- else:
235
- positive_findings.append(f)
236
-
237
- ros[pending_system] = positive_findings + negative_findings
238
-
239
- if current_idx < len(ros_systems):
240
- next_system = ros_systems[current_idx]
241
- reply = f"Let's review your {next_system} system. Any {next_system} symptoms? Please mention what's present and what's not."
242
- return {
243
- "messages": [{"role": "assistant", "content": reply}],
244
- "ros": ros,
245
- "current_node": "ros",
246
- "ros_systems": ros_systems,
247
- "ros_current_index": current_idx + 1,
248
- "ros_pending_system": next_system,
249
- "last_processed_message_index": len(messages),
250
- "vague_retry_field": None,
251
- }
252
- else:
253
- reply = "Thank you. I have enough information."
254
- return {
255
- "messages": [{"role": "assistant", "content": reply}],
256
- "ros": ros,
257
- "current_node": "brief_generator",
258
- "ros_systems": ros_systems,
259
- "ros_current_index": current_idx,
260
- "ros_pending_system": None,
261
- "last_processed_message_index": len(messages),
262
- "vague_retry_field": None,
263
- }
264
-
265
- if current_idx < len(ros_systems):
266
- next_system = ros_systems[current_idx]
267
- reply = f"Let's start with your {next_system} system. Any {next_system} symptoms? Please mention what's present and what's not."
268
- return {
269
- "messages": [{"role": "assistant", "content": reply}],
270
- "current_node": "ros",
271
- "ros_systems": ros_systems,
272
- "ros_current_index": current_idx + 1,
273
- "ros_pending_system": next_system,
274
- "last_processed_message_index": last_idx,
275
- "vague_retry_field": None,
276
- }
277
 
278
- reply = "Continuing review of systems..."
279
  return {
280
  "messages": [{"role": "assistant", "content": reply}],
 
281
  "current_node": "ros",
282
  "ros_systems": ros_systems,
283
- "ros_current_index": current_idx,
284
- "ros_pending_system": None,
285
- "last_processed_message_index": last_idx,
286
- "vague_retry_field": None,
287
  }
288
 
289
 
 
 
290
  from datetime import datetime, timezone
291
  from app.schemas import HPI as HPIModel, ClinicalBrief as ClinicalBriefModel
292
 
293
 
294
  def brief_generator_node(state: IntakeState) -> dict:
295
- ros = state.get("ros", {})
296
- hpi_data = state.get("hpi", {})
297
-
298
- hpi_obj = HPIModel(
299
- onset=hpi_data.get("onset") or "not specified",
300
- location=hpi_data.get("location") or "not specified",
301
- duration=hpi_data.get("duration") or "not specified",
302
- character=hpi_data.get("character") or "not specified",
303
- severity=hpi_data.get("severity") or "not specified",
304
- aggravating=hpi_data.get("aggravating") or "not specified",
305
- relieving=hpi_data.get("relieving") or "not specified",
306
- )
307
 
308
  brief = ClinicalBriefModel(
309
  chief_complaint=state.get("chief_complaint", ""),
310
  hpi=hpi_obj,
311
- ros=ros,
312
  generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
313
  )
314
 
315
- reply = "Your clinical intake is complete. Here is your summary."
316
  return {
317
- "messages": [{"role": "assistant", "content": reply}],
318
  "current_node": "done",
319
  "clinical_brief": brief.model_dump(),
320
- "ros_systems": state.get("ros_systems", []),
321
- "ros_current_index": state.get("ros_current_index", 0),
322
- "ros_pending_system": None,
323
- "last_processed_message_index": len(state.get("messages", [])),
324
- "vague_retry_field": None,
325
  }
326
 
327
 
328
- def route_from_intake(state: IntakeState) -> str:
329
- return "hpi"
330
-
331
-
332
- def route_from_hpi(state: IntakeState) -> str:
333
- hpi = state.get("hpi", {})
334
- all_filled = all(hpi.get(f) for f in HPI_FIELDS)
335
- return "ros" if all_filled else "hpi"
336
-
337
-
338
- def route_from_ros(state: IntakeState) -> str:
339
- ros_systems = state.get("ros_systems", [])
340
- current_index = state.get("ros_current_index", 0)
341
- all_processed = current_index >= len(ros_systems)
342
- return "brief_generator" if all_processed else "ros"
343
-
344
-
345
- def build_graph() -> tuple:
346
  workflow = StateGraph(IntakeState)
347
 
348
  workflow.add_node("intake", intake_node)
@@ -351,12 +315,12 @@ def build_graph() -> tuple:
351
  workflow.add_node("brief_generator", brief_generator_node)
352
 
353
  workflow.add_edge(START, "intake")
354
- workflow.add_conditional_edges("intake", route_from_intake, {"hpi": "hpi"})
355
- workflow.add_conditional_edges("hpi", route_from_hpi, {"hpi": "hpi", "ros": "ros"})
356
- workflow.add_conditional_edges("ros", route_from_ros, {"ros": "ros", "brief_generator": "brief_generator"})
357
  workflow.add_edge("brief_generator", END)
358
 
359
  checkpointer = MemorySaver()
360
- graph = workflow.compile(checkpointer=checkpointer, interrupt_after=["intake", "hpi", "ros"])
361
 
362
- return graph, checkpointer
 
1
  from typing import Optional, TypedDict, Annotated
2
  from langgraph.graph import StateGraph, START, END
3
  from langgraph.checkpoint.memory import MemorySaver
4
+ import os
5
+ from app.llm import get_llm
6
+
7
+ _MOCK = lambda: os.environ.get("MOCK_LLM", "true").lower() == "true"
8
+
9
+ SYSTEM_PROMPT = """
10
+ You are a clinical intake assistant.
11
+
12
+ Rules:
13
+ - Ask exactly ONE question at a time
14
+ - Keep responses under 20 words
15
+ - Be clear and direct
16
+ - No explanations unless asked
17
+ """
18
+
19
+
20
+ def _ask(prompt: str) -> str:
21
+ llm = get_llm()
22
+ try:
23
+ return llm.ask(prompt, system=SYSTEM_PROMPT)
24
+ except TypeError:
25
+ # fallback if system param not supported
26
+ return llm.ask(prompt)
27
 
28
 
29
  def add_messages(left: list[dict], right: list[dict]) -> list[dict]:
 
61
  "location": "where exactly you feel it",
62
  "duration": "how long each episode lasts",
63
  "character": "what the pain feels like",
64
+ "severity": "pain severity (1-10)",
65
+ "aggravating": "what makes symptoms worse",
66
+ "relieving": "what relieves symptoms",
67
  }
68
 
69
  CC_KEYWORDS_TO_ROS = {
 
108
  return any(phrase in answer_lower for phrase in vague_phrases)
109
 
110
 
111
+ # -------------------- NODES --------------------
112
+
113
  def intake_node(state: IntakeState) -> dict:
114
  messages = state.get("messages", [])
115
  last_idx = state.get("last_processed_message_index", 0)
 
121
  user_msg = messages[-1]
122
  if user_msg.get("role") == "user":
123
  cc = user_msg.get("content", "")
124
+
125
+ if _MOCK():
126
+ reply = f"I understand you're experiencing {cc}. Let me ask a few questions."
127
+ else:
128
+ reply = _ask(
129
+ f"Patient says: '{cc}'. "
130
+ "Reply in one short sentence. Acknowledge and say you will ask a few questions."
131
+ )
132
  else:
133
+ reply = "What brings you in today?"
134
  elif not cc:
135
+ reply = "What brings you in today?"
136
  else:
137
+ return {"current_node": "hpi"}
 
 
138
 
139
  return {
140
  "messages": [{"role": "assistant", "content": reply}],
141
  "chief_complaint": cc,
142
  "current_node": "hpi",
 
 
 
143
  "last_processed_message_index": len(messages) if has_new_user_msg else last_idx,
144
  "vague_retry_field": None,
145
  }
 
150
  last_idx = state.get("last_processed_message_index", 0)
151
  hpi = dict(state.get("hpi", {}))
152
  vague_retry_field = state.get("vague_retry_field")
153
+ cc = state.get("chief_complaint", "")
154
 
155
  next_field = vague_retry_field
156
  if not next_field:
 
160
  break
161
 
162
  if next_field is None:
 
163
  return {
164
+ "messages": [{"role": "assistant", "content": "Now I’ll ask about other symptoms."}],
165
  "current_node": "ros",
166
  "last_processed_message_index": len(messages),
167
  "vague_retry_field": None,
168
  }
169
 
170
  has_new_user_msg = len(messages) > last_idx
171
+
172
  if has_new_user_msg:
173
+ user_msg = next((m for m in messages[last_idx:] if m["role"] == "user"), None)
174
+
 
 
 
 
175
  if user_msg:
176
+ answer = user_msg["content"]
177
 
178
  if _is_vague_answer(answer):
179
+ field_context = HPI_FIELD_CONTEXT[next_field]
180
+
181
+ if _MOCK():
182
+ reply = f"Please be more specific about {field_context}."
183
+ else:
184
+ reply = _ask(
185
+ f"Patient response about {field_context} was vague. "
186
+ "Ask for clarification in one short sentence."
187
+ )
188
+
189
  return {
190
  "messages": [{"role": "assistant", "content": reply}],
191
  "current_node": "hpi",
 
197
 
198
  next_idx = HPI_FIELDS.index(next_field)
199
  if next_idx < len(HPI_FIELDS) - 1:
200
+ next_field = HPI_FIELDS[next_idx + 1]
201
+
202
+ if _MOCK():
203
+ reply = HPI_QUESTIONS[next_field]
204
+ else:
205
+ reply = _ask(
206
+ f"Complaint: {cc}. Known info: {hpi}. "
207
+ f"Ask ONE question about {HPI_FIELD_CONTEXT[next_field]}."
208
+ )
209
+
210
+ return {
211
+ "messages": [{"role": "assistant", "content": reply}],
212
+ "hpi": hpi,
213
+ "current_node": "hpi",
214
+ "last_processed_message_index": len(messages),
215
+ "vague_retry_field": None,
216
+ }
217
 
218
  return {
219
+ "messages": [{"role": "assistant", "content": "Now I’ll ask about other symptoms."}],
220
  "hpi": hpi,
221
+ "current_node": "ros",
222
  "last_processed_message_index": len(messages),
223
  "vague_retry_field": None,
224
  }
225
 
226
+ if _MOCK():
227
+ reply = HPI_QUESTIONS[next_field]
228
+ else:
229
+ reply = _ask(
230
+ f"Complaint: {cc}. Known info: {hpi}. "
231
+ f"Ask ONE question about {HPI_FIELD_CONTEXT[next_field]}."
232
+ )
233
+
234
  return {
235
  "messages": [{"role": "assistant", "content": reply}],
236
  "current_node": "hpi",
 
245
  ros = dict(state.get("ros", {}))
246
  cc = state.get("chief_complaint", "")
247
 
248
+ ros_systems = state.get("ros_systems") or get_relevant_ros_systems(cc)
 
 
 
249
  current_idx = state.get("ros_current_index", 0)
250
+ pending = state.get("ros_pending_system")
251
 
252
  if current_idx >= len(ros_systems):
 
253
  return {
254
+ "messages": [{"role": "assistant", "content": "I have enough information."}],
255
  "current_node": "brief_generator",
 
 
 
256
  "last_processed_message_index": len(messages),
 
257
  }
258
 
259
  has_new_user_msg = len(messages) > last_idx
260
 
261
+ if has_new_user_msg and pending:
262
+ answer = messages[-1]["content"]
263
+ ros[pending] = [f.strip() for f in answer.split(",")]
 
264
 
265
+ next_system = ros_systems[current_idx]
 
 
266
 
267
+ if _MOCK():
268
+ reply = f"Any {next_system} symptoms? Mention present and absent."
269
+ else:
270
+ reply = _ask(
271
+ f"Ask about {next_system} symptoms. One short question. "
272
+ "Ask for both present and absent symptoms."
273
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
 
275
  return {
276
  "messages": [{"role": "assistant", "content": reply}],
277
+ "ros": ros,
278
  "current_node": "ros",
279
  "ros_systems": ros_systems,
280
+ "ros_current_index": current_idx + 1,
281
+ "ros_pending_system": next_system,
282
+ "last_processed_message_index": len(messages),
 
283
  }
284
 
285
 
286
+ # -------------------- FINAL --------------------
287
+
288
  from datetime import datetime, timezone
289
  from app.schemas import HPI as HPIModel, ClinicalBrief as ClinicalBriefModel
290
 
291
 
292
  def brief_generator_node(state: IntakeState) -> dict:
293
+ hpi_obj = HPIModel(**{f: state.get("hpi", {}).get(f) or "not specified" for f in HPI_FIELDS})
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  brief = ClinicalBriefModel(
296
  chief_complaint=state.get("chief_complaint", ""),
297
  hpi=hpi_obj,
298
+ ros=state.get("ros", {}),
299
  generated_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
300
  )
301
 
 
302
  return {
303
+ "messages": [{"role": "assistant", "content": "Intake complete. Here is your summary."}],
304
  "current_node": "done",
305
  "clinical_brief": brief.model_dump(),
 
 
 
 
 
306
  }
307
 
308
 
309
+ def build_graph():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  workflow = StateGraph(IntakeState)
311
 
312
  workflow.add_node("intake", intake_node)
 
315
  workflow.add_node("brief_generator", brief_generator_node)
316
 
317
  workflow.add_edge(START, "intake")
318
+ workflow.add_edge("intake", "hpi")
319
+ workflow.add_edge("hpi", "ros")
320
+ workflow.add_edge("ros", "brief_generator")
321
  workflow.add_edge("brief_generator", END)
322
 
323
  checkpointer = MemorySaver()
324
+ graph = workflow.compile(checkpointer=checkpointer)
325
 
326
+ return graph, checkpointer
app/llm.py CHANGED
@@ -1,82 +1,76 @@
1
  import os
2
 
 
 
 
 
 
 
3
 
4
- class MockLLM:
5
- """Mock LLM for testing - returns hardcoded clinical responses."""
6
 
 
7
  def __init__(self):
8
  self.hpi_fields = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
9
  self.current_hpi_index = 0
10
  self.ros_systems_done = False
11
- self.ros_current_system = 0
12
-
13
- def generate_response(self, conversation_history: list[dict], current_node: str) -> str:
14
- if current_node == "intake":
15
- return "I have chest pain since this morning"
16
-
17
- if current_node == "hpi":
18
- responses = [
19
- "It started about 3 hours ago",
20
- "In the center of my chest",
21
- "It has been constant",
22
- "It feels like pressure",
23
- "About a 7 out of 10",
24
- "It gets worse when I walk",
25
- "Resting helps a little"
26
- ]
27
- if self.current_hpi_index < len(responses):
28
- response = responses[self.current_hpi_index]
29
- self.current_hpi_index += 1
30
- return response
31
- return "I already answered all those questions"
32
-
33
- if current_node == "ros":
34
- if not self.ros_systems_done:
35
- self.ros_systems_done = True
36
- return "cardiac:palpitations present,no syncope|respiratory:mild shortness of breath,no cough"
37
- return "done"
38
 
39
- return ""
 
40
 
41
  def reset(self):
42
  self.current_hpi_index = 0
43
  self.ros_systems_done = False
44
- self.ros_current_system = 0
45
 
46
 
47
- class RealLLM:
48
- """Real LLM using llama-cpp-python with lazy loading."""
49
-
50
  def __init__(self):
51
  self.model = None
52
- self.model_path = "/models/qwen2.5-0.5b-instruct-q4_k_m.gguf"
 
53
 
54
- def _load_model(self):
55
  if self.model is None:
56
- from llama_cpp import Llama
57
- self.model = Llama(
58
- model_path=self.model_path,
59
- n_ctx=2048,
60
- n_threads=4
 
 
61
  )
62
 
63
- def generate_response(self, conversation_history: list[dict], current_node: str) -> str:
64
- self._load_model()
65
-
66
- system_prompt = (
67
- "You are a clinical AI assistant conducting patient intake. "
68
- "Ask one question at a time. Be concise and professional."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
 
70
 
71
- messages = [{"role": "system", "content": system_prompt}]
72
- messages.extend(conversation_history)
73
 
74
- output = self.model.create_chat_completion(messages, max_tokens=256)
75
- return output["choices"][0]["message"]["content"]
76
 
77
 
78
  def get_llm():
79
- mock_mode = os.environ.get("MOCK_LLM", "false").lower() == "true"
80
- if mock_mode:
81
- return MockLLM()
82
- return RealLLM()
 
 
1
  import os
2
 
3
+ CLINICAL_SYSTEM_PROMPT = (
4
+ "You are a clinical intake assistant conducting a pre-visit patient interview. "
5
+ "Ask one clear, concise, professional medical question at a time. "
6
+ "Do not diagnose or give medical advice. Keep responses under 2 sentences. "
7
+ "Be empathetic but professional."
8
+ )
9
 
 
 
10
 
11
+ class MockLLM:
12
  def __init__(self):
13
  self.hpi_fields = ["onset", "location", "duration", "character", "severity", "aggravating", "relieving"]
14
  self.current_hpi_index = 0
15
  self.ros_systems_done = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def ask(self, instruction: str) -> str:
18
+ return "" # unused in mock mode — graph uses hardcoded questions
19
 
20
  def reset(self):
21
  self.current_hpi_index = 0
22
  self.ros_systems_done = False
 
23
 
24
 
25
+ class TransformersLLM:
 
 
26
  def __init__(self):
27
  self.model = None
28
+ self.tokenizer = None
29
+ self.model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
30
 
31
+ def _load(self):
32
  if self.model is None:
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
+ import torch
35
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
36
+ self.model = AutoModelForCausalLM.from_pretrained(
37
+ self.model_name,
38
+ torch_dtype=torch.float32,
39
+ device_map="cpu",
40
  )
41
 
42
+ def ask(self, instruction: str) -> str:
43
+ self._load()
44
+ import torch
45
+ messages = [
46
+ {"role": "system", "content": CLINICAL_SYSTEM_PROMPT},
47
+ {"role": "user", "content": instruction},
48
+ ]
49
+ text = self.tokenizer.apply_chat_template(
50
+ messages, tokenize=False, add_generation_prompt=True
51
+ )
52
+ inputs = self.tokenizer(text, return_tensors="pt")
53
+ with torch.no_grad():
54
+ outputs = self.model.generate(
55
+ **inputs,
56
+ max_new_tokens=80,
57
+ temperature=0.3,
58
+ do_sample=True,
59
+ pad_token_id=self.tokenizer.eos_token_id,
60
+ )
61
+ response = self.tokenizer.decode(
62
+ outputs[0][inputs.input_ids.shape[1]:],
63
+ skip_special_tokens=True,
64
  )
65
+ return response.strip()
66
 
 
 
67
 
68
+ _llm_instance = None
 
69
 
70
 
71
  def get_llm():
72
+ global _llm_instance
73
+ if _llm_instance is None:
74
+ mock_mode = os.environ.get("MOCK_LLM", "true").lower() == "true"
75
+ _llm_instance = MockLLM() if mock_mode else TransformersLLM()
76
+ return _llm_instance
requirements.txt CHANGED
@@ -5,4 +5,7 @@ pydantic
5
  pytest
6
  httpx
7
  pytest-asyncio
8
- aiofiles
 
 
 
 
5
  pytest
6
  httpx
7
  pytest-asyncio
8
+ aiofiles
9
+ transformers
10
+ huggingface_hub
11
+ accelerate