JDFPalladium commited on
Commit
24e3e87
·
1 Parent(s): faaa805

resolving conflicts

Browse files
Files changed (4) hide show
  1. app.py +12 -11
  2. chatlib/assistant_node.py +73 -5
  3. chatlib/idsr_check.py +103 -24
  4. chatlib/state_types.py +5 -0
app.py CHANGED
@@ -45,7 +45,9 @@ def idsr_check_tool(query):
45
  """Check if the patient case description matches any known diseases."""
46
  result = idsr_check(query, llm=llm)
47
 
48
- return {"answer": result.get("answer", ""), "last_tool": "idsr_check"}
 
 
49
 
50
 
51
  tools = [rag_retrieve_tool, sql_chain_tool, idsr_check_tool]
@@ -58,7 +60,7 @@ You are a helpful assistant supporting clinicians during patient visits. You hav
58
 
59
  - rag_retrieve: to access HIV clinical guidelines
60
  - sql_chain: to access HIV data about the patient with whom the clinician is meeting. When using this tool, always run rag_retrieve first to get context
61
- - idsr_check: to check if the patient case description matches any known diseases
62
 
63
  When a tool is needed, respond only with a JSON object specifying the tool to call and its minimal arguments, for example:
64
  {
@@ -100,16 +102,9 @@ def chat_with_patient(question: str, thread_id: str = None): # type: ignore
100
 
101
  question = detect_and_redact_phi(question)["redacted_text"]
102
 
 
103
  input_state: AppState = {
104
- "messages": [HumanMessage(content=question)],
105
- "question": "",
106
- "rag_result": "",
107
- "answer": "",
108
- "last_answer": "",
109
- "last_user_message": "",
110
- "last_tool": None,
111
- "idsr_disclaimer": False,
112
- "summary": None,
113
  }
114
 
115
  config = {"configurable": {"thread_id": thread_id, "user_id": thread_id}}
@@ -125,6 +120,12 @@ def chat_with_patient(question: str, thread_id: str = None): # type: ignore
125
 
126
 
127
  with gr.Blocks() as app:
 
 
 
 
 
 
128
  question_input = gr.Textbox(label="Question")
129
  thread_id_state = gr.State()
130
  output_chat = gr.Textbox(label="Assistant Response")
 
45
  """Check if the patient case description matches any known diseases."""
46
  result = idsr_check(query, llm=llm)
47
 
48
+ return {"answer": result.get("answer", ""),
49
+ "last_tool": "idsr_check",
50
+ "context": result.get("context", None)}
51
 
52
 
53
  tools = [rag_retrieve_tool, sql_chain_tool, idsr_check_tool]
 
60
 
61
  - rag_retrieve: to access HIV clinical guidelines
62
  - sql_chain: to access HIV data about the patient with whom the clinician is meeting. When using this tool, always run rag_retrieve first to get context
63
+ - idsr_check: to check if the patient case description matches any known diseases.
64
 
65
  When a tool is needed, respond only with a JSON object specifying the tool to call and its minimal arguments, for example:
66
  {
 
102
 
103
  question = detect_and_redact_phi(question)["redacted_text"]
104
 
105
+ # First turn: initialize state
106
  input_state: AppState = {
107
+ "messages": [HumanMessage(content=question)]
 
 
 
 
 
 
 
 
108
  }
109
 
110
  config = {"configurable": {"thread_id": thread_id, "user_id": thread_id}}
 
120
 
121
 
122
  with gr.Blocks() as app:
123
+ gr.Markdown(
124
+ """
125
+ # Clinician Assistant
126
+ Welcome! Enter your clinical question below. The assistant can access HIV guidelines, patient data, and disease surveillance tools.
127
+ """
128
+ )
129
  question_input = gr.Textbox(label="Question")
130
  thread_id_state = gr.State()
131
  output_chat = gr.Textbox(label="Assistant Response")
chatlib/assistant_node.py CHANGED
@@ -33,6 +33,22 @@ def summarize_conversation(messages, llm):
33
 
34
 
35
  def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  messages = state.get("messages", [])
37
  base_messages = [sys_msg]
38
  messages = base_messages + [m for m in messages if not isinstance(m, SystemMessage)]
@@ -48,18 +64,66 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
48
  state["answer"] = ""
49
  state["rag_result"] = ""
50
 
51
- # Update state from any ToolMessages appended by previous tool calls
52
- # Only consider the most recent ToolMessage for updating state
53
  for msg in reversed(messages):
54
  if isinstance(msg, ToolMessage):
55
  try:
56
  content = msg.content
57
  data = json.loads(content) if isinstance(content, str) else content
58
- state.update(data)
59
- break # only process the most recent ToolMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  except json.JSONDecodeError:
61
  break
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # Invoke LLM with tools (this returns AIMessage with tool_calls if tool call is needed)
64
  new_message = llm_with_tools.invoke(messages)
65
  messages.append(new_message)
@@ -99,6 +163,10 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
99
  final_content = disclaimer + final_content
100
  state["idsr_disclaimer_shown"] = True
101
 
 
 
 
 
102
  # Replace the last AIMessage content with final_content to avoid duplicates
103
  for i in reversed(range(len(messages))):
104
  if isinstance(messages[i], AIMessage):
@@ -114,7 +182,7 @@ def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
114
  m for m in non_sys_messages if isinstance(m, (HumanMessage, AIMessage))
115
  ]
116
 
117
- if len(human_ai_messages) > 15:
118
  summary_text = summarize_conversation(messages, llm)
119
  summary_msg = SystemMessage(
120
  content="Summary of earlier conversation:\n" + summary_text
 
33
 
34
 
35
  def assistant(state: AppState, sys_msg, llm, llm_with_tools) -> AppState:
36
+
37
+ # Initialize missing keys with defaults
38
+ state.setdefault("question", "")
39
+ state.setdefault("rag_result", "")
40
+ state.setdefault("answer", "")
41
+ state.setdefault("last_answer", None)
42
+ state.setdefault("last_user_message", None)
43
+ state.setdefault("last_tool", None)
44
+ state.setdefault("idsr_disclaimer_shown", False)
45
+ state.setdefault("summary", None)
46
+ state.setdefault("context", None)
47
+ state.setdefault("context_versions", {})
48
+ state.setdefault("last_context_injected_versions", {})
49
+ state.setdefault("context_version_ready_for_injection", 0)
50
+ state.setdefault("context_first_response_sent", True)
51
+
52
  messages = state.get("messages", [])
53
  base_messages = [sys_msg]
54
  messages = base_messages + [m for m in messages if not isinstance(m, SystemMessage)]
 
64
  state["answer"] = ""
65
  state["rag_result"] = ""
66
 
67
+ # Process latest ToolMessage and update context_version
 
68
  for msg in reversed(messages):
69
  if isinstance(msg, ToolMessage):
70
  try:
71
  content = msg.content
72
  data = json.loads(content) if isinstance(content, str) else content
73
+
74
+ tool_name = data.get("last_tool")
75
+ new_context = data.get("context")
76
+
77
+ if tool_name:
78
+ old_context = state.get("context", "")
79
+ old_version = state["context_versions"].get(tool_name, 0)
80
+
81
+ if new_context is not None and new_context != old_context:
82
+ state["context"] = new_context
83
+ state["context_versions"][tool_name] = old_version + 1
84
+ state["context_first_response_sent"] = False # Reset flag on new context
85
+
86
+ state["last_tool"] = tool_name
87
+
88
+ for k, v in data.items():
89
+ if k not in ("context", "last_tool"):
90
+ state[k] = v
91
+
92
+ break
93
  except json.JSONDecodeError:
94
  break
95
 
96
+ tool_name = "idsr_check"
97
+ current_version = state["context_versions"].get(tool_name, 0)
98
+ last_injected_version = state["last_context_injected_versions"].get(tool_name, 0)
99
+
100
+ # On turns where user message is unchanged, advance ready_for_injection to current_version
101
+ if not user_message_changed and state["context_version_ready_for_injection"] < current_version:
102
+ state["context_version_ready_for_injection"] = current_version
103
+
104
+ # Inject context system message only if:
105
+ # - last_tool matches tool_name
106
+ # - context exists
107
+ # - ready_for_injection > last injected version
108
+ # - AND first AI response after new context has been sent
109
+ if (
110
+ state.get("last_tool") == tool_name
111
+ and state.get("context")
112
+ and state["context_version_ready_for_injection"] > last_injected_version
113
+ and state.get("context_first_response_sent", True)
114
+ ):
115
+ context_msg = SystemMessage(
116
+ content=(
117
+ f"The following information was retrieved from the {tool_name.upper()} database and may help answer the user's question:\n\n"
118
+ f"{state['context']}\n\n"
119
+ "Use this information when responding."
120
+ )
121
+ )
122
+ messages.append(context_msg)
123
+
124
+ state["last_context_injected_versions"][tool_name] = state["context_version_ready_for_injection"]
125
+ state["last_tool"] = None
126
+
127
  # Invoke LLM with tools (this returns AIMessage with tool_calls if tool call is needed)
128
  new_message = llm_with_tools.invoke(messages)
129
  messages.append(new_message)
 
163
  final_content = disclaimer + final_content
164
  state["idsr_disclaimer_shown"] = True
165
 
166
+ # After generating AI message, mark first response sent
167
+ if state.get("last_tool") == tool_name or state.get("context_first_response_sent") is False:
168
+ state["context_first_response_sent"] = True
169
+
170
  # Replace the last AIMessage content with final_content to avoid duplicates
171
  for i in reversed(range(len(messages))):
172
  if isinstance(messages[i], AIMessage):
 
182
  m for m in non_sys_messages if isinstance(m, (HumanMessage, AIMessage))
183
  ]
184
 
185
+ if len(human_ai_messages) > 10:
186
  summary_text = summarize_conversation(messages, llm)
187
  summary_msg = SystemMessage(
188
  content="Summary of earlier conversation:\n" + summary_text
chatlib/idsr_check.py CHANGED
@@ -9,6 +9,8 @@ from langchain_core.output_parsers import PydanticOutputParser
9
  import json
10
  import math
11
  from collections import Counter
 
 
12
 
13
 
14
  with open("./guidance_docs/idsr_keywords.txt", "r", encoding="utf-8") as f:
@@ -39,6 +41,15 @@ keyword_weights = {
39
  kw: math.log(total_docs / (1 + count)) for kw, count in keyword_doc_counts.items()
40
  }
41
 
 
 
 
 
 
 
 
 
 
42
 
43
  def score_doc(doc_to_score, matched_keywords):
44
  doc_keywords = set(doc_to_score.metadata.get("matched_keywords", []))
@@ -110,9 +121,9 @@ def hybrid_search_with_query_keywords(
110
 
111
  ranked_docs = sorted(scored_docs, key=lambda x: -x[1])
112
  top_docs = [doc for doc, score in ranked_docs if score > 0]
113
- top_3_docs = top_docs[:3]
114
 
115
- merged = {doc.page_content: doc for doc in semantic_hits + top_3_docs}
116
  return list(merged.values())
117
 
118
 
@@ -130,47 +141,96 @@ def idsr_check(query: str, llm) -> AppState:
130
  results = hybrid_search_with_query_keywords(
131
  query, vectorstore, tagged_documents, keywords, llm
132
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  disease_definitions = "\n\n".join(
135
  [
136
- f"{doc.metadata.get('disease_name', 'Unknown Disease')}:\n{doc.page_content}"
137
  for doc in results
138
  ]
139
  )
140
 
141
- prompt = """
142
- You are a medical assistant reviewing a brief clinical case in Kenya to help identify which diseases the patient may plausibly have. You have access to several disease definitions.
143
-
144
- Your task is as follows:
145
- 1. Carefully compare the case description to each disease definition.
146
- 2. If a disease seems like a possible match based on the available information, list it and explain why.
147
- 3. Only include rare diseases (e.g., eradicated or non-endemic to Kenya) if the match is extremely strong. Prioritize common and plausible conditions.
148
- 4. If no disease clearly matches, say: "No strong match found."
149
- 5. Ask clarifying questions if helpful to make better match suggestions.
150
- 6. After asking clarifying questions, proceed with an assessment anyway based on what is already available.
151
 
152
- Case:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  {query}
154
 
155
- Diseases:
156
  {disease_definitions}
157
 
158
- Your response should be brief and include as appropriate:
 
 
 
 
 
 
 
 
 
159
 
160
  Possible matches:
161
- - Disease Name: [Likely] - Reason
162
- - Disease Name: [Probable] - Reason
163
- (Only include diseases that clearly fit based on the information. If none, say "No strong match found.")
164
 
165
- Clarifying questions (optional, only if needed):
166
  - Question 1
167
  - Question 2
168
 
169
- At the end, always give a brief recommendation like:
170
- - Recommendation: "Suggest monitoring for the listed conditions." OR "No disease meets criteria based on current data — suggest gathering additional history on [x, y, z]."
171
 
172
  """.format(
173
- query=query, disease_definitions=disease_definitions
 
 
 
174
  )
175
 
176
  llm_response = llm.invoke(prompt)
@@ -180,4 +240,23 @@ def idsr_check(query: str, llm) -> AppState:
180
  else "No relevant disease information found."
181
  )
182
 
183
- return {"answer": answer_text, "last_tool": "idsr_check"} # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import json
10
  import math
11
  from collections import Counter
12
+ import sqlite3
13
+ import os
14
 
15
 
16
  with open("./guidance_docs/idsr_keywords.txt", "r", encoding="utf-8") as f:
 
41
  kw: math.log(total_docs / (1 + count)) for kw, count in keyword_doc_counts.items()
42
  }
43
 
44
+ ## prepare to get location data
45
+ # first, get sitecode from environment variable
46
+ sitecode = os.environ.get("SITECODE")
47
+ # next, connect to location database and get county where code = sitecode
48
+ conn = sqlite3.connect('data/location_data.sqlite')
49
+ cursor = conn.cursor()
50
+ cursor.execute("SELECT County FROM sitecode_county_xwalk WHERE Code = ?", (sitecode,))
51
+ county = cursor.fetchone()
52
+ conn.close()
53
 
54
  def score_doc(doc_to_score, matched_keywords):
55
  doc_keywords = set(doc_to_score.metadata.get("matched_keywords", []))
 
121
 
122
  ranked_docs = sorted(scored_docs, key=lambda x: -x[1])
123
  top_docs = [doc for doc, score in ranked_docs if score > 0]
124
+ top_5_docs = top_docs[:5]
125
 
126
+ merged = {doc.page_content: doc for doc in semantic_hits + top_5_docs}
127
  return list(merged.values())
128
 
129
 
 
141
  results = hybrid_search_with_query_keywords(
142
  query, vectorstore, tagged_documents, keywords, llm
143
  )
144
+
145
+ # set up connection to location database and get EpidemicInfo for any diseases in the disease_name metadata field of the results from the hybrid search
146
+ conn = sqlite3.connect('data/location_data.sqlite')
147
+ cursor = conn.cursor()
148
+ disease_names = [doc.metadata.get("disease_name") for doc in results]
149
+ placeholders = ",".join("?" * len(disease_names))
150
+ query_str = f"SELECT Disease, EpidemicInfo FROM who_bulletin WHERE Disease IN ({placeholders})"
151
+ cursor.execute(query_str, disease_names)
152
+ epidemic_info = cursor.fetchall()
153
+ conn.close()
154
+
155
+ # print(doc.metadata.get("disease_name") for doc in results)
156
+
157
+ # set up connection to location database and get results where County = county and Disease is in
158
+ # the disease_name metadata field of the results from the hybrid search
159
+ conn = sqlite3.connect('data/location_data.sqlite')
160
+ cursor = conn.cursor()
161
+ if county: # Ensure county is not None
162
+ county_name = county[0]
163
+ disease_names = [doc.metadata.get("disease_name") for doc in results]
164
+ placeholders = ",".join("?" * len(disease_names))
165
+ query_str = f"SELECT County, Disease, Prevalence, Notes FROM county_disease_info WHERE County = ? AND Disease IN ({placeholders})"
166
+ cursor.execute(query_str, (county_name, *disease_names))
167
+ county_info = cursor.fetchall()
168
+
169
+ # Get climate information for the county from the rainy seasons table
170
+ # Get the current month
171
+ from datetime import datetime
172
+ current_month = datetime.now().strftime("%B") # Full month name, e.g. "March"
173
+ cursor.execute("SELECT RainySeason FROM county_rainy_seasons WHERE County = ? and Month = ?", (county_name, current_month))
174
+ rainy_season = cursor.fetchone()
175
+ rainy_season = rainy_season[0] if rainy_season else "Unknown"
176
+
177
+ # close the connection
178
+ conn.close()
179
 
180
  disease_definitions = "\n\n".join(
181
  [
182
+ f"### Disease: {doc.metadata.get('disease_name', 'Unknown Disease')}:\n{doc.page_content}"
183
  for doc in results
184
  ]
185
  )
186
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ prompt = """
189
+ You are a medical assistant reviewing a brief clinical case in Kenya to help identify which diseases the patient may plausibly have.
190
+ You have access to several disease definitions. You also have access to information about the prevalence of each disease in the county
191
+ where the patient is located. The prevalence of some diseases varies by season, and some diseases are also more likely when there is a
192
+ declared epidemic. Information on the timing of the rainy season and any declared epidemics is also provided.
193
+
194
+ ## Instructions:
195
+ 1. Carefully compare the case description to each disease definition, taking into account the prevalence and seasonality information.
196
+ 2. If a disease seems like a possible match based on the available information, list it and explain why.
197
+ 3. Only include rare diseases, or diseases that don't fit seasonally, if the match is extremely strong. Prioritize common and plausible conditions.
198
+ 4. You don't need to suggest matches if none of the diseases seem relevant.
199
+ 5. Ask clarifying questions if helpful to make better match suggestions. Possible questions might include asking about specific symptoms, demographic characteristics, exposures, or travel history.
200
+ 6. At the end, give a brief recommendation on next steps, such as monitoring for certain conditions or gathering additional history.
201
+
202
+ ## Case:
203
  {query}
204
 
205
+ ## Diseases:
206
  {disease_definitions}
207
 
208
+ ## Locational context:
209
+ In {county_name}, the current rainy season status is {rainy_season}.
210
+
211
+ The above diseases have the following prevalence (county, disease name, prevalence, seasonality):
212
+ {county_info}
213
+
214
+ Here are any relevant epidemic alerts for these diseases:
215
+ {epidemic_info}
216
+
217
+ ## Expected Output
218
 
219
  Possible matches:
220
+ - Disease Name: Reason
221
+ - Disease Name: Reason
 
222
 
223
+ Clarifying questions:
224
  - Question 1
225
  - Question 2
226
 
227
+ Recommendation:
 
228
 
229
  """.format(
230
+ query=query, disease_definitions=disease_definitions, county_name=county_name if county else "Unknown County",
231
+ rainy_season=rainy_season if county else "Unknown",
232
+ county_info="\n".join([f"- {row[0]}, {row[1]}, Prevalence: {row[2]}, Seasonality: {row[3]}" for row in county_info]) if county else "No county information available.",
233
+ epidemic_info="\n".join([f"- {row[0]}: {row[1]}" for row in epidemic_info]) if epidemic_info else "No epidemic information available."
234
  )
235
 
236
  llm_response = llm.invoke(prompt)
 
240
  else "No relevant disease information found."
241
  )
242
 
243
+ # Set up context to return.
244
+ # First, use an LLM to identify which diseases from disease_definitions were mentioned in the answer_text
245
+ disease_names_in_answer = [doc.metadata.get("disease_name") for doc in results if doc.metadata.get("disease_name") in answer_text]
246
+ # Next, filter the results to only include those diseases
247
+ filtered_results = [doc for doc in results if doc.metadata.get("disease_name") in disease_names_in_answer]
248
+ # Finally, create context string with only those diseases, plus any county_info and epidemic_info
249
+ context_parts = []
250
+ if filtered_results:
251
+ context_parts.append("### Disease Definitions:\n" + "\n\n".join(
252
+ [
253
+ f"### Disease: {doc.metadata.get('disease_name', 'Unknown Disease')}:\n{doc.page_content}"
254
+ for doc in filtered_results
255
+ ]
256
+ ))
257
+ if county and county_info:
258
+ context_parts.append("### County Disease Information:\n" + "\n".join([f"- {row[0]}, {row[1]}, Prevalence: {row[2]}, Seasonality: {row[3]}" for row in county_info]))
259
+ if epidemic_info:
260
+ context_parts.append("### Epidemic Information:\n" + "\n".join([f"- {row[0]}: {row[1]}" for row in epidemic_info]))
261
+
262
+ return {"answer": answer_text, "last_tool": "idsr_check", "context": context_parts} # type: ignore
chatlib/state_types.py CHANGED
@@ -15,3 +15,8 @@ class AppState(TypedDict):
15
  last_tool: Optional[str] = None
16
  idsr_disclaimer_shown: bool = False
17
  summary: Optional[str] = None
 
 
 
 
 
 
15
  last_tool: Optional[str] = None
16
  idsr_disclaimer_shown: bool = False
17
  summary: Optional[str] = None
18
+ context: Optional[str] = None
19
+ context_versions: dict[str, int] = {}
20
+ last_context_injected_versions: dict[str, int] = {}
21
+ context_version_ready_for_injection: int = 0
22
+ context_first_response_sent: bool = True