JDFPalladium commited on
Commit
35274a7
·
1 Parent(s): 1255a5e

adding idsr define tool and reflecting tweaks to other scripts and notebooks

Browse files
app.py CHANGED
@@ -22,6 +22,7 @@ from chatlib.state_types import AppState
22
  from chatlib.guidlines_rag_agent_li import rag_retrieve
23
  from chatlib.patient_all_data import sql_chain
24
  from chatlib.idsr_check import idsr_check
 
25
  from chatlib.phi_filter import detect_and_redact_phi
26
  from chatlib.assistant_node import assistant
27
 
@@ -52,8 +53,15 @@ def idsr_check_tool(query, sitecode):
52
  "context": result.get("context", None),
53
  }
54
 
 
 
 
 
 
 
 
55
 
56
- tools = [rag_retrieve_tool, sql_chain_tool, idsr_check_tool]
57
  llm_with_tools = llm.bind_tools(tools)
58
 
59
 
@@ -61,11 +69,12 @@ sys_msg = SystemMessage(
61
  content="""
62
  You are a helpful assistant supporting clinicians during patient visits. When a patient ID is provided, the clinician is meeting with that HIV-positive patient and may inquire about their history, lab results, or medications. If no patient ID is provided, the clinician may be asking general HIV clinical questions or presenting symptoms for a new patient.
63
 
64
- You have access to three tools to help you answer the clinician's questions.
65
 
66
- - rag_retrieve: to access HIV clinical guidelines
67
- - sql_chain: to access HIV data about the patient with whom the clinician is meeting. For straightforward factual questions about the patient, you may call sql_chain directly. For questions requiring clinical interpretation or classification, first call rag_retrieve to get relevant clinical guideline context, then include that context when calling sql_chain.
68
- - idsr_check: to check if the patient case description matches any known diseases.
 
69
 
70
  When a tool is needed, respond only with a JSON object specifying the tool to call and its minimal arguments, for example:
71
  {
@@ -107,6 +116,19 @@ For example:
107
  }
108
  }
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  There are only two cases where a tool is not needed:
111
  1. If the clinician's question is a simple greeting, farewell, or acknowledgement.
112
  2. The answer is clearly and completely present in the prior conversation turns.
 
22
  from chatlib.guidlines_rag_agent_li import rag_retrieve
23
  from chatlib.patient_all_data import sql_chain
24
  from chatlib.idsr_check import idsr_check
25
+ from chatlib.idsr_definition import idsr_define
26
  from chatlib.phi_filter import detect_and_redact_phi
27
  from chatlib.assistant_node import assistant
28
 
 
53
  "context": result.get("context", None),
54
  }
55
 
56
+ def idsr_define_tool(query):
57
+ """Retrieve disease definition based on the query."""
58
+ result = idsr_define(query, llm=llm)
59
+ return {
60
+ "answer": result.get("answer", ""),
61
+ "last_tool": "idsr_define"
62
+ }
63
 
64
+ tools = [rag_retrieve_tool, sql_chain_tool, idsr_check_tool, idsr_define_tool]
65
  llm_with_tools = llm.bind_tools(tools)
66
 
67
 
 
69
  content="""
70
  You are a helpful assistant supporting clinicians during patient visits. When a patient ID is provided, the clinician is meeting with that HIV-positive patient and may inquire about their history, lab results, or medications. If no patient ID is provided, the clinician may be asking general HIV clinical questions or presenting symptoms for a new patient.
71
 
72
+ You have access to four tools to help you answer the clinician's questions.
73
 
74
+ - rag_retrieve_tool: to access HIV clinical guidelines
75
+ - sql_chain_tool: to access HIV data about the patient with whom the clinician is meeting. For straightforward factual questions about the patient, you may call sql_chain directly. For questions requiring clinical interpretation or classification, first call rag_retrieve to get relevant clinical guideline context, then include that context when calling sql_chain.
76
+ - idsr_check_tool: to check if the patient case description matches any known diseases.
77
+ - idsr_define_tool: to retrieve the official case definition of a disease when the clinician asks about it (e.g., “What is the description of cholera?”). Do not use this tool for analyzing symptom descriptions — use `idsr_check_tool` for that.
78
 
79
  When a tool is needed, respond only with a JSON object specifying the tool to call and its minimal arguments, for example:
80
  {
 
116
  }
117
  }
118
 
119
+ When calling the "idsr_define_tool" tool, always include the following arguments in the JSON response:
120
+
121
+ - "query": the clinician's question
122
+
123
+ For example:
124
+
125
+ {
126
+ "tool": "idsr_define_tool",
127
+ "args": {
128
+ "query": "What is the description of cholera?"
129
+ }
130
+ }
131
+
132
  There are only two cases where a tool is not needed:
133
  1. If the clinician's question is a simple greeting, farewell, or acknowledgement.
134
  2. The answer is clearly and completely present in the prior conversation turns.
chatlib/idsr_definition.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_core.output_parsers import PydanticOutputParser
3
+ from pydantic import BaseModel, Field
4
+ from typing import Optional
5
+ from langchain_core.documents import Document
6
+ import json
7
+
8
+ with open("./data/processed/tagged_documents.json", "r", encoding="utf-8") as f:
9
+ doc_dicts = json.load(f)
10
+
11
+ tagged_documents = [Document(**d) for d in doc_dicts]
12
+
13
+ class DiseaseSelectionOutput(BaseModel):
14
+ disease_name: Optional[str] = Field(
15
+ description="The most likely disease the user is asking about, or null if no match is confident"
16
+ )
17
+
18
+
19
+ def select_disease_from_query(query: str, llm, tagged_docs: list[Document]) -> Optional[str]:
20
+ disease_names = [doc.metadata.get("disease_name") for doc in tagged_docs]
21
+ disease_list = "\n".join(f"- {name}" for name in disease_names)
22
+
23
+ parser = PydanticOutputParser(pydantic_object=DiseaseSelectionOutput)
24
+
25
+ prompt = ChatPromptTemplate.from_template(
26
+ """
27
+ You are helping a clinician retrieve a disease definition from a list of IDSR diseases.
28
+
29
+ Given the following query:
30
+ "{query}"
31
+
32
+ Select the single disease from the list below that the query most likely refers to.
33
+
34
+ List of available diseases:
35
+ {disease_list}
36
+
37
+ If no match is clearly appropriate, set "disease_name" to null.
38
+
39
+ {format_instructions}
40
+ """
41
+ )
42
+
43
+ chain = prompt | llm | parser
44
+ output = chain.invoke({
45
+ "query": query,
46
+ "disease_list": disease_list,
47
+ "format_instructions": parser.get_format_instructions()
48
+ })
49
+
50
+ return output.disease_name
51
+
52
+ def idsr_define(query: str, llm) -> dict:
53
+ disease_name = select_disease_from_query(query, llm, tagged_documents)
54
+
55
+ if not disease_name:
56
+ return {
57
+ "answer": "Sorry, I couldn't find a clear match for that disease. Please rephrase or try a different name."
58
+ }
59
+
60
+ # Search for matching doc
61
+ for doc in tagged_documents:
62
+ if doc.metadata.get("disease_name") == disease_name:
63
+ definition = doc.page_content.strip()
64
+
65
+ # Use LLM to generate a helpful answer
66
+ prompt = f"""
67
+ You are a medical assistant helping a clinician understand disease case definitions.
68
+
69
+ Here is a user query:
70
+ "{query}"
71
+
72
+ Here is the official case definition for the relevant disease:
73
+ "{definition}"
74
+
75
+ Based on the case definition, answer the user query clearly and concisely. Do not speculate beyond the information provided.
76
+ """
77
+ llm_response = llm.invoke(prompt)
78
+
79
+ return {
80
+ "answer": llm_response.content.strip()
81
+ }
82
+
83
+ return {
84
+ "answer": f"Sorry, no case definition was found for the selected disease."
85
+ }
chatlib/patient_all_data.py CHANGED
@@ -172,6 +172,9 @@ def sql_chain(query: str, llm, rag_result: str, pk_hash: str) -> dict:
172
  except (ValueError, TypeError):
173
  return "invalid date"
174
 
 
 
 
175
  row = df.iloc[0]
176
  summary = (
177
  f"Sex: {safe(row['Sex'])}\n"
@@ -180,7 +183,7 @@ def sql_chain(query: str, llm, rag_result: str, pk_hash: str) -> dict:
180
  f"Occupation: {safe(row['Occupation'])}\n"
181
  f"OnIPT: {safe(row['OnIPT'])}\n"
182
  f"ARTOutcomeDescription: {safe(row['ARTOutcomeDescription'])}\n"
183
- f"StartARTDate: {safe(row['StartARTDate'])}\n"
184
  f"Age: {calculate_age(safe(row['DOB']))}"
185
  )
186
  return summary
 
172
  except (ValueError, TypeError):
173
  return "invalid date"
174
 
175
+ df = df.copy()
176
+ df["StartARTDate"] = pd.to_datetime(df["StartARTDate"], errors="coerce")
177
+
178
  row = df.iloc[0]
179
  summary = (
180
  f"Sex: {safe(row['Sex'])}\n"
 
183
  f"Occupation: {safe(row['Occupation'])}\n"
184
  f"OnIPT: {safe(row['OnIPT'])}\n"
185
  f"ARTOutcomeDescription: {safe(row['ARTOutcomeDescription'])}\n"
186
+ f"StartARTDate: {describe_relative_date(row['StartARTDate'])}\n"
187
  f"Age: {calculate_age(safe(row['DOB']))}"
188
  )
189
  return summary
notebooks/create_patient_db.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "ddb26634",
7
  "metadata": {},
8
  "outputs": [],
@@ -10,7 +10,7 @@
10
  "import sqlite3\n",
11
  "import pandas as pd\n",
12
  "# inspect current database schema\n",
13
- "conn = sqlite3.connect('patient_slim.sqlite')\n",
14
  "cursor = conn.cursor()\n",
15
  "# list tables\n",
16
  "# pull all data from the visits table \n",
@@ -22,19 +22,41 @@
22
  },
23
  {
24
  "cell_type": "code",
25
- "execution_count": 2,
26
  "id": "cd4faa4b",
27
  "metadata": {},
28
  "outputs": [],
29
  "source": [
30
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
31
- "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
32
  "cursor = conn.cursor() "
33
  ]
34
  },
35
  {
36
  "cell_type": "code",
37
  "execution_count": 3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  "id": "f8547b78",
39
  "metadata": {},
40
  "outputs": [],
@@ -82,17 +104,17 @@
82
  },
83
  {
84
  "cell_type": "code",
85
- "execution_count": 4,
86
  "id": "9ddfa626",
87
  "metadata": {},
88
  "outputs": [
89
  {
90
  "data": {
91
  "text/plain": [
92
- "<sqlite3.Cursor at 0x7d4240c3d840>"
93
  ]
94
  },
95
- "execution_count": 4,
96
  "metadata": {},
97
  "output_type": "execute_result"
98
  }
@@ -110,7 +132,7 @@
110
  },
111
  {
112
  "cell_type": "code",
113
- "execution_count": 5,
114
  "id": "d14ef687",
115
  "metadata": {},
116
  "outputs": [],
@@ -177,12 +199,12 @@
177
  },
178
  {
179
  "cell_type": "code",
180
- "execution_count": 6,
181
  "id": "6e27bce5",
182
  "metadata": {},
183
  "outputs": [],
184
  "source": [
185
- "conn = sqlite3.connect('patient_slim.sqlite')\n",
186
  "cursor = conn.cursor()\n",
187
  "# pull all data from the lab table except for the \"key\" column \n",
188
  "cursor.execute(\"SELECT * FROM lab;\")\n",
@@ -193,19 +215,19 @@
193
  },
194
  {
195
  "cell_type": "code",
196
- "execution_count": 7,
197
  "id": "14402e96",
198
  "metadata": {},
199
  "outputs": [],
200
  "source": [
201
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
202
- "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
203
  "cursor = conn.cursor() "
204
  ]
205
  },
206
  {
207
  "cell_type": "code",
208
- "execution_count": 8,
209
  "id": "540962b7",
210
  "metadata": {},
211
  "outputs": [],
@@ -235,7 +257,7 @@
235
  },
236
  {
237
  "cell_type": "code",
238
- "execution_count": 9,
239
  "id": "8df7171e",
240
  "metadata": {},
241
  "outputs": [],
@@ -260,12 +282,12 @@
260
  },
261
  {
262
  "cell_type": "code",
263
- "execution_count": 10,
264
  "id": "b66d3dbb",
265
  "metadata": {},
266
  "outputs": [],
267
  "source": [
268
- "conn = sqlite3.connect('patient_slim.sqlite')\n",
269
  "cursor = conn.cursor()\n",
270
  "# pull all data from the lab table except for the \"key\" column \n",
271
  "cursor.execute(\"SELECT * FROM pharmacy;\")\n",
@@ -276,19 +298,19 @@
276
  },
277
  {
278
  "cell_type": "code",
279
- "execution_count": 11,
280
  "id": "435b8d4e",
281
  "metadata": {},
282
  "outputs": [],
283
  "source": [
284
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
285
- "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
286
  "cursor = conn.cursor() "
287
  ]
288
  },
289
  {
290
  "cell_type": "code",
291
- "execution_count": 12,
292
  "id": "b3753eeb",
293
  "metadata": {},
294
  "outputs": [],
@@ -322,7 +344,7 @@
322
  },
323
  {
324
  "cell_type": "code",
325
- "execution_count": 13,
326
  "id": "8b8ed08a",
327
  "metadata": {},
328
  "outputs": [],
@@ -348,12 +370,12 @@
348
  },
349
  {
350
  "cell_type": "code",
351
- "execution_count": 14,
352
  "id": "2de65432",
353
  "metadata": {},
354
  "outputs": [],
355
  "source": [
356
- "conn = sqlite3.connect('patient_slim.sqlite')\n",
357
  "cursor = conn.cursor()\n",
358
  "# pull all data from the lab table except for the \"key\" column \n",
359
  "cursor.execute(\"SELECT * FROM demographics;\")\n",
@@ -364,19 +386,454 @@
364
  },
365
  {
366
  "cell_type": "code",
367
- "execution_count": 15,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  "id": "a7a10f4f",
369
  "metadata": {},
370
  "outputs": [],
371
  "source": [
372
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
373
- "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
374
  "cursor = conn.cursor() "
375
  ]
376
  },
377
  {
378
  "cell_type": "code",
379
- "execution_count": 16,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  "id": "947c63d5",
381
  "metadata": {},
382
  "outputs": [],
@@ -386,6 +843,7 @@
386
  "cursor.execute('DROP TABLE IF EXISTS demographics;')\n",
387
  "cursor.execute('''\n",
388
  "CREATE TABLE demographics (\n",
 
389
  " PatientPKHash TEXT,\n",
390
  " MFLCode TEXT,\n",
391
  " FacilityName TEXT,\n",
@@ -403,14 +861,13 @@
403
  " AsOfDate TEXT,\n",
404
  " LoadDate TEXT,\n",
405
  " StartARTDate TEXT,\n",
406
- " DOB TEXT,\n",
407
- " key TEXT\n",
408
  ");\n",
409
  "''')\n",
410
  "\n",
411
  "# let's now populate the table with the rows variable that contains all the data from the visits table\n",
412
  "cursor.executemany('''\n",
413
- "INSERT INTO demographics (PatientPKHash, MFLCode, FacilityName, County, SubCounty, PartnerName, AgencyName, Sex, MaritalStatus, EducationLevel, Occupation, OnIPT, AgeGroup, ARTOutcomeDescription, AsOfDate, LoadDate, StartARTDate, DOB, key)\n",
414
  "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n",
415
  "''', rows)\n",
416
  "conn.commit()"
@@ -418,7 +875,7 @@
418
  },
419
  {
420
  "cell_type": "code",
421
- "execution_count": 17,
422
  "id": "9cff0d90",
423
  "metadata": {},
424
  "outputs": [],
@@ -458,7 +915,7 @@
458
  ],
459
  "metadata": {
460
  "kernelspec": {
461
- "display_name": ".venv",
462
  "language": "python",
463
  "name": "python3"
464
  },
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 19,
6
  "id": "ddb26634",
7
  "metadata": {},
8
  "outputs": [],
 
10
  "import sqlite3\n",
11
  "import pandas as pd\n",
12
  "# inspect current database schema\n",
13
+ "conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
14
  "cursor = conn.cursor()\n",
15
  "# list tables\n",
16
  "# pull all data from the visits table \n",
 
22
  },
23
  {
24
  "cell_type": "code",
25
+ "execution_count": 20,
26
  "id": "cd4faa4b",
27
  "metadata": {},
28
  "outputs": [],
29
  "source": [
30
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
31
+ "conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
32
  "cursor = conn.cursor() "
33
  ]
34
  },
35
  {
36
  "cell_type": "code",
37
  "execution_count": 3,
38
+ "id": "866e707d",
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "(271, 25)\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "# extract everything from the visits table\n",
51
+ "cursor.execute(\"SELECT * FROM clinical_visits;\")\n",
52
+ "rows = cursor.fetchall()\n",
53
+ "visits_df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
54
+ "print(visits_df.shape)"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
  "id": "f8547b78",
61
  "metadata": {},
62
  "outputs": [],
 
104
  },
105
  {
106
  "cell_type": "code",
107
+ "execution_count": 5,
108
  "id": "9ddfa626",
109
  "metadata": {},
110
  "outputs": [
111
  {
112
  "data": {
113
  "text/plain": [
114
+ "<sqlite3.Cursor at 0x71e4721907c0>"
115
  ]
116
  },
117
+ "execution_count": 5,
118
  "metadata": {},
119
  "output_type": "execute_result"
120
  }
 
132
  },
133
  {
134
  "cell_type": "code",
135
+ "execution_count": 6,
136
  "id": "d14ef687",
137
  "metadata": {},
138
  "outputs": [],
 
199
  },
200
  {
201
  "cell_type": "code",
202
+ "execution_count": 7,
203
  "id": "6e27bce5",
204
  "metadata": {},
205
  "outputs": [],
206
  "source": [
207
+ "conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
208
  "cursor = conn.cursor()\n",
209
  "# pull all data from the lab table except for the \"key\" column \n",
210
  "cursor.execute(\"SELECT * FROM lab;\")\n",
 
215
  },
216
  {
217
  "cell_type": "code",
218
+ "execution_count": 8,
219
  "id": "14402e96",
220
  "metadata": {},
221
  "outputs": [],
222
  "source": [
223
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
224
+ "conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
225
  "cursor = conn.cursor() "
226
  ]
227
  },
228
  {
229
  "cell_type": "code",
230
+ "execution_count": 9,
231
  "id": "540962b7",
232
  "metadata": {},
233
  "outputs": [],
 
257
  },
258
  {
259
  "cell_type": "code",
260
+ "execution_count": 10,
261
  "id": "8df7171e",
262
  "metadata": {},
263
  "outputs": [],
 
282
  },
283
  {
284
  "cell_type": "code",
285
+ "execution_count": 11,
286
  "id": "b66d3dbb",
287
  "metadata": {},
288
  "outputs": [],
289
  "source": [
290
+ "conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
291
  "cursor = conn.cursor()\n",
292
  "# pull all data from the lab table except for the \"key\" column \n",
293
  "cursor.execute(\"SELECT * FROM pharmacy;\")\n",
 
298
  },
299
  {
300
  "cell_type": "code",
301
+ "execution_count": 12,
302
  "id": "435b8d4e",
303
  "metadata": {},
304
  "outputs": [],
305
  "source": [
306
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
307
+ "conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
308
  "cursor = conn.cursor() "
309
  ]
310
  },
311
  {
312
  "cell_type": "code",
313
+ "execution_count": 13,
314
  "id": "b3753eeb",
315
  "metadata": {},
316
  "outputs": [],
 
344
  },
345
  {
346
  "cell_type": "code",
347
+ "execution_count": 14,
348
  "id": "8b8ed08a",
349
  "metadata": {},
350
  "outputs": [],
 
370
  },
371
  {
372
  "cell_type": "code",
373
+ "execution_count": 24,
374
  "id": "2de65432",
375
  "metadata": {},
376
  "outputs": [],
377
  "source": [
378
+ "conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
379
  "cursor = conn.cursor()\n",
380
  "# pull all data from the lab table except for the \"key\" column \n",
381
  "cursor.execute(\"SELECT * FROM demographics;\")\n",
 
386
  },
387
  {
388
  "cell_type": "code",
389
+ "execution_count": 27,
390
+ "id": "f3a11ac1",
391
+ "metadata": {},
392
+ "outputs": [
393
+ {
394
+ "data": {
395
+ "text/html": [
396
+ "<div>\n",
397
+ "<style scoped>\n",
398
+ " .dataframe tbody tr th:only-of-type {\n",
399
+ " vertical-align: middle;\n",
400
+ " }\n",
401
+ "\n",
402
+ " .dataframe tbody tr th {\n",
403
+ " vertical-align: top;\n",
404
+ " }\n",
405
+ "\n",
406
+ " .dataframe thead th {\n",
407
+ " text-align: right;\n",
408
+ " }\n",
409
+ "</style>\n",
410
+ "<table border=\"1\" class=\"dataframe\">\n",
411
+ " <thead>\n",
412
+ " <tr style=\"text-align: right;\">\n",
413
+ " <th></th>\n",
414
+ " <th>key</th>\n",
415
+ " <th>PatientPKHash</th>\n",
416
+ " <th>MFLCode</th>\n",
417
+ " <th>FacilityName</th>\n",
418
+ " <th>County</th>\n",
419
+ " <th>SubCounty</th>\n",
420
+ " <th>PartnerName</th>\n",
421
+ " <th>AgencyName</th>\n",
422
+ " <th>Sex</th>\n",
423
+ " <th>MaritalStatus</th>\n",
424
+ " <th>EducationLevel</th>\n",
425
+ " <th>Occupation</th>\n",
426
+ " <th>OnIPT</th>\n",
427
+ " <th>AgeGroup</th>\n",
428
+ " <th>ARTOutcomeDescription</th>\n",
429
+ " <th>AsOfDate</th>\n",
430
+ " <th>LoadDate</th>\n",
431
+ " <th>StartARTDate</th>\n",
432
+ " <th>DOB</th>\n",
433
+ " </tr>\n",
434
+ " </thead>\n",
435
+ " <tbody>\n",
436
+ " <tr>\n",
437
+ " <th>0</th>\n",
438
+ " <td>07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8...</td>\n",
439
+ " <td>3</td>\n",
440
+ " <td>13703</td>\n",
441
+ " <td>Kisii Teaching and Referral Hospital (Level 6)</td>\n",
442
+ " <td>Kisii</td>\n",
443
+ " <td>Kitutu Chache South</td>\n",
444
+ " <td>LVCT Vukisha 95</td>\n",
445
+ " <td>CDC</td>\n",
446
+ " <td>Female</td>\n",
447
+ " <td>Single</td>\n",
448
+ " <td>NULL</td>\n",
449
+ " <td>NULL</td>\n",
450
+ " <td>NULL</td>\n",
451
+ " <td>NULL</td>\n",
452
+ " <td>LOST IN HMIS</td>\n",
453
+ " <td>20088</td>\n",
454
+ " <td>20161</td>\n",
455
+ " <td>2012-04-12 00:00:00.000</td>\n",
456
+ " <td>2010-05-10 00:00:00.000</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <th>1</th>\n",
460
+ " <td>290D316E1B41A21F58E780026971F4D86DBB3BF043A77B...</td>\n",
461
+ " <td>4</td>\n",
462
+ " <td>13028</td>\n",
463
+ " <td>Kibera Community Health Centre - Amref</td>\n",
464
+ " <td>Nairobi</td>\n",
465
+ " <td>Kibra</td>\n",
466
+ " <td>CIHEB CONNECT</td>\n",
467
+ " <td>CDC</td>\n",
468
+ " <td>Female</td>\n",
469
+ " <td>MARRIED MONOGAMOUS</td>\n",
470
+ " <td>SECONDARY</td>\n",
471
+ " <td>Trader</td>\n",
472
+ " <td>NULL</td>\n",
473
+ " <td>NULL</td>\n",
474
+ " <td>ACTIVE</td>\n",
475
+ " <td>20088</td>\n",
476
+ " <td>20161</td>\n",
477
+ " <td>2009-05-12 00:00:00.000</td>\n",
478
+ " <td>1970-08-25 00:00:00.000</td>\n",
479
+ " </tr>\n",
480
+ " <tr>\n",
481
+ " <th>2</th>\n",
482
+ " <td>45889B18F2C615A78371E1DAFC2680C0A36284C6195885...</td>\n",
483
+ " <td>9</td>\n",
484
+ " <td>15834</td>\n",
485
+ " <td>Busia County Referral Hospital</td>\n",
486
+ " <td>Busia</td>\n",
487
+ " <td>Matayos</td>\n",
488
+ " <td>USAID Dumisha Afya</td>\n",
489
+ " <td>USAID</td>\n",
490
+ " <td>Female</td>\n",
491
+ " <td>NULL</td>\n",
492
+ " <td>NULL</td>\n",
493
+ " <td>NULL</td>\n",
494
+ " <td>NULL</td>\n",
495
+ " <td>NULL</td>\n",
496
+ " <td>ACTIVE</td>\n",
497
+ " <td>20088</td>\n",
498
+ " <td>20161</td>\n",
499
+ " <td>2014-08-12 00:00:00.000</td>\n",
500
+ " <td>1972-04-13 00:00:00.000</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <th>3</th>\n",
504
+ " <td>9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772...</td>\n",
505
+ " <td>7</td>\n",
506
+ " <td>14831</td>\n",
507
+ " <td>Kericho District Hospital</td>\n",
508
+ " <td>Kericho</td>\n",
509
+ " <td>Ainamoi</td>\n",
510
+ " <td>HJF-South Rift Valley</td>\n",
511
+ " <td>DOD</td>\n",
512
+ " <td>Female</td>\n",
513
+ " <td>MARRIED MONOGAMOUS</td>\n",
514
+ " <td>NULL</td>\n",
515
+ " <td>Trader</td>\n",
516
+ " <td>NULL</td>\n",
517
+ " <td>NULL</td>\n",
518
+ " <td>ACTIVE</td>\n",
519
+ " <td>20088</td>\n",
520
+ " <td>20161</td>\n",
521
+ " <td>2023-05-10 00:00:00.000</td>\n",
522
+ " <td>1989-06-15 00:00:00.000</td>\n",
523
+ " </tr>\n",
524
+ " <tr>\n",
525
+ " <th>4</th>\n",
526
+ " <td>A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C...</td>\n",
527
+ " <td>1</td>\n",
528
+ " <td>11259</td>\n",
529
+ " <td>Bomu Medical Centre (Likoni)</td>\n",
530
+ " <td>Mombasa</td>\n",
531
+ " <td>Likoni</td>\n",
532
+ " <td>Mkomani Clinic society</td>\n",
533
+ " <td>CDC</td>\n",
534
+ " <td>Female</td>\n",
535
+ " <td>MARRIED MONOGAMOUS</td>\n",
536
+ " <td>PRIMARY</td>\n",
537
+ " <td>Trader</td>\n",
538
+ " <td>NULL</td>\n",
539
+ " <td>NULL</td>\n",
540
+ " <td>UNDOCUMENTED LOSS</td>\n",
541
+ " <td>20088</td>\n",
542
+ " <td>20161</td>\n",
543
+ " <td>2018-05-22 00:00:00.000</td>\n",
544
+ " <td>1995-05-21 00:00:00.000</td>\n",
545
+ " </tr>\n",
546
+ " </tbody>\n",
547
+ "</table>\n",
548
+ "</div>"
549
+ ],
550
+ "text/plain": [
551
+ " key PatientPKHash MFLCode \\\n",
552
+ "0 07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8... 3 13703 \n",
553
+ "1 290D316E1B41A21F58E780026971F4D86DBB3BF043A77B... 4 13028 \n",
554
+ "2 45889B18F2C615A78371E1DAFC2680C0A36284C6195885... 9 15834 \n",
555
+ "3 9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772... 7 14831 \n",
556
+ "4 A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C... 1 11259 \n",
557
+ "\n",
558
+ " FacilityName County \\\n",
559
+ "0 Kisii Teaching and Referral Hospital (Level 6) Kisii \n",
560
+ "1 Kibera Community Health Centre - Amref Nairobi \n",
561
+ "2 Busia County Referral Hospital Busia \n",
562
+ "3 Kericho District Hospital Kericho \n",
563
+ "4 Bomu Medical Centre (Likoni) Mombasa \n",
564
+ "\n",
565
+ " SubCounty PartnerName AgencyName Sex \\\n",
566
+ "0 Kitutu Chache South LVCT Vukisha 95 CDC Female \n",
567
+ "1 Kibra CIHEB CONNECT CDC Female \n",
568
+ "2 Matayos USAID Dumisha Afya USAID Female \n",
569
+ "3 Ainamoi HJF-South Rift Valley DOD Female \n",
570
+ "4 Likoni Mkomani Clinic society CDC Female \n",
571
+ "\n",
572
+ " MaritalStatus EducationLevel Occupation OnIPT AgeGroup \\\n",
573
+ "0 Single NULL NULL NULL NULL \n",
574
+ "1 MARRIED MONOGAMOUS SECONDARY Trader NULL NULL \n",
575
+ "2 NULL NULL NULL NULL NULL \n",
576
+ "3 MARRIED MONOGAMOUS NULL Trader NULL NULL \n",
577
+ "4 MARRIED MONOGAMOUS PRIMARY Trader NULL NULL \n",
578
+ "\n",
579
+ " ARTOutcomeDescription AsOfDate LoadDate StartARTDate \\\n",
580
+ "0 LOST IN HMIS 20088 20161 2012-04-12 00:00:00.000 \n",
581
+ "1 ACTIVE 20088 20161 2009-05-12 00:00:00.000 \n",
582
+ "2 ACTIVE 20088 20161 2014-08-12 00:00:00.000 \n",
583
+ "3 ACTIVE 20088 20161 2023-05-10 00:00:00.000 \n",
584
+ "4 UNDOCUMENTED LOSS 20088 20161 2018-05-22 00:00:00.000 \n",
585
+ "\n",
586
+ " DOB \n",
587
+ "0 2010-05-10 00:00:00.000 \n",
588
+ "1 1970-08-25 00:00:00.000 \n",
589
+ "2 1972-04-13 00:00:00.000 \n",
590
+ "3 1989-06-15 00:00:00.000 \n",
591
+ "4 1995-05-21 00:00:00.000 "
592
+ ]
593
+ },
594
+ "execution_count": 27,
595
+ "metadata": {},
596
+ "output_type": "execute_result"
597
+ }
598
+ ],
599
+ "source": [
600
+ "df.head()"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": 30,
606
  "id": "a7a10f4f",
607
  "metadata": {},
608
  "outputs": [],
609
  "source": [
610
  "# let's create a new sqlite database called patient_demonstration.sqlite\n",
611
+ "conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
612
  "cursor = conn.cursor() "
613
  ]
614
  },
615
  {
616
  "cell_type": "code",
617
+ "execution_count": 32,
618
+ "id": "07296631",
619
+ "metadata": {},
620
+ "outputs": [
621
+ {
622
+ "data": {
623
+ "text/html": [
624
+ "<div>\n",
625
+ "<style scoped>\n",
626
+ " .dataframe tbody tr th:only-of-type {\n",
627
+ " vertical-align: middle;\n",
628
+ " }\n",
629
+ "\n",
630
+ " .dataframe tbody tr th {\n",
631
+ " vertical-align: top;\n",
632
+ " }\n",
633
+ "\n",
634
+ " .dataframe thead th {\n",
635
+ " text-align: right;\n",
636
+ " }\n",
637
+ "</style>\n",
638
+ "<table border=\"1\" class=\"dataframe\">\n",
639
+ " <thead>\n",
640
+ " <tr style=\"text-align: right;\">\n",
641
+ " <th></th>\n",
642
+ " <th>key</th>\n",
643
+ " <th>PatientPKHash</th>\n",
644
+ " <th>MFLCode</th>\n",
645
+ " <th>FacilityName</th>\n",
646
+ " <th>County</th>\n",
647
+ " <th>SubCounty</th>\n",
648
+ " <th>PartnerName</th>\n",
649
+ " <th>AgencyName</th>\n",
650
+ " <th>Sex</th>\n",
651
+ " <th>MaritalStatus</th>\n",
652
+ " <th>EducationLevel</th>\n",
653
+ " <th>Occupation</th>\n",
654
+ " <th>OnIPT</th>\n",
655
+ " <th>AgeGroup</th>\n",
656
+ " <th>ARTOutcomeDescription</th>\n",
657
+ " <th>AsOfDate</th>\n",
658
+ " <th>LoadDate</th>\n",
659
+ " <th>StartARTDate</th>\n",
660
+ " <th>DOB</th>\n",
661
+ " </tr>\n",
662
+ " </thead>\n",
663
+ " <tbody>\n",
664
+ " <tr>\n",
665
+ " <th>0</th>\n",
666
+ " <td>07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8...</td>\n",
667
+ " <td>3</td>\n",
668
+ " <td>13703</td>\n",
669
+ " <td>Kisii Teaching and Referral Hospital (Level 6)</td>\n",
670
+ " <td>Kisii</td>\n",
671
+ " <td>Kitutu Chache South</td>\n",
672
+ " <td>LVCT Vukisha 95</td>\n",
673
+ " <td>CDC</td>\n",
674
+ " <td>Female</td>\n",
675
+ " <td>Single</td>\n",
676
+ " <td>NULL</td>\n",
677
+ " <td>NULL</td>\n",
678
+ " <td>NULL</td>\n",
679
+ " <td>NULL</td>\n",
680
+ " <td>LOST IN HMIS</td>\n",
681
+ " <td>20088</td>\n",
682
+ " <td>20161</td>\n",
683
+ " <td>2012-04-12 00:00:00.000</td>\n",
684
+ " <td>2010-05-10 00:00:00.000</td>\n",
685
+ " </tr>\n",
686
+ " <tr>\n",
687
+ " <th>1</th>\n",
688
+ " <td>290D316E1B41A21F58E780026971F4D86DBB3BF043A77B...</td>\n",
689
+ " <td>4</td>\n",
690
+ " <td>13028</td>\n",
691
+ " <td>Kibera Community Health Centre - Amref</td>\n",
692
+ " <td>Nairobi</td>\n",
693
+ " <td>Kibra</td>\n",
694
+ " <td>CIHEB CONNECT</td>\n",
695
+ " <td>CDC</td>\n",
696
+ " <td>Female</td>\n",
697
+ " <td>MARRIED MONOGAMOUS</td>\n",
698
+ " <td>SECONDARY</td>\n",
699
+ " <td>Trader</td>\n",
700
+ " <td>NULL</td>\n",
701
+ " <td>NULL</td>\n",
702
+ " <td>ACTIVE</td>\n",
703
+ " <td>20088</td>\n",
704
+ " <td>20161</td>\n",
705
+ " <td>2009-05-12 00:00:00.000</td>\n",
706
+ " <td>1970-08-25 00:00:00.000</td>\n",
707
+ " </tr>\n",
708
+ " <tr>\n",
709
+ " <th>2</th>\n",
710
+ " <td>45889B18F2C615A78371E1DAFC2680C0A36284C6195885...</td>\n",
711
+ " <td>9</td>\n",
712
+ " <td>15834</td>\n",
713
+ " <td>Busia County Referral Hospital</td>\n",
714
+ " <td>Busia</td>\n",
715
+ " <td>Matayos</td>\n",
716
+ " <td>USAID Dumisha Afya</td>\n",
717
+ " <td>USAID</td>\n",
718
+ " <td>Female</td>\n",
719
+ " <td>NULL</td>\n",
720
+ " <td>NULL</td>\n",
721
+ " <td>NULL</td>\n",
722
+ " <td>NULL</td>\n",
723
+ " <td>NULL</td>\n",
724
+ " <td>ACTIVE</td>\n",
725
+ " <td>20088</td>\n",
726
+ " <td>20161</td>\n",
727
+ " <td>2014-08-12 00:00:00.000</td>\n",
728
+ " <td>1972-04-13 00:00:00.000</td>\n",
729
+ " </tr>\n",
730
+ " <tr>\n",
731
+ " <th>3</th>\n",
732
+ " <td>9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772...</td>\n",
733
+ " <td>7</td>\n",
734
+ " <td>14831</td>\n",
735
+ " <td>Kericho District Hospital</td>\n",
736
+ " <td>Kericho</td>\n",
737
+ " <td>Ainamoi</td>\n",
738
+ " <td>HJF-South Rift Valley</td>\n",
739
+ " <td>DOD</td>\n",
740
+ " <td>Female</td>\n",
741
+ " <td>MARRIED MONOGAMOUS</td>\n",
742
+ " <td>NULL</td>\n",
743
+ " <td>Trader</td>\n",
744
+ " <td>NULL</td>\n",
745
+ " <td>NULL</td>\n",
746
+ " <td>ACTIVE</td>\n",
747
+ " <td>20088</td>\n",
748
+ " <td>20161</td>\n",
749
+ " <td>2023-05-10 00:00:00.000</td>\n",
750
+ " <td>1989-06-15 00:00:00.000</td>\n",
751
+ " </tr>\n",
752
+ " <tr>\n",
753
+ " <th>4</th>\n",
754
+ " <td>A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C...</td>\n",
755
+ " <td>1</td>\n",
756
+ " <td>11259</td>\n",
757
+ " <td>Bomu Medical Centre (Likoni)</td>\n",
758
+ " <td>Mombasa</td>\n",
759
+ " <td>Likoni</td>\n",
760
+ " <td>Mkomani Clinic society</td>\n",
761
+ " <td>CDC</td>\n",
762
+ " <td>Female</td>\n",
763
+ " <td>MARRIED MONOGAMOUS</td>\n",
764
+ " <td>PRIMARY</td>\n",
765
+ " <td>Trader</td>\n",
766
+ " <td>NULL</td>\n",
767
+ " <td>NULL</td>\n",
768
+ " <td>UNDOCUMENTED LOSS</td>\n",
769
+ " <td>20088</td>\n",
770
+ " <td>20161</td>\n",
771
+ " <td>2018-05-22 00:00:00.000</td>\n",
772
+ " <td>1995-05-21 00:00:00.000</td>\n",
773
+ " </tr>\n",
774
+ " </tbody>\n",
775
+ "</table>\n",
776
+ "</div>"
777
+ ],
778
+ "text/plain": [
779
+ " key PatientPKHash MFLCode \\\n",
780
+ "0 07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8... 3 13703 \n",
781
+ "1 290D316E1B41A21F58E780026971F4D86DBB3BF043A77B... 4 13028 \n",
782
+ "2 45889B18F2C615A78371E1DAFC2680C0A36284C6195885... 9 15834 \n",
783
+ "3 9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772... 7 14831 \n",
784
+ "4 A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C... 1 11259 \n",
785
+ "\n",
786
+ " FacilityName County \\\n",
787
+ "0 Kisii Teaching and Referral Hospital (Level 6) Kisii \n",
788
+ "1 Kibera Community Health Centre - Amref Nairobi \n",
789
+ "2 Busia County Referral Hospital Busia \n",
790
+ "3 Kericho District Hospital Kericho \n",
791
+ "4 Bomu Medical Centre (Likoni) Mombasa \n",
792
+ "\n",
793
+ " SubCounty PartnerName AgencyName Sex \\\n",
794
+ "0 Kitutu Chache South LVCT Vukisha 95 CDC Female \n",
795
+ "1 Kibra CIHEB CONNECT CDC Female \n",
796
+ "2 Matayos USAID Dumisha Afya USAID Female \n",
797
+ "3 Ainamoi HJF-South Rift Valley DOD Female \n",
798
+ "4 Likoni Mkomani Clinic society CDC Female \n",
799
+ "\n",
800
+ " MaritalStatus EducationLevel Occupation OnIPT AgeGroup \\\n",
801
+ "0 Single NULL NULL NULL NULL \n",
802
+ "1 MARRIED MONOGAMOUS SECONDARY Trader NULL NULL \n",
803
+ "2 NULL NULL NULL NULL NULL \n",
804
+ "3 MARRIED MONOGAMOUS NULL Trader NULL NULL \n",
805
+ "4 MARRIED MONOGAMOUS PRIMARY Trader NULL NULL \n",
806
+ "\n",
807
+ " ARTOutcomeDescription AsOfDate LoadDate StartARTDate \\\n",
808
+ "0 LOST IN HMIS 20088 20161 2012-04-12 00:00:00.000 \n",
809
+ "1 ACTIVE 20088 20161 2009-05-12 00:00:00.000 \n",
810
+ "2 ACTIVE 20088 20161 2014-08-12 00:00:00.000 \n",
811
+ "3 ACTIVE 20088 20161 2023-05-10 00:00:00.000 \n",
812
+ "4 UNDOCUMENTED LOSS 20088 20161 2018-05-22 00:00:00.000 \n",
813
+ "\n",
814
+ " DOB \n",
815
+ "0 2010-05-10 00:00:00.000 \n",
816
+ "1 1970-08-25 00:00:00.000 \n",
817
+ "2 1972-04-13 00:00:00.000 \n",
818
+ "3 1989-06-15 00:00:00.000 \n",
819
+ "4 1995-05-21 00:00:00.000 "
820
+ ]
821
+ },
822
+ "execution_count": 32,
823
+ "metadata": {},
824
+ "output_type": "execute_result"
825
+ }
826
+ ],
827
+ "source": [
828
+ "cursor.execute(\"select * from demographics;\")\n",
829
+ "rows = cursor.fetchall()\n",
830
+ "df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
831
+ "df.head()"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": 31,
837
  "id": "947c63d5",
838
  "metadata": {},
839
  "outputs": [],
 
843
  "cursor.execute('DROP TABLE IF EXISTS demographics;')\n",
844
  "cursor.execute('''\n",
845
  "CREATE TABLE demographics (\n",
846
+ " key TEXT,\n",
847
  " PatientPKHash TEXT,\n",
848
  " MFLCode TEXT,\n",
849
  " FacilityName TEXT,\n",
 
861
  " AsOfDate TEXT,\n",
862
  " LoadDate TEXT,\n",
863
  " StartARTDate TEXT,\n",
864
+ " DOB TEXT\n",
 
865
  ");\n",
866
  "''')\n",
867
  "\n",
868
  "# let's now populate the table with the rows variable that contains all the data from the visits table\n",
869
  "cursor.executemany('''\n",
870
+ "INSERT INTO demographics (key, PatientPKHash, MFLCode, FacilityName, County, SubCounty, PartnerName, AgencyName, Sex, MaritalStatus, EducationLevel, Occupation, OnIPT, AgeGroup, ARTOutcomeDescription, AsOfDate, LoadDate, StartARTDate, DOB)\n",
871
  "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n",
872
  "''', rows)\n",
873
  "conn.commit()"
 
875
  },
876
  {
877
  "cell_type": "code",
878
+ "execution_count": 18,
879
  "id": "9cff0d90",
880
  "metadata": {},
881
  "outputs": [],
 
915
  ],
916
  "metadata": {
917
  "kernelspec": {
918
+ "display_name": "clinician-assistant-lg",
919
  "language": "python",
920
  "name": "python3"
921
  },
notebooks/create_slim_patient_db.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 20,
6
  "id": "c867740b",
7
  "metadata": {},
8
  "outputs": [],
@@ -10,7 +10,7 @@
10
  "import sqlite3\n",
11
  "import pandas as pd\n",
12
  "# inspect current database schema\n",
13
- "conn = sqlite3.connect('iit_test.sqlite')\n",
14
  "cursor = conn.cursor()\n",
15
  "# list tables\n",
16
  "# pull all data from the visits table \n",
@@ -22,7 +22,7 @@
22
  },
23
  {
24
  "cell_type": "code",
25
- "execution_count": 21,
26
  "id": "f424fcf6",
27
  "metadata": {},
28
  "outputs": [
@@ -30,7 +30,7 @@
30
  "name": "stderr",
31
  "output_type": "stream",
32
  "text": [
33
- "/tmp/ipykernel_2997/3546205200.py:11: SettingWithCopyWarning: \n",
34
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
35
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
36
  "\n",
@@ -53,35 +53,14 @@
53
  "sampled_df['PatientPKHash'] = sampled_df['PatientPKHash'].map(key_to_number)\n",
54
  "\n",
55
  "# save sampled_df back to iit_test.sqlite as a new table called sampled_visits\n",
56
- "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
57
  "sampled_df.to_sql('visits', sampled_conn, if_exists='replace', index=False)\n",
58
  "sampled_conn.close()"
59
  ]
60
  },
61
  {
62
  "cell_type": "code",
63
- "execution_count": 23,
64
- "id": "8615f9fa",
65
- "metadata": {},
66
- "outputs": [
67
- {
68
- "data": {
69
- "text/plain": [
70
- "(271, 25)"
71
- ]
72
- },
73
- "execution_count": 23,
74
- "metadata": {},
75
- "output_type": "execute_result"
76
- }
77
- ],
78
- "source": [
79
- "sampled_df.shape"
80
- ]
81
- },
82
- {
83
- "cell_type": "code",
84
- "execution_count": 24,
85
  "id": "1bad1098",
86
  "metadata": {},
87
  "outputs": [
@@ -89,7 +68,7 @@
89
  "name": "stderr",
90
  "output_type": "stream",
91
  "text": [
92
- "/tmp/ipykernel_2997/4153193150.py:11: SettingWithCopyWarning: \n",
93
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
94
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
95
  "\n",
@@ -100,7 +79,7 @@
100
  ],
101
  "source": [
102
  "# now, read in pharmacy table from iit_test.sqlite\n",
103
- "conn = sqlite3.connect('iit_test.sqlite')\n",
104
  "cursor = conn.cursor()\n",
105
  "cursor.execute(\"SELECT * FROM pharmacy;\")\n",
106
  "rows = cursor.fetchall()\n",
@@ -110,46 +89,14 @@
110
  "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_pharmacy\n",
111
  "sampled_pharmacy_df = pharmacy_df[pharmacy_df['PatientPKHash'].isin(sampled_keys)]\n",
112
  "sampled_pharmacy_df['PatientPKHash'] = sampled_pharmacy_df['PatientPKHash'].map(key_to_number)\n",
113
- "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
114
  "sampled_pharmacy_df.to_sql('pharmacy', sampled_conn, if_exists='replace', index=False)\n",
115
  "sampled_conn.close()\n"
116
  ]
117
  },
118
  {
119
  "cell_type": "code",
120
- "execution_count": 25,
121
- "id": "bc8fac93",
122
- "metadata": {},
123
- "outputs": [
124
- {
125
- "data": {
126
- "text/plain": [
127
- "PatientPKHash\n",
128
- "1 14\n",
129
- "2 24\n",
130
- "3 24\n",
131
- "4 9\n",
132
- "5 40\n",
133
- "6 1\n",
134
- "7 15\n",
135
- "8 1\n",
136
- "9 64\n",
137
- "10 14\n",
138
- "dtype: int64"
139
- ]
140
- },
141
- "execution_count": 25,
142
- "metadata": {},
143
- "output_type": "execute_result"
144
- }
145
- ],
146
- "source": [
147
- "sampled_pharmacy_df.groupby('PatientPKHash').size()"
148
- ]
149
- },
150
- {
151
- "cell_type": "code",
152
- "execution_count": 26,
153
  "id": "df01b886",
154
  "metadata": {},
155
  "outputs": [
@@ -157,7 +104,7 @@
157
  "name": "stderr",
158
  "output_type": "stream",
159
  "text": [
160
- "/tmp/ipykernel_2997/3478231606.py:11: SettingWithCopyWarning: \n",
161
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
162
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
163
  "\n",
@@ -168,7 +115,7 @@
168
  ],
169
  "source": [
170
  "# repeat the process above for lab table\n",
171
- "conn = sqlite3.connect('iit_test.sqlite')\n",
172
  "cursor = conn.cursor()\n",
173
  "cursor.execute(\"SELECT * FROM lab;\")\n",
174
  "rows = cursor.fetchall()\n",
@@ -178,46 +125,14 @@
178
  "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_lab\n",
179
  "sampled_lab_df = lab_df[lab_df['PatientPKHash'].isin(sampled_keys)]\n",
180
  "sampled_lab_df['PatientPKHash'] = sampled_lab_df['PatientPKHash'].map(key_to_number)\n",
181
- "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
182
  "sampled_lab_df.to_sql('lab', sampled_conn, if_exists='replace', index=False)\n",
183
  "sampled_conn.close()\n"
184
  ]
185
  },
186
  {
187
  "cell_type": "code",
188
- "execution_count": 27,
189
- "id": "2578bf85",
190
- "metadata": {},
191
- "outputs": [
192
- {
193
- "data": {
194
- "text/plain": [
195
- "PatientPKHash\n",
196
- "1 6\n",
197
- "2 2\n",
198
- "3 17\n",
199
- "4 22\n",
200
- "5 23\n",
201
- "6 1\n",
202
- "7 2\n",
203
- "8 10\n",
204
- "9 13\n",
205
- "10 12\n",
206
- "dtype: int64"
207
- ]
208
- },
209
- "execution_count": 27,
210
- "metadata": {},
211
- "output_type": "execute_result"
212
- }
213
- ],
214
- "source": [
215
- "sampled_lab_df.groupby('PatientPKHash').size()"
216
- ]
217
- },
218
- {
219
- "cell_type": "code",
220
- "execution_count": 28,
221
  "id": "ebf358c5",
222
  "metadata": {},
223
  "outputs": [
@@ -225,7 +140,7 @@
225
  "name": "stderr",
226
  "output_type": "stream",
227
  "text": [
228
- "/tmp/ipykernel_2997/3867144072.py:11: SettingWithCopyWarning: \n",
229
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
230
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
231
  "\n",
@@ -236,7 +151,7 @@
236
  ],
237
  "source": [
238
  "# now, from dem table\n",
239
- "conn = sqlite3.connect('iit_test.sqlite')\n",
240
  "cursor = conn.cursor()\n",
241
  "cursor.execute(\"SELECT * FROM dem;\")\n",
242
  "rows = cursor.fetchall()\n",
@@ -246,47 +161,15 @@
246
  "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_dem\n",
247
  "sampled_dem_df = dem_df[dem_df['PatientPKHash'].isin(sampled_keys)]\n",
248
  "sampled_dem_df['PatientPKHash'] = sampled_dem_df['PatientPKHash'].map(key_to_number)\n",
249
- "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
250
  "sampled_dem_df.to_sql('demographics', sampled_conn, if_exists='replace', index=False)\n",
251
  "sampled_conn.close()"
252
  ]
253
- },
254
- {
255
- "cell_type": "code",
256
- "execution_count": 29,
257
- "id": "527420fa",
258
- "metadata": {},
259
- "outputs": [
260
- {
261
- "data": {
262
- "text/plain": [
263
- "PatientPKHash\n",
264
- "1 1\n",
265
- "2 1\n",
266
- "3 1\n",
267
- "4 1\n",
268
- "5 1\n",
269
- "6 1\n",
270
- "7 1\n",
271
- "8 1\n",
272
- "9 1\n",
273
- "10 1\n",
274
- "dtype: int64"
275
- ]
276
- },
277
- "execution_count": 29,
278
- "metadata": {},
279
- "output_type": "execute_result"
280
- }
281
- ],
282
- "source": [
283
- "sampled_dem_df.groupby('PatientPKHash').size()"
284
- ]
285
  }
286
  ],
287
  "metadata": {
288
  "kernelspec": {
289
- "display_name": ".venv",
290
  "language": "python",
291
  "name": "python3"
292
  },
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 4,
6
  "id": "c867740b",
7
  "metadata": {},
8
  "outputs": [],
 
10
  "import sqlite3\n",
11
  "import pandas as pd\n",
12
  "# inspect current database schema\n",
13
+ "conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
14
  "cursor = conn.cursor()\n",
15
  "# list tables\n",
16
  "# pull all data from the visits table \n",
 
22
  },
23
  {
24
  "cell_type": "code",
25
+ "execution_count": 5,
26
  "id": "f424fcf6",
27
  "metadata": {},
28
  "outputs": [
 
30
  "name": "stderr",
31
  "output_type": "stream",
32
  "text": [
33
+ "/tmp/ipykernel_12725/435846127.py:11: SettingWithCopyWarning: \n",
34
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
35
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
36
  "\n",
 
53
  "sampled_df['PatientPKHash'] = sampled_df['PatientPKHash'].map(key_to_number)\n",
54
  "\n",
55
  "# save sampled_df back to iit_test.sqlite as a new table called sampled_visits\n",
56
+ "sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
57
  "sampled_df.to_sql('visits', sampled_conn, if_exists='replace', index=False)\n",
58
  "sampled_conn.close()"
59
  ]
60
  },
61
  {
62
  "cell_type": "code",
63
+ "execution_count": 6,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  "id": "1bad1098",
65
  "metadata": {},
66
  "outputs": [
 
68
  "name": "stderr",
69
  "output_type": "stream",
70
  "text": [
71
+ "/tmp/ipykernel_12725/2381592446.py:11: SettingWithCopyWarning: \n",
72
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
73
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
74
  "\n",
 
79
  ],
80
  "source": [
81
  "# now, read in pharmacy table from iit_test.sqlite\n",
82
+ "conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
83
  "cursor = conn.cursor()\n",
84
  "cursor.execute(\"SELECT * FROM pharmacy;\")\n",
85
  "rows = cursor.fetchall()\n",
 
89
  "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_pharmacy\n",
90
  "sampled_pharmacy_df = pharmacy_df[pharmacy_df['PatientPKHash'].isin(sampled_keys)]\n",
91
  "sampled_pharmacy_df['PatientPKHash'] = sampled_pharmacy_df['PatientPKHash'].map(key_to_number)\n",
92
+ "sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
93
  "sampled_pharmacy_df.to_sql('pharmacy', sampled_conn, if_exists='replace', index=False)\n",
94
  "sampled_conn.close()\n"
95
  ]
96
  },
97
  {
98
  "cell_type": "code",
99
+ "execution_count": 7,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  "id": "df01b886",
101
  "metadata": {},
102
  "outputs": [
 
104
  "name": "stderr",
105
  "output_type": "stream",
106
  "text": [
107
+ "/tmp/ipykernel_12725/4028870248.py:11: SettingWithCopyWarning: \n",
108
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
109
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
110
  "\n",
 
115
  ],
116
  "source": [
117
  "# repeat the process above for lab table\n",
118
+ "conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
119
  "cursor = conn.cursor()\n",
120
  "cursor.execute(\"SELECT * FROM lab;\")\n",
121
  "rows = cursor.fetchall()\n",
 
125
  "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_lab\n",
126
  "sampled_lab_df = lab_df[lab_df['PatientPKHash'].isin(sampled_keys)]\n",
127
  "sampled_lab_df['PatientPKHash'] = sampled_lab_df['PatientPKHash'].map(key_to_number)\n",
128
+ "sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
129
  "sampled_lab_df.to_sql('lab', sampled_conn, if_exists='replace', index=False)\n",
130
  "sampled_conn.close()\n"
131
  ]
132
  },
133
  {
134
  "cell_type": "code",
135
+ "execution_count": 8,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  "id": "ebf358c5",
137
  "metadata": {},
138
  "outputs": [
 
140
  "name": "stderr",
141
  "output_type": "stream",
142
  "text": [
143
+ "/tmp/ipykernel_12725/696424165.py:11: SettingWithCopyWarning: \n",
144
  "A value is trying to be set on a copy of a slice from a DataFrame.\n",
145
  "Try using .loc[row_indexer,col_indexer] = value instead\n",
146
  "\n",
 
151
  ],
152
  "source": [
153
  "# now, from dem table\n",
154
+ "conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
155
  "cursor = conn.cursor()\n",
156
  "cursor.execute(\"SELECT * FROM dem;\")\n",
157
  "rows = cursor.fetchall()\n",
 
161
  "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_dem\n",
162
  "sampled_dem_df = dem_df[dem_df['PatientPKHash'].isin(sampled_keys)]\n",
163
  "sampled_dem_df['PatientPKHash'] = sampled_dem_df['PatientPKHash'].map(key_to_number)\n",
164
+ "sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
165
  "sampled_dem_df.to_sql('demographics', sampled_conn, if_exists='replace', index=False)\n",
166
  "sampled_conn.close()"
167
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  }
169
  ],
170
  "metadata": {
171
  "kernelspec": {
172
+ "display_name": "clinician-assistant-lg",
173
  "language": "python",
174
  "name": "python3"
175
  },
chat.py → scripts/chat.py RENAMED
File without changes
scripts/evaluate_trulens.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from dotenv import load_dotenv
4
+ import os
5
+
6
+ from llama_index.core import StorageContext, load_index_from_storage, QueryBundle
7
+ from llama_index.core.retrievers import VectorIndexRetriever
8
+ from llama_index.core.postprocessor import LLMRerank
9
+ from llama_index.embeddings.openai import OpenAIEmbedding
10
+ from llama_index.llms.openai import OpenAI
11
+
12
+ from langchain_openai import ChatOpenAI
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+
15
+ from trulens_eval import Tru
16
+ from trulens.core import Feedback
17
+ from trulens.providers.openai import OpenAI as OpenAIFeedbackProvider
18
+ from trulens_eval.tru_app import TruLlama
19
+
20
+ # Load environment
21
+ if os.path.exists("config.env"):
22
+ load_dotenv("config.env")
23
+
24
+ # Load vectorstore metadata
25
+ embeddings = np.load("data/processed/lp/summary_embeddings/embeddings.npy")
26
+ df = pd.read_csv("data/processed/lp/summary_embeddings/index.tsv", sep="\t")
27
+
28
+ # LLMs and components
29
+ embedding_model = OpenAIEmbedding()
30
+ llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
31
+ reranker = LLMRerank(llm=llm_llama, top_n=3)
32
+
33
+ # langchain summarize LLM
34
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.0)
35
+
36
+ grounded = Feedback(Groundedness()).on_input().on_context().with_name("faithfulness")
37
+ context_rel = Feedback(Relevance()).on_input().on_context().with_name("context_relevance")
38
+ answer_rel = Feedback(AnswerRelevance()).on_input().on_output().with_name("answer_relevance")
39
+
40
+
41
+ # Prompt for query expansion
42
+ query_expansion_prompt = ChatPromptTemplate.from_messages([
43
+ ("system", "You are an expert in HIV medicine."),
44
+ ("user", (
45
+ "Given the query below, provide a concise, comma-separated list of related terms and synonyms "
46
+ "useful for document retrieval. Return only the list, no explanations.\n\n"
47
+ "Query: {query}"
48
+ ))
49
+ ])
50
+
51
+ # ---------- Functions ----------
52
+
53
+ def cosine_similarity_numpy(query_vec: np.ndarray, matrix: np.ndarray) -> np.ndarray:
54
+ query_norm = query_vec / np.linalg.norm(query_vec)
55
+ matrix_norm = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
56
+ return matrix_norm @ query_norm
57
+
58
+ def expand_query(query, llm, prompt_template):
59
+ messages = prompt_template.format_messages(query=query)
60
+ return llm.invoke(messages).content.strip()
61
+
62
+ def retrieve_contexts(expanded_query, embeddings, df, embedding_model):
63
+ query_vec = embedding_model.get_text_embedding(expanded_query)
64
+ sims = cosine_similarity_numpy(query_vec, embeddings)
65
+ top_indices = sims.argsort()[-3:][::-1]
66
+ paths = df.loc[top_indices, "vectorestore_path"].tolist()
67
+
68
+ all_nodes = []
69
+ for path in paths:
70
+ ctx = StorageContext.from_defaults(persist_dir=path)
71
+ index = load_index_from_storage(ctx)
72
+ retriever = VectorIndexRetriever(index=index, similarity_top_k=3)
73
+ all_nodes.extend(retriever.retrieve(expanded_query))
74
+
75
+ reranked = reranker.postprocess_nodes(all_nodes, QueryBundle(expanded_query))
76
+ return [n.text for n in reranked]
77
+
78
+ def summarize(query, contexts, llm):
79
+ prompt = (
80
+ "You're a clinical assistant helping a provider answer a question using HIV/AIDS guidelines.\n\n"
81
+ f"Question: {query}\n\n"
82
+ "Provide a detailed summary of the most relevant points to the user question from the following source texts. Use bullet points.\n\n"
83
+ + "\n\n".join([f"Source {i+1}: {t}" for i, t in enumerate(contexts)])
84
+ )
85
+ return llm.invoke(prompt).content.strip()
86
+
87
+ # ---------- RAG Pipeline ----------
88
+
89
+ def custom_rag_app(query):
90
+ expanded = expand_query(query, llm, query_expansion_prompt)
91
+ contexts = retrieve_contexts(expanded, embeddings, df, embedding_model)
92
+ answer = summarize(query, contexts, llm)
93
+ return {
94
+ "question": query,
95
+ "expanded_query": expanded,
96
+ "contexts": contexts,
97
+ "answer": answer
98
+ }
99
+
100
+
101
+ # ---------- Feedbacks ----------
102
+
103
+ provider = OpenAIFeedbackProvider()
104
+
105
+ f_grounded = Feedback(provider.groundedness).on_input().on_context().with_name("faithfulness")
106
+ f_context_rel = Feedback(provider.context_relevance).on_input().on_context().with_name("context_relevance")
107
+ f_answer_rel = Feedback(provider.relevance).on_input().on_output().with_name("answer_relevance")
108
+
109
+ # ---------- TruLens App ----------
110
+
111
+ tru_llama = TruLlama(
112
+ app=custom_rag_app,
113
+ feedbacks=[f_grounded, f_context_rel, f_answer_rel],
114
+ app_id="evaluate-trulens-llama-v2"
115
+ )
116
+
117
+ tru = Tru()
118
+
119
+ # ---------- Run Evaluation ----------
120
+
121
+ test_queries = [
122
+ "What are important drug interactions with dolutegravir?",
123
+ "How should PrEP be provided to adolescent girls?",
124
+ "When is cotrimoxazole prophylaxis indicated?",
125
+ "What are the guidelines for ART failure?",
126
+ "How do you manage HIV in pregnancy?"
127
+ ]
128
+
129
+ records = []
130
+
131
+ for q in test_queries:
132
+ record = tru_llama.run_with_record(question=q)
133
+ fb = record["feedback"]
134
+ records.append({
135
+ "question": q,
136
+ "answer": record["output"],
137
+ "contexts": record["context"],
138
+ "faithfulness_score": fb["faithfulness"].get("score"),
139
+ "context_relevance_score": fb["context_relevance"].get("score"),
140
+ "answer_relevance_score": fb["answer_relevance"].get("score"),
141
+ "faithfulness_justification": fb["faithfulness"].get("justification", "")
142
+ })
143
+
144
+ df = pd.DataFrame(records)
145
+ df.to_csv("trulens_llama_eval_results.csv", index=False)
146
+ print("✅ Evaluation complete. Saved to trulens_llama_eval_results.csv")
147
+ print(df)
{chatlib → scripts}/patient_sql_agent.py RENAMED
File without changes
scripts/ragas_eval.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_rag_with_ragas.py
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from datasets import Dataset
6
+ from ragas.evaluation import evaluate
7
+ from ragas.metrics import (
8
+ faithfulness,
9
+ answer_relevancy,
10
+ context_precision,
11
+ context_recall
12
+ )
13
+ from llama_index.core import StorageContext, load_index_from_storage, QueryBundle
14
+ from llama_index.core.retrievers import VectorIndexRetriever
15
+ from llama_index.core.postprocessor import LLMRerank
16
+ from llama_index.embeddings.openai import OpenAIEmbedding
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain.chat_models import ChatOpenAI
19
+ from llama_index.llms.openai import OpenAI
20
+ import os
21
+ from dotenv import load_dotenv
22
+ if os.path.exists("config.env"):
23
+ load_dotenv("config.env")
24
+
25
+ embeddings = np.load("data/processed/lp/summary_embeddings/embeddings.npy")
26
+ df = pd.read_csv("data/processed/lp/summary_embeddings/index.tsv", sep="\t")
27
+
28
+ embedding_model = OpenAIEmbedding()
29
+
30
+ # Define your reranker-compatible LLM
31
+ llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
32
+
33
+ # Create LLM reranker
34
+ reranker = LLMRerank(llm=llm_llama, top_n=3)
35
+
36
+ # summarizer LLM
37
+ llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
38
+
39
+ # Define a prompt template for query expansion
40
+ query_expansion_prompt = ChatPromptTemplate.from_messages([
41
+ ("system", "You are an expert in HIV medicine."),
42
+ ("user", (
43
+ "Given the query below, provide a concise, comma-separated list of related terms and synonyms "
44
+ "useful for document retrieval. Return only the list, no explanations.\n\n"
45
+ "Query: {query}"
46
+ ))
47
+ ])
48
+
49
+ def cosine_similarity_numpy(query_vec: np.ndarray, matrix: np.ndarray) -> np.ndarray:
50
+ query_norm = query_vec / np.linalg.norm(query_vec)
51
+ matrix_norm = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
52
+ return matrix_norm @ query_norm
53
+
54
+
55
+ def expand_query(query, llm, prompt_template):
56
+ messages = prompt_template.format_messages(query=query)
57
+ return llm.invoke(messages).content.strip()
58
+
59
+ def retrieve_contexts(expanded_query, embeddings, df, embedding_model):
60
+ query_vec = embedding_model.get_text_embedding(expanded_query)
61
+ similarities = cosine_similarity_numpy(query_vec, embeddings)
62
+ top_indices = similarities.argsort()[-3:][::-1]
63
+ paths = df.loc[top_indices, "vectorestore_path"].tolist()
64
+ print(paths)
65
+ all_nodes = []
66
+ for path in paths:
67
+ ctx = StorageContext.from_defaults(persist_dir=path)
68
+ index = load_index_from_storage(ctx)
69
+ retriever = VectorIndexRetriever(index=index, similarity_top_k=3)
70
+ all_nodes.extend(retriever.retrieve(expanded_query))
71
+
72
+ return [n.text for n in LLMRerank(llm=llm_llama, top_n=3).postprocess_nodes(all_nodes, QueryBundle(expanded_query))]
73
+
74
+ def summarize(query, contexts, llm):
75
+ prompt = (
76
+ "You're a clinical assistant helping a provider answer a question using HIV/AIDS guidelines.\n\n"
77
+ f"Question: {query}\n\n"
78
+ "Provide a detailed summary of the most relevant points from the following source texts using bullet points.\n\n"
79
+ + "\n\n".join([f"Source {i+1}: {text}" for i, text in enumerate(contexts)])
80
+ )
81
+ return llm.invoke(prompt).content.strip()
82
+
83
+ # Run on test queries
84
+ test_queries = [
85
+ "What are important drug interactions with dolutegravir?",
86
+ "How should PrEP be provided to adolescent girls?",
87
+ "When is cotrimoxazole prophylaxis indicated?",
88
+ "What are the guidelines for ART failure?",
89
+ "How do you manage HIV in pregnancy?"
90
+ ]
91
+ results = []
92
+
93
+ for q in test_queries:
94
+ print(f"⏳ Processing: {q}")
95
+ expanded = expand_query(q, llm, query_expansion_prompt)
96
+ contexts = retrieve_contexts(expanded, embeddings, df, embedding_model)
97
+ answer = summarize(q, contexts, llm)
98
+ results.append({
99
+ "question": q,
100
+ "contexts": contexts,
101
+ "answer": answer
102
+ })
103
+
104
+ # --- Ragas Evaluation ---
105
+ print("✅ Running Ragas evaluation...")
106
+
107
+ ragas_data = Dataset.from_list(results)
108
+
109
+ eval_results = evaluate(
110
+ ragas_data,
111
+ metrics=[faithfulness, answer_relevancy]
112
+ )
113
+
114
+ df_eval = eval_results.to_pandas()
115
+ df_eval.to_csv("ragas_eval_results.csv", index=False)
116
+
117
+ print("✅ Evaluation complete. Saved to ragas_eval_results.csv")
118
+ print(df_eval)