JDFPalladium commited on
Commit
389c5f0
Β·
1 Parent(s): 7bc9486

cleaning up organization of scripts and data and updating filepaths in app to processed data

Browse files
chatlib/guidlines_rag_agent_li.py CHANGED
@@ -9,8 +9,8 @@ import pandas as pd
9
  from llama_index.embeddings.openai import OpenAIEmbedding
10
 
11
  # load vectorstore summaries
12
- embeddings = np.load("guidance_docs/lp/summary_embeddings/embeddings.npy")
13
- df = pd.read_csv("guidance_docs/lp/summary_embeddings/index.tsv", sep="\t")
14
 
15
  embedding_model = OpenAIEmbedding()
16
 
 
9
  from llama_index.embeddings.openai import OpenAIEmbedding
10
 
11
  # load vectorstore summaries
12
+ embeddings = np.load("data/processed/lp/summary_embeddings/embeddings.npy")
13
+ df = pd.read_csv("data/processed/lp/summary_embeddings/index.tsv", sep="\t")
14
 
15
  embedding_model = OpenAIEmbedding()
16
 
chatlib/idsr_check.py CHANGED
@@ -14,17 +14,17 @@ import sqlite3
14
  # import os
15
 
16
 
17
- with open("./guidance_docs/idsr_keywords.txt", "r", encoding="utf-8") as f:
18
  keywords = [line.strip() for line in f if line.strip()]
19
 
20
  vectorstore = FAISS.load_local(
21
- "./guidance_docs/disease_vectorstore",
22
  OpenAIEmbeddings(),
23
  allow_dangerous_deserialization=True,
24
  )
25
 
26
 
27
- with open("./guidance_docs/tagged_documents.json", "r", encoding="utf-8") as f:
28
  doc_dicts = json.load(f)
29
 
30
  tagged_documents = [Document(**d) for d in doc_dicts]
@@ -138,7 +138,7 @@ def idsr_check(query: str, llm, sitecode) -> AppState:
138
  # first, get sitecode from environment variable
139
  # sitecode = os.environ.get("SITECODE")
140
  # next, connect to location database and get county where code = sitecode
141
- conn = sqlite3.connect("data/location_data.sqlite")
142
  county_cursor = conn.cursor()
143
  county_cursor.execute(
144
  "SELECT County FROM sitecode_county_xwalk WHERE Code = ?", (sitecode,)
 
14
  # import os
15
 
16
 
17
+ with open("./data/processed/idsr_keywords.txt", "r", encoding="utf-8") as f:
18
  keywords = [line.strip() for line in f if line.strip()]
19
 
20
  vectorstore = FAISS.load_local(
21
+ "./data/processed/disease_vectorstore",
22
  OpenAIEmbeddings(),
23
  allow_dangerous_deserialization=True,
24
  )
25
 
26
 
27
+ with open("./data/processed/tagged_documents.json", "r", encoding="utf-8") as f:
28
  doc_dicts = json.load(f)
29
 
30
  tagged_documents = [Document(**d) for d in doc_dicts]
 
138
  # first, get sitecode from environment variable
139
  # sitecode = os.environ.get("SITECODE")
140
  # next, connect to location database and get county where code = sitecode
141
+ conn = sqlite3.connect("data/processed/location_data.sqlite")
142
  county_cursor = conn.cursor()
143
  county_cursor.execute(
144
  "SELECT County FROM sitecode_county_xwalk WHERE Code = ?", (sitecode,)
chatlib/patient_all_data.py CHANGED
@@ -41,7 +41,7 @@ def sql_chain(query: str, llm, rag_result: str, pk_hash: str) -> dict:
41
  if not pk_hash:
42
  raise ValueError("pk_hash is required in state for SQL queries.")
43
 
44
- conn = sqlite3.connect("data/patient_demonstration.sqlite")
45
  cursor = conn.cursor()
46
 
47
  cursor.execute(
 
41
  if not pk_hash:
42
  raise ValueError("pk_hash is required in state for SQL queries.")
43
 
44
+ conn = sqlite3.connect("data/processed/patient_demonstration.sqlite")
45
  cursor = conn.cursor()
46
 
47
  cursor.execute(
chatlib/phi_filter.py CHANGED
@@ -4,7 +4,7 @@ import re
4
  from .helpers import dateparser_detect, describe_relative_date
5
 
6
 
7
- def load_kenyan_names(filepath="data/kenyan_names.txt"):
8
  if not Path(filepath).exists():
9
  return set()
10
  with open(filepath, "r", encoding="utf-8") as f:
 
4
  from .helpers import dateparser_detect, describe_relative_date
5
 
6
 
7
+ def load_kenyan_names(filepath="data/processed/kenyan_names.txt"):
8
  if not Path(filepath).exists():
9
  return set()
10
  with open(filepath, "r", encoding="utf-8") as f:
notebooks/create_location_db.ipynb ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 8,
6
+ "id": "1c8c38eb",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import sqlite3\n",
11
+ "import pandas as pd\n",
12
+ "\n",
13
+ "# read in kenya_disease_county_matrix.csv and sitecode_county_xwalk.csv\n",
14
+ "disease_df = pd.read_csv('kenya_disease_county_matrix.csv')\n",
15
+ "xwalk_df = pd.read_csv('sitecode_county_xwalk.csv')\n",
16
+ "rainy_df = pd.read_csv('kenya_counties_rainy_seasons.csv')\n",
17
+ "who_df = pd.read_csv('who_bulletin.csv')\n"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 9,
23
+ "id": "f0c63494",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "# create sqlite database\n",
28
+ "conn = sqlite3.connect('location_data.sqlite')\n",
29
+ "cursor = conn.cursor()\n",
30
+ "\n",
31
+ "# add each dataframe to a table in the database\n",
32
+ "disease_df.to_sql('county_disease_info', conn, if_exists='replace', index=False)\n",
33
+ "xwalk_df.to_sql('sitecode_county_xwalk', conn, if_exists='replace', index=False)\n",
34
+ "rainy_df.to_sql('county_rainy_seasons', conn, if_exists='replace', index=False)\n",
35
+ "who_df.to_sql('who_bulletin', conn, if_exists='replace', index=False)\n",
36
+ "\n",
37
+ "# commit changes and close connection\n",
38
+ "conn.commit()\n",
39
+ "conn.close()"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 7,
45
+ "id": "c12e58cf",
46
+ "metadata": {},
47
+ "outputs": [
48
+ {
49
+ "name": "stdout",
50
+ "output_type": "stream",
51
+ "text": [
52
+ "['County', 'Disease', 'Prevalence Level', 'Notes']\n"
53
+ ]
54
+ }
55
+ ],
56
+ "source": [
57
+ "# get table in location_data.sqlite and show column names\n",
58
+ "import sqlite3\n",
59
+ "conn = sqlite3.connect('location_data.sqlite')\n",
60
+ "cursor = conn.cursor()\n",
61
+ "cursor.execute(\"SELECT * FROM county_disease_info;\")\n",
62
+ "tables = cursor.fetchall()\n",
63
+ "columns = [column[0] for column in cursor.description]\n",
64
+ "print(columns)"
65
+ ]
66
+ }
67
+ ],
68
+ "metadata": {
69
+ "kernelspec": {
70
+ "display_name": ".venv",
71
+ "language": "python",
72
+ "name": "python3"
73
+ },
74
+ "language_info": {
75
+ "codemirror_mode": {
76
+ "name": "ipython",
77
+ "version": 3
78
+ },
79
+ "file_extension": ".py",
80
+ "mimetype": "text/x-python",
81
+ "name": "python",
82
+ "nbconvert_exporter": "python",
83
+ "pygments_lexer": "ipython3",
84
+ "version": "3.12.1"
85
+ }
86
+ },
87
+ "nbformat": 4,
88
+ "nbformat_minor": 5
89
+ }
notebooks/create_patient_db.ipynb ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "ddb26634",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import sqlite3\n",
11
+ "import pandas as pd\n",
12
+ "# inspect current database schema\n",
13
+ "conn = sqlite3.connect('patient_slim.sqlite')\n",
14
+ "cursor = conn.cursor()\n",
15
+ "# list tables\n",
16
+ "# pull all data from the visits table \n",
17
+ "cursor.execute(\"SELECT * FROM visits;\")\n",
18
+ "rows = cursor.fetchall()\n",
19
+ "df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
20
+ "conn.close()"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 2,
26
+ "id": "cd4faa4b",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# let's create a new sqlite database called patient_demonstration.sqlite\n",
31
+ "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
32
+ "cursor = conn.cursor() "
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 3,
38
+ "id": "f8547b78",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "# create a table called clinical_visits with the column names pulled above\n",
43
+ "# overwite the table if it already exists\n",
44
+ "cursor.execute('DROP TABLE IF EXISTS clinical_visits;')\n",
45
+ "cursor.execute('''\n",
46
+ "CREATE TABLE clinical_visits (\n",
47
+ " PatientPKHash TEXT,\n",
48
+ " SiteCode TEXT,\n",
49
+ " VisitDate TEXT,\n",
50
+ " VisitType TEXT,\n",
51
+ " VisitBy TEXT,\n",
52
+ " NextAppointmentDate TEXT,\n",
53
+ " TCAReason TEXT,\n",
54
+ " Pregnant TEXT,\n",
55
+ " Breastfeeding TEXT,\n",
56
+ " StabilityAssessment TEXT,\n",
57
+ " DifferentiatedCare TEXT,\n",
58
+ " WHOStage INTEGER,\n",
59
+ " WHOStagingOI TEXT,\n",
60
+ " Height REAL,\n",
61
+ " Weight REAL, \n",
62
+ " EMR TEXT,\n",
63
+ " Project TEXT,\n",
64
+ " Adherence TEXT,\n",
65
+ " AdherenceCategory TEXT,\n",
66
+ " BP TEXT,\n",
67
+ " OI TEXT,\n",
68
+ " OIDate DATE,\n",
69
+ " CurrentRegimen TEXT,\n",
70
+ " AppointmentReminderWillingness TEXT,\n",
71
+ " key TEXT\n",
72
+ ");\n",
73
+ "''')\n",
74
+ "\n",
75
+ "# let's now populate the table with the rows variable that contains all the data from the visits table\n",
76
+ "cursor.executemany('''\n",
77
+ "INSERT INTO clinical_visits (PatientPKHash, SiteCode, VisitDate, VisitType, VisitBy, NextAppointmentDate, TCAReason, Pregnant, Breastfeeding, StabilityAssessment, DifferentiatedCare, WHOStage, WHOStagingOI, Height, Weight, EMR, Project, Adherence, AdherenceCategory, BP, OI, OIDate, CurrentRegimen, AppointmentReminderWillingness, key)\n",
78
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n",
79
+ "''', rows)\n",
80
+ "conn.commit()"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 4,
86
+ "id": "9ddfa626",
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "data": {
91
+ "text/plain": [
92
+ "<sqlite3.Cursor at 0x7d4240c3d840>"
93
+ ]
94
+ },
95
+ "execution_count": 4,
96
+ "metadata": {},
97
+ "output_type": "execute_result"
98
+ }
99
+ ],
100
+ "source": [
101
+ "# now let's create a data dictionary\n",
102
+ "cursor.execute('DROP TABLE IF EXISTS data_dictionary;')\n",
103
+ "cursor.execute('''\n",
104
+ "CREATE TABLE data_dictionary (\n",
105
+ " table_name TEXT,\n",
106
+ " column_name TEXT,\n",
107
+ " description TEXT);\n",
108
+ "''')"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 5,
114
+ "id": "d14ef687",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "# populate the data dictionary with descriptions for each column in the clinical_visits table\n",
119
+ "cursor.execute('''\n",
120
+ "INSERT INTO data_dictionary (table_name, column_name, description) VALUES\n",
121
+ "('clinical_visits', 'PatientPKHash', 'Hashed patient identifier'),\n",
122
+ "('clinical_visits', 'SiteCode', 'Code for the clinical site'),\n",
123
+ "('clinical_visits', 'VisitDate', 'Date of the patient visit'),\n",
124
+ "('clinical_visits', 'VisitType', 'Type of the patient visit. Values include Unknown, SCHEDULED VISIT, UNSCHEDULED VISIT LATE,\n",
125
+ " UNSCHEDULED VISIT EARLY, Unscheduled, Scheduled. These should typically be grouped as Scheduled and Unscheduled'),\n",
126
+ "('clinical_visits', 'VisitBy', 'Provider of the visit. Values include , Self, Treatment supporter, Refill visit documentation, Other'),\n",
127
+ "('clinical_visits', 'NextAppointmentDate', 'Date of the next scheduled clinical appointment set during VisitDate. \n",
128
+ " This is typically a date in the future after VisitDate.'),\n",
129
+ "('clinical_visits', 'TCAReason', 'Reason for the TCA (To Come Again) status. Values include, Follow up, Lab tests, Pharmacy Refill, Counseling,Other'),\n",
130
+ "('clinical_visits', 'Pregnant', 'Is the patient pregnant? Values include Yes and No.'),\n",
131
+ "('clinical_visits', 'Breastfeeding', 'Is the patient breastfeeding? Values include Yes, No and N/A'),\n",
132
+ "('clinical_visits', 'StabilityAssessment', 'Stability assessment result. Values include Stable, Unstable, and not stable.\n",
133
+ " typically, this should be grouped as Stable and Unstable (including not stable)'),\n",
134
+ "('clinical_visits', 'DifferentiatedCare', 'Differentiated care model. Values include Fast Track, Standard Care,\n",
135
+ " Community ART Distribution peer led,\n",
136
+ " Facility ART distribution Group,\n",
137
+ " Community ART Distribution HCW Led'),\n",
138
+ "('clinical_visits', 'WHOStage', 'WHO stage of the patient, either 1, 2, 3, or 4'),\n",
139
+ "('clinical_visits', 'WHOStagingOI', 'Opportunistic infection observed during WHO staging. Values include\n",
140
+ " Asymptomatic, Oral hairy leukoplakia,\n",
141
+ " Unexplained severe weight loss, Pulmonary tuberculosis,\n",
142
+ " Extra pulmonary tuberculosis,\n",
143
+ " Unexplained severe weight loss,Pulmonary tuberculosis,\n",
144
+ " Recurrent upper respiratory tract infections,\n",
145
+ " Asymptomatic,Persistent generalized lymphadenopathy),\n",
146
+ " Symptomatic HIV-associated nephropathy,\n",
147
+ " Cryptococcal meningitis, Herpes zoster,\n",
148
+ " Unexplained severe weight loss,Recurrent upper respiratory tract infections,\n",
149
+ " Persistent generalized lymphadenopathy),\n",
150
+ " Minor mucocutaneous manifestations,\n",
151
+ " Unexplained severe weight loss,Unexplained persistent fever,Pulmonary tuberculosis,\n",
152
+ " Recurrent oral ulcerations, Unexplained moderate malnutrition,\n",
153
+ " Oral candidiasis, HIV wasting syndrome,\n",
154
+ " Pulmonary tuberculosis,Oral candidiasis,\n",
155
+ " Unexplained persistent fever'),\n",
156
+ "('clinical_visits', 'Height', 'Height of the patient in centimeters'),\n",
157
+ "('clinical_visits', 'Weight', 'Weight of the patient in kilograms'),\n",
158
+ "('clinical_visits', 'EMR', 'Electronic medical record information. Values include AMRS, KenyaEMR, ECARE, DREAMS'),\n",
159
+ "('clinical_visits', 'Project', 'Project associated with the visit. Values include Ampath Plus, Kenya HMIS II, EDARP, DREAM Kenya Trusts'),\n",
160
+ "('clinical_visits', 'Adherence', 'Adherence to treatment. Values include Good, , Fair, Good|, Good|Good, Poor, Poor|Poor,\n",
161
+ " Poor|, 0, Poor|Good, Good|Poor. This variable will typically be used in combination with AdherenceCategory, and | here should align\n",
162
+ " with | in that variable, indicating two values for two categories.'),\n",
163
+ "('clinical_visits', 'AdherenceCategory', 'Category of adherence. Values include GOOD, , FAIR, ART|CTX, ARV. \n",
164
+ " GOOD and FAIR are erroneous and should be dropped when the variable is used. ART and and ARV should be\n",
165
+ " considered as ART.'),\n",
166
+ "('clinical_visits', 'BP', 'Blood pressure readings. Value reported as systolic/diastolic in mmHg, e.g., 120/80.'),\n",
167
+ "('clinical_visits', 'OI', 'Opportunistic infections present. Values include Asymptomatic, Lymphadenopathy,\n",
168
+ " Respiratory Tract Infections, Moderate Weight Loss'),\n",
169
+ "('clinical_visits', 'OIDate', 'Date of opportunistic infection diagnosis'),\n",
170
+ "('clinical_visits', 'CurrentRegimen', 'Current treatment regimen. Value includes two or three digit descriptions of molecules separated by / signs'),\n",
171
+ "('clinical_visits', 'AppointmentReminderWillingness', 'Willingness to receive appointment reminders. Values include Yes and No'),\n",
172
+ "('clinical_visits', 'key', 'Unique key for patientPKHash and SiteCode combination');\n",
173
+ "''')\n",
174
+ "conn.commit()\n",
175
+ "conn.close()"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 6,
181
+ "id": "6e27bce5",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "conn = sqlite3.connect('patient_slim.sqlite')\n",
186
+ "cursor = conn.cursor()\n",
187
+ "# pull all data from the lab table except for the \"key\" column \n",
188
+ "cursor.execute(\"SELECT * FROM lab;\")\n",
189
+ "rows = cursor.fetchall()\n",
190
+ "df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
191
+ "conn.close()"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 7,
197
+ "id": "14402e96",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# let's create a new sqlite database called patient_demonstration.sqlite\n",
202
+ "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
203
+ "cursor = conn.cursor() "
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 8,
209
+ "id": "540962b7",
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "# create a table called clinical_visits with the column names pulled above\n",
214
+ "# overwite the table if it already exists\n",
215
+ "cursor.execute('DROP TABLE IF EXISTS lab;')\n",
216
+ "cursor.execute('''\n",
217
+ "CREATE TABLE lab (\n",
218
+ " PatientPKHash TEXT,\n",
219
+ " SiteCode TEXT,\n",
220
+ " OrderedbyDate TEXT,\n",
221
+ " ReportedbyDate TEXT,\n",
222
+ " TestName TEXT,\n",
223
+ " TestResult TEXT,\n",
224
+ " key TEXT\n",
225
+ ");\n",
226
+ "''')\n",
227
+ "\n",
228
+ "# let's now populate the table with the rows variable that contains all the data from the visits table\n",
229
+ "cursor.executemany('''\n",
230
+ "INSERT INTO lab (PatientPKHash, SiteCode, OrderedbyDate, ReportedbyDate, TestName, TestResult, key)\n",
231
+ "VALUES (?, ?, ?, ?, ?, ?, ?);\n",
232
+ "''', rows)\n",
233
+ "conn.commit()"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 9,
239
+ "id": "8df7171e",
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "# now, add lab table to the data dictionary\n",
244
+ "cursor.execute('''\n",
245
+ "INSERT INTO data_dictionary (table_name, column_name, description) VALUES\n",
246
+ "('lab', 'PatientPKHash', 'Hashed patient identifier'),\n",
247
+ "('lab', 'SiteCode', 'Code for the clinical site'),\n",
248
+ "('lab', 'OrderedbyDate', 'Date when the lab test was ordered'),\n",
249
+ "('lab', 'ReportedbyDate', 'Date when the lab test result was reported'),\n",
250
+ "('lab', 'TestName', 'Name of the lab test conducted, including CD4 Count for adults,\n",
251
+ " CD4 Percentage for children, and Viral Load'),\n",
252
+ "('lab', 'TestResult', 'Result of the lab test. This will sometimes appear as numeric value\n",
253
+ " and sometimes as text. Typically, when text, the value will be \"LDL\", meaning low \n",
254
+ " detectable level, or low HIV viral load.'),\n",
255
+ "('lab', 'key', 'Unique key for PatientPKHash and SiteCode combination');\n",
256
+ "''')\n",
257
+ "conn.commit()\n",
258
+ "conn.close() "
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 10,
264
+ "id": "b66d3dbb",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "conn = sqlite3.connect('patient_slim.sqlite')\n",
269
+ "cursor = conn.cursor()\n",
270
+ "# pull all data from the lab table except for the \"key\" column \n",
271
+ "cursor.execute(\"SELECT * FROM pharmacy;\")\n",
272
+ "rows = cursor.fetchall()\n",
273
+ "df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
274
+ "conn.close()"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": 11,
280
+ "id": "435b8d4e",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "# let's create a new sqlite database called patient_demonstration.sqlite\n",
285
+ "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
286
+ "cursor = conn.cursor() "
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": 12,
292
+ "id": "b3753eeb",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "# create a table called clinical_visits with the column names pulled above\n",
297
+ "# overwite the table if it already exists\n",
298
+ "cursor.execute('DROP TABLE IF EXISTS pharmacy;')\n",
299
+ "cursor.execute('''\n",
300
+ "CREATE TABLE pharmacy (\n",
301
+ " PatientPKHash TEXT,\n",
302
+ " SiteCode TEXT,\n",
303
+ " Drug TEXT,\n",
304
+ " DispenseDate TEXT,\n",
305
+ " ExpectedReturn TEXT,\n",
306
+ " Duration INTEGER,\n",
307
+ " TreatmentType TEXT,\n",
308
+ " RegimenLine TEXT,\n",
309
+ " RegimenChangedSwitched TEXT,\n",
310
+ " RegimenChangeSwitchedReason TEXT,\n",
311
+ " key TEXT\n",
312
+ ");\n",
313
+ "''')\n",
314
+ "\n",
315
+ "# let's now populate the table with the rows variable that contains all the data from the visits table\n",
316
+ "cursor.executemany('''\n",
317
+ "INSERT INTO pharmacy (PatientPKHash, SiteCode, Drug, DispenseDate, ExpectedReturn, Duration, TreatmentType, RegimenLine, RegimenChangedSwitched, RegimenChangeSwitchedReason, key)\n",
318
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n",
319
+ "''', rows)\n",
320
+ "conn.commit()"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": 13,
326
+ "id": "8b8ed08a",
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "# now, add pharmacy table to the data dictionary\n",
331
+ "cursor.execute('''\n",
332
+ "INSERT INTO data_dictionary (table_name, column_name, description) VALUES\n",
333
+ "('pharmacy', 'PatientPKHash', 'Hashed patient identifier'),\n",
334
+ "('pharmacy', 'SiteCode', 'Code for the clinical site'),\n",
335
+ "('pharmacy', 'Drug', 'Description of the drug prescribed, reported as collection of molecules (e.g. 3TC+DTG+TDF). Most common are ARVs for HIV'),\n",
336
+ "('pharmacy', 'DispenseDate', 'Date when the drug was dispensed'),\n",
337
+ "('pharmacy', 'ExpectedReturn', 'Expected return date for the next pharmacy visit'),\n",
338
+ "('pharmacy', 'Duration', 'Duration in number of days for which the drug is prescribed. Any duration of 60 days or greater is considered a multi-month dispensing (MMD).'),\n",
339
+ "('pharmacy', 'TreatmentType', 'Type of treatment. Values include ARV, PMTCT, Prophylaxis.'),\n",
340
+ "('pharmacy', 'RegimenLine', 'Line of treatment regimen. Valid values include First Line, Second Line, Third Line'),\n",
341
+ "('pharmacy', 'RegimenChangedSwitched', 'Indicates if the regimen was changed or switched. Valid values are Switch and Substition. Otherwise, regimen was not changed.'),\n",
342
+ "('pharmacy', 'RegimenChangeSwitchedReason', 'Reason for changing or switching the regimen. Valid values include New drug available, Virological failure, Drugs out of stock, Drug toxicity, New Diagnosis of tuberculosis, and Other.'),\n",
343
+ "('pharmacy', 'key', 'Unique key for PatientPKHash and SiteCode combination');\n",
344
+ "''')\n",
345
+ "conn.commit()\n",
346
+ "conn.close()"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": 14,
352
+ "id": "2de65432",
353
+ "metadata": {},
354
+ "outputs": [],
355
+ "source": [
356
+ "conn = sqlite3.connect('patient_slim.sqlite')\n",
357
+ "cursor = conn.cursor()\n",
358
+ "# pull all data from the lab table except for the \"key\" column \n",
359
+ "cursor.execute(\"SELECT * FROM demographics;\")\n",
360
+ "rows = cursor.fetchall()\n",
361
+ "df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
362
+ "conn.close()"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": 15,
368
+ "id": "a7a10f4f",
369
+ "metadata": {},
370
+ "outputs": [],
371
+ "source": [
372
+ "# let's create a new sqlite database called patient_demonstration.sqlite\n",
373
+ "conn = sqlite3.connect('patient_demonstration.sqlite')\n",
374
+ "cursor = conn.cursor() "
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": 16,
380
+ "id": "947c63d5",
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "# create a table called clinical_visits with the column names pulled above\n",
385
+ "# overwite the table if it already exists\n",
386
+ "cursor.execute('DROP TABLE IF EXISTS demographics;')\n",
387
+ "cursor.execute('''\n",
388
+ "CREATE TABLE demographics (\n",
389
+ " PatientPKHash TEXT,\n",
390
+ " MFLCode TEXT,\n",
391
+ " FacilityName TEXT,\n",
392
+ " County TEXT,\n",
393
+ " SubCounty TEXT,\n",
394
+ " PartnerName TEXT,\n",
395
+ " AgencyName TEXT,\n",
396
+ " Sex TEXT,\n",
397
+ " MaritalStatus TEXT,\n",
398
+ " EducationLevel TEXT,\n",
399
+ " Occupation TEXT,\n",
400
+ " OnIPT TEXT,\n",
401
+ " AgeGroup TEXT,\n",
402
+ " ARTOutcomeDescription TEXT,\n",
403
+ " AsOfDate TEXT,\n",
404
+ " LoadDate TEXT,\n",
405
+ " StartARTDate TEXT,\n",
406
+ " DOB TEXT,\n",
407
+ " key TEXT\n",
408
+ ");\n",
409
+ "''')\n",
410
+ "\n",
411
+ "# let's now populate the table with the rows variable that contains all the data from the visits table\n",
412
+ "cursor.executemany('''\n",
413
+ "INSERT INTO demographics (PatientPKHash, MFLCode, FacilityName, County, SubCounty, PartnerName, AgencyName, Sex, MaritalStatus, EducationLevel, Occupation, OnIPT, AgeGroup, ARTOutcomeDescription, AsOfDate, LoadDate, StartARTDate, DOB, key)\n",
414
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n",
415
+ "''', rows)\n",
416
+ "conn.commit()"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": 17,
422
+ "id": "9cff0d90",
423
+ "metadata": {},
424
+ "outputs": [],
425
+ "source": [
426
+ "# now, add pharmacy table to the data dictionary\n",
427
+ "cursor.execute('''\n",
428
+ "INSERT INTO data_dictionary (table_name, column_name, description) VALUES\n",
429
+ "('demographics', 'PatientPKHash', 'Hashed patient identifier'),\n",
430
+ "('demographics', 'MFLCode', 'Code for the clinical site, same as SiteCode'),\n",
431
+ "('demographics', 'FacilityName', 'Name of the clinical facility'),\n",
432
+ "('demographics', 'County', 'County where the patient is located'),\n",
433
+ "('demographics', 'SubCounty', 'Sub-county where the patient is located'),\n",
434
+ "('demographics', 'PartnerName', 'Name of the implementing partner that manages the facility'),\n",
435
+ "('demographics', 'AgencyName', 'Name of the agency that supports the facility'),\n",
436
+ "('demographics', 'Sex', 'Sex of the patient. Valid values are male and female. Capitalization is not standardized so always set to lower case.'),\n",
437
+ "('demographics', 'MaritalStatus', 'Marital status of the patient. Valid values include married monogamous,\n",
438
+ " married polygamous, single, divorced, widowed, cohabiting, separated. There are also some erroneous values \n",
439
+ " that should be ignored and treated as missing.'),\n",
440
+ "('demographics', 'EducationLevel', 'Education level of the patient. Valid values primary, secondary, tertiary, none. \n",
441
+ " there is a value for NULL that should be treated as missing.'),\n",
442
+ "('demographics', 'Occupation', 'Occupation of the patient. Valid values include farmer, trader, none (for unemployed),\n",
443
+ " student, self employed, professional, employee, driver, and NULL that should be treated as missing.'),\n",
444
+ "('demographics', 'OnIPT', 'Indicates if the patient is on IPT. This is all null.'),\n",
445
+ "('demographics', 'AgeGroup', 'Age group of the patient. This is all null.'),\n",
446
+ "('demographics', 'ARTOutcomeDescription', 'Description of the ART outcome. Valid values include active, dead,\n",
447
+ " loss to follow up, transferred out, undocumented loss, and lost in hmis.'),\n",
448
+ "('demographics', 'AsOfDate', 'Date as of which the data is reported'),\n",
449
+ "('demographics', 'LoadDate', 'Date when the data was loaded'),\n",
450
+ "('demographics', 'StartARTDate', 'Date when the patient started ART'),\n",
451
+ "('demographics', 'DOB', 'Date of birth of the patient'),\n",
452
+ "('demographics', 'key', 'Unique key for PatientPKHash and MFLCode combination');\n",
453
+ "''')\n",
454
+ "conn.commit()\n",
455
+ "conn.close()"
456
+ ]
457
+ }
458
+ ],
459
+ "metadata": {
460
+ "kernelspec": {
461
+ "display_name": ".venv",
462
+ "language": "python",
463
+ "name": "python3"
464
+ },
465
+ "language_info": {
466
+ "codemirror_mode": {
467
+ "name": "ipython",
468
+ "version": 3
469
+ },
470
+ "file_extension": ".py",
471
+ "mimetype": "text/x-python",
472
+ "name": "python",
473
+ "nbconvert_exporter": "python",
474
+ "pygments_lexer": "ipython3",
475
+ "version": "3.12.1"
476
+ }
477
+ },
478
+ "nbformat": 4,
479
+ "nbformat_minor": 5
480
+ }
notebooks/create_slim_patient_db.ipynb ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 20,
6
+ "id": "c867740b",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import sqlite3\n",
11
+ "import pandas as pd\n",
12
+ "# inspect current database schema\n",
13
+ "conn = sqlite3.connect('iit_test.sqlite')\n",
14
+ "cursor = conn.cursor()\n",
15
+ "# list tables\n",
16
+ "# pull all data from the visits table \n",
17
+ "cursor.execute(\"SELECT * FROM visits;\")\n",
18
+ "rows = cursor.fetchall()\n",
19
+ "df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
20
+ "conn.close()"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 21,
26
+ "id": "f424fcf6",
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "name": "stderr",
31
+ "output_type": "stream",
32
+ "text": [
33
+ "/tmp/ipykernel_2997/3546205200.py:11: SettingWithCopyWarning: \n",
34
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
35
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
36
+ "\n",
37
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
38
+ " sampled_df['PatientPKHash'] = sampled_df['PatientPKHash'].map(key_to_number)\n"
39
+ ]
40
+ }
41
+ ],
42
+ "source": [
43
+ "# pick ten unique keys at random from df without replacement\n",
44
+ "sampled_keys = df['PatientPKHash'].drop_duplicates().sample(n=10, random_state=42).tolist()\n",
45
+ "\n",
46
+ "# filter dataframe to only include sampled keys\n",
47
+ "sampled_df = df[df['PatientPKHash'].isin(sampled_keys)]\n",
48
+ "\n",
49
+ "# create a dict with key as key and numbers 1-10 as values\n",
50
+ "key_to_number = {key: i+1 for i, key in enumerate(sampled_keys)}\n",
51
+ "\n",
52
+ "# replace key column in sampled_df with corresponding number from key_to_number\n",
53
+ "sampled_df['PatientPKHash'] = sampled_df['PatientPKHash'].map(key_to_number)\n",
54
+ "\n",
55
+ "# save sampled_df back to iit_test.sqlite as a new table called sampled_visits\n",
56
+ "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
57
+ "sampled_df.to_sql('visits', sampled_conn, if_exists='replace', index=False)\n",
58
+ "sampled_conn.close()"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 23,
64
+ "id": "8615f9fa",
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "data": {
69
+ "text/plain": [
70
+ "(271, 25)"
71
+ ]
72
+ },
73
+ "execution_count": 23,
74
+ "metadata": {},
75
+ "output_type": "execute_result"
76
+ }
77
+ ],
78
+ "source": [
79
+ "sampled_df.shape"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 24,
85
+ "id": "1bad1098",
86
+ "metadata": {},
87
+ "outputs": [
88
+ {
89
+ "name": "stderr",
90
+ "output_type": "stream",
91
+ "text": [
92
+ "/tmp/ipykernel_2997/4153193150.py:11: SettingWithCopyWarning: \n",
93
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
94
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
95
+ "\n",
96
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
97
+ " sampled_pharmacy_df['PatientPKHash'] = sampled_pharmacy_df['PatientPKHash'].map(key_to_number)\n"
98
+ ]
99
+ }
100
+ ],
101
+ "source": [
102
+ "# now, read in pharmacy table from iit_test.sqlite\n",
103
+ "conn = sqlite3.connect('iit_test.sqlite')\n",
104
+ "cursor = conn.cursor()\n",
105
+ "cursor.execute(\"SELECT * FROM pharmacy;\")\n",
106
+ "rows = cursor.fetchall()\n",
107
+ "pharmacy_df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
108
+ "conn.close()\n",
109
+ "\n",
110
+ "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_pharmacy\n",
111
+ "sampled_pharmacy_df = pharmacy_df[pharmacy_df['PatientPKHash'].isin(sampled_keys)]\n",
112
+ "sampled_pharmacy_df['PatientPKHash'] = sampled_pharmacy_df['PatientPKHash'].map(key_to_number)\n",
113
+ "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
114
+ "sampled_pharmacy_df.to_sql('pharmacy', sampled_conn, if_exists='replace', index=False)\n",
115
+ "sampled_conn.close()\n"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 25,
121
+ "id": "bc8fac93",
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "data": {
126
+ "text/plain": [
127
+ "PatientPKHash\n",
128
+ "1 14\n",
129
+ "2 24\n",
130
+ "3 24\n",
131
+ "4 9\n",
132
+ "5 40\n",
133
+ "6 1\n",
134
+ "7 15\n",
135
+ "8 1\n",
136
+ "9 64\n",
137
+ "10 14\n",
138
+ "dtype: int64"
139
+ ]
140
+ },
141
+ "execution_count": 25,
142
+ "metadata": {},
143
+ "output_type": "execute_result"
144
+ }
145
+ ],
146
+ "source": [
147
+ "sampled_pharmacy_df.groupby('PatientPKHash').size()"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 26,
153
+ "id": "df01b886",
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "name": "stderr",
158
+ "output_type": "stream",
159
+ "text": [
160
+ "/tmp/ipykernel_2997/3478231606.py:11: SettingWithCopyWarning: \n",
161
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
162
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
163
+ "\n",
164
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
165
+ " sampled_lab_df['PatientPKHash'] = sampled_lab_df['PatientPKHash'].map(key_to_number)\n"
166
+ ]
167
+ }
168
+ ],
169
+ "source": [
170
+ "# repeat the process above for lab table\n",
171
+ "conn = sqlite3.connect('iit_test.sqlite')\n",
172
+ "cursor = conn.cursor()\n",
173
+ "cursor.execute(\"SELECT * FROM lab;\")\n",
174
+ "rows = cursor.fetchall()\n",
175
+ "lab_df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
176
+ "conn.close()\n",
177
+ "\n",
178
+ "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_lab\n",
179
+ "sampled_lab_df = lab_df[lab_df['PatientPKHash'].isin(sampled_keys)]\n",
180
+ "sampled_lab_df['PatientPKHash'] = sampled_lab_df['PatientPKHash'].map(key_to_number)\n",
181
+ "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
182
+ "sampled_lab_df.to_sql('lab', sampled_conn, if_exists='replace', index=False)\n",
183
+ "sampled_conn.close()\n"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 27,
189
+ "id": "2578bf85",
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "data": {
194
+ "text/plain": [
195
+ "PatientPKHash\n",
196
+ "1 6\n",
197
+ "2 2\n",
198
+ "3 17\n",
199
+ "4 22\n",
200
+ "5 23\n",
201
+ "6 1\n",
202
+ "7 2\n",
203
+ "8 10\n",
204
+ "9 13\n",
205
+ "10 12\n",
206
+ "dtype: int64"
207
+ ]
208
+ },
209
+ "execution_count": 27,
210
+ "metadata": {},
211
+ "output_type": "execute_result"
212
+ }
213
+ ],
214
+ "source": [
215
+ "sampled_lab_df.groupby('PatientPKHash').size()"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 28,
221
+ "id": "ebf358c5",
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "name": "stderr",
226
+ "output_type": "stream",
227
+ "text": [
228
+ "/tmp/ipykernel_2997/3867144072.py:11: SettingWithCopyWarning: \n",
229
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
230
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
231
+ "\n",
232
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
233
+ " sampled_dem_df['PatientPKHash'] = sampled_dem_df['PatientPKHash'].map(key_to_number)\n"
234
+ ]
235
+ }
236
+ ],
237
+ "source": [
238
+ "# now, from dem table\n",
239
+ "conn = sqlite3.connect('iit_test.sqlite')\n",
240
+ "cursor = conn.cursor()\n",
241
+ "cursor.execute(\"SELECT * FROM dem;\")\n",
242
+ "rows = cursor.fetchall()\n",
243
+ "dem_df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
244
+ "conn.close() \n",
245
+ "\n",
246
+ "# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_dem\n",
247
+ "sampled_dem_df = dem_df[dem_df['PatientPKHash'].isin(sampled_keys)]\n",
248
+ "sampled_dem_df['PatientPKHash'] = sampled_dem_df['PatientPKHash'].map(key_to_number)\n",
249
+ "sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
250
+ "sampled_dem_df.to_sql('demographics', sampled_conn, if_exists='replace', index=False)\n",
251
+ "sampled_conn.close()"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": 29,
257
+ "id": "527420fa",
258
+ "metadata": {},
259
+ "outputs": [
260
+ {
261
+ "data": {
262
+ "text/plain": [
263
+ "PatientPKHash\n",
264
+ "1 1\n",
265
+ "2 1\n",
266
+ "3 1\n",
267
+ "4 1\n",
268
+ "5 1\n",
269
+ "6 1\n",
270
+ "7 1\n",
271
+ "8 1\n",
272
+ "9 1\n",
273
+ "10 1\n",
274
+ "dtype: int64"
275
+ ]
276
+ },
277
+ "execution_count": 29,
278
+ "metadata": {},
279
+ "output_type": "execute_result"
280
+ }
281
+ ],
282
+ "source": [
283
+ "sampled_dem_df.groupby('PatientPKHash').size()"
284
+ ]
285
+ }
286
+ ],
287
+ "metadata": {
288
+ "kernelspec": {
289
+ "display_name": ".venv",
290
+ "language": "python",
291
+ "name": "python3"
292
+ },
293
+ "language_info": {
294
+ "codemirror_mode": {
295
+ "name": "ipython",
296
+ "version": 3
297
+ },
298
+ "file_extension": ".py",
299
+ "mimetype": "text/x-python",
300
+ "name": "python",
301
+ "nbconvert_exporter": "python",
302
+ "pygments_lexer": "ipython3",
303
+ "version": "3.12.1"
304
+ }
305
+ },
306
+ "nbformat": 4,
307
+ "nbformat_minor": 5
308
+ }
notebooks/create_textrag.ipynb ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "1d13fafe",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "\n",
11
+ "import os\n",
12
+ "import asyncio\n",
13
+ "from llama_parse import LlamaParse\n",
14
+ "from llama_index.core import VectorStoreIndex\n",
15
+ "from llama_index.core.node_parser import SimpleNodeParser\n",
16
+ "from llama_index.core.schema import Document\n",
17
+ "import nest_asyncio\n",
18
+ "nest_asyncio.apply()\n",
19
+ "\n",
20
+ "from dotenv import load_dotenv\n",
21
+ "load_dotenv(\"../config.env\")\n",
22
+ "os.environ.get(\"OPENAI_API_KEY\")\n",
23
+ "os.environ.get(\"LLAMAPARSE_API_KEY\")"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "0a79afb5",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "# instantiate LlamaParse\n",
34
+ "parser = LlamaParse(\n",
35
+ " api_key=os.environ.get(\"LLAMAPARSE_API_KEY\"),\n",
36
+ " result_type=\"markdown\", # or \"text\"\n",
37
+ " extract_charts=True,\n",
38
+ " auto_mode=True,\n",
39
+ " auto_mode_trigger_on_image_in_page=True,\n",
40
+ " auto_mode_trigger_on_table_in_page=True,\n",
41
+ " bbox_top=0.05,\n",
42
+ " bbox_bottom=0.1,\n",
43
+ " verbose=True\n",
44
+ ")\n",
45
+ "\n",
46
+ "# documents = parser.load_data(f\"GuidelinesSections/Kenya-ARV-Guidelines-2022-HepB-HepC-Coinfection.pdf\")\n",
47
+ "# # Write the output to a file\n",
48
+ "# with open(\"output.md\", \"w\", encoding=\"utf-8\") as f:\n",
49
+ "# for doc in documents:\n",
50
+ "# f.write(doc.text)\n",
51
+ "# filename=\"GuidelinesSections/Kenya-ARV-Guidelines-2022-HepB-HepC-Coinfection.pdf\"\n",
52
+ "# full_text = \"\\n\\n\".join(doc.text for doc in documents)\n",
53
+ "# combined_doc = Document(text=full_text)\n",
54
+ "# node_parser = SimpleNodeParser()\n",
55
+ "# nodes = node_parser.get_nodes_from_documents([combined_doc])\n",
56
+ "# # create the index\n",
57
+ "# index = VectorStoreIndex(nodes)\n",
58
+ "# # remove \"Kenya-ARV-Guidelines-2022-\" from filename\n",
59
+ "# short_filename = filename.replace(\"GuidelinesSections/Kenya-ARV-Guidelines-2022-\",\"\").replace(\".pdf\", \"\")\n",
60
+ "# # persist the index\n",
61
+ "# index.storage_context.persist(f\"lp/indices/{short_filename}\")\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "id": "4e94da2b",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": []
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "3ea85ed0",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "# iterate through all files in guidance_docs/GuidelinesSections\n",
80
+ "# first, load the data using the parser\n",
81
+ "# then, flatted the data in each doc to create a single large doc per section\n",
82
+ "# finally, chunk the data using SentenceSplitter (tight size control)\n",
83
+ "async def parse_docs():\n",
84
+ " for filename in os.listdir(\"GuidelinesSections\"):\n",
85
+ " if filename.endswith(\".pdf\"):\n",
86
+ " documents = parser.load_data(f\"GuidelinesSections/{filename}\")\n",
87
+ " full_text = \"\\n\\n\".join(doc.text for doc in documents)\n",
88
+ " combined_doc = Document(text=full_text)\n",
89
+ " node_parser = SimpleNodeParser()\n",
90
+ " nodes = node_parser.get_nodes_from_documents([combined_doc])\n",
91
+ " # create the index\n",
92
+ " index = VectorStoreIndex(nodes)\n",
93
+ " # remove \"Kenya-ARV-Guidelines-2022-\" from filename\n",
94
+ " short_filename = filename.replace(\"Kenya-ARV-Guidelines-2022-\",\"\").replace(\".pdf\", \"\")\n",
95
+ " # persist the index\n",
96
+ " index.storage_context.persist(f\"lp/indices/{short_filename}\")\n",
97
+ " \n",
98
+ "await parse_docs()\n"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "7135ce0d",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": []
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "bfa61623",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": []
116
+ }
117
+ ],
118
+ "metadata": {
119
+ "kernelspec": {
120
+ "display_name": "clinician-assistant-lg",
121
+ "language": "python",
122
+ "name": "python3"
123
+ },
124
+ "language_info": {
125
+ "codemirror_mode": {
126
+ "name": "ipython",
127
+ "version": 3
128
+ },
129
+ "file_extension": ".py",
130
+ "mimetype": "text/x-python",
131
+ "name": "python",
132
+ "nbconvert_exporter": "python",
133
+ "pygments_lexer": "ipython3",
134
+ "version": "3.12.1"
135
+ }
136
+ },
137
+ "nbformat": 4,
138
+ "nbformat_minor": 5
139
+ }
notebooks/gen_idsr_rag.ipynb ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "da62e982",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import re\n",
11
+ "from pprint import pprint\n",
12
+ "import os\n",
13
+ "from dotenv import load_dotenv\n",
14
+ "\n",
15
+ "load_dotenv(\"../config.env\")\n",
16
+ "os.environ.get(\"OPENAI_API_KEY\")"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "7b2b560b",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "# read in IDSR.txt\n",
27
+ "with open(\"IDSR.txt\", encoding=\"utf-8\") as f:\n",
28
+ " text = f.read()"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "id": "50d72066",
34
+ "metadata": {},
35
+ "source": [
36
+ "Extract Keywords"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "id": "75a4c7bf",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "prompt = \"\"\"\n",
47
+ "You are a helpful assistant. Extract a list of 30–50 key symptoms, signs, or diagnostic terms from the following disease descriptions.\n",
48
+ "\n",
49
+ "Focus on words or phrases that are likely to appear in clinical case definitions or user queries β€” such as \"fever\", \"skin lesions\", \"swollen lymph nodes\", \"positive blood smear\", etc.\n",
50
+ "\n",
51
+ "Only return the keywords or short phrases β€” one per line.\n",
52
+ "\n",
53
+ "Text:\n",
54
+ "\"\"\""
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "4f704812",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "from openai import OpenAI\n",
65
+ "\n",
66
+ "client = OpenAI()\n",
67
+ "response = client.chat.completions.create(\n",
68
+ " model=\"gpt-4o\",\n",
69
+ " messages=[\n",
70
+ " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
71
+ " {\"role\": \"user\", \"content\": prompt + text}\n",
72
+ " ],\n",
73
+ " temperature=0.0\n",
74
+ ")\n",
75
+ "keywords = [line.strip() for line in response.choices[0].message.content.splitlines() if line.strip()]\n",
76
+ "print(\"Extracted Keywords:\")\n",
77
+ "for keyword in keywords:\n",
78
+ " print(\"-\", keyword)"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "9f698154",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "# remove dashes and normalize keywords\n",
89
+ "def normalize_kw(kw):\n",
90
+ " return kw.lstrip(\"-β€’ \").strip().lower() \n",
91
+ "keywords = [normalize_kw(kw) for kw in keywords]"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "11324098",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "# save keywords to file\n",
102
+ "with open(\"idsr_keywords.txt\", \"w\", encoding=\"utf-8\") as f:\n",
103
+ " for keyword in keywords:\n",
104
+ " f.write(f\"{keyword}\\n\")"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "add8c3fe",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "# load file\n",
115
+ "with open(\"idsr_keywords.txt\", \"r\", encoding=\"utf-8\") as f:\n",
116
+ " keywords = [line.strip() for line in f if line.strip()]"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "markdown",
121
+ "id": "1d12b253",
122
+ "metadata": {},
123
+ "source": [
124
+ "Prep each disease as a document"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "id": "2923ecab",
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "# we need to split the text into a list of dictionaries:\n",
135
+ "# the text is structured as follows:\n",
136
+ "# the section for each disease starts after an empty line.\n",
137
+ "# the disease name itself takes up the first line.\n",
138
+ "# following the disease name, there will be subsections, each one beginning with an \"-\", some text, and then a colon. \n",
139
+ "# what is between the \"-\" and the colon is the name of the subsection. the name of each subsection takes up one line.\n",
140
+ "# following this, the next few lines contains the text for that subsection. however many lines it takes up,\n",
141
+ "# this should be the value for the subsection key in the dictionary, condenses to a single string.\n",
142
+ "# some diseases have multiple subsections, while others have only one.\n",
143
+ "# when we encounter an empty line, it indicates the start of a new disease section.\n",
144
+ "# what we should produce is one dictionary per disease, with a key called disease_name and value being the name of the disease. \n",
145
+ "# the other keys should be the subsections, with the value being the text that follows the subsection name.\n",
146
+ "\n",
147
+ "def parse_disease_text(text):\n",
148
+ " diseases = []\n",
149
+ " lines = text.strip().splitlines()\n",
150
+ " \n",
151
+ " current_disease = None\n",
152
+ " current_subsection = None\n",
153
+ " buffer = []\n",
154
+ "\n",
155
+ " def finalize_subsection():\n",
156
+ " if current_disease is not None and current_subsection and buffer:\n",
157
+ " content = \" \".join(line.strip() for line in buffer).strip()\n",
158
+ " current_disease[current_subsection] = content\n",
159
+ "\n",
160
+ " subsection_pattern = re.compile(r\"^-\\s*(.+):\\s*$\")\n",
161
+ "\n",
162
+ " for line in lines + [\"\"]: # Extra empty line to trigger final save\n",
163
+ " if not line.strip():\n",
164
+ " finalize_subsection()\n",
165
+ " if current_disease:\n",
166
+ " diseases.append(current_disease)\n",
167
+ " current_disease = None\n",
168
+ " current_subsection = None\n",
169
+ " buffer = []\n",
170
+ " continue\n",
171
+ "\n",
172
+ " if current_disease is None:\n",
173
+ " current_disease = {\"disease_name\": line.strip()}\n",
174
+ " continue\n",
175
+ "\n",
176
+ " match = subsection_pattern.match(line)\n",
177
+ " if match:\n",
178
+ " finalize_subsection()\n",
179
+ " current_subsection = match.group(1).strip()\n",
180
+ " buffer = []\n",
181
+ " else:\n",
182
+ " buffer.append(line.rstrip())\n",
183
+ "\n",
184
+ " return diseases\n",
185
+ "\n",
186
+ "\n",
187
+ "\n",
188
+ "disease_dicts = parse_disease_text(text)\n",
189
+ " "
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "2fd83b33",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "from langchain_core.documents import Document\n",
200
+ "\n",
201
+ "def convert_disease_dicts_to_documents(disease_dicts):\n",
202
+ " docs = []\n",
203
+ " for disease in disease_dicts:\n",
204
+ " disease_name = disease.get(\"disease_name\", \"\")\n",
205
+ " subsections = [f\"{key}:\\n{value}\" for key, value in disease.items() if key != \"disease_name\"]\n",
206
+ " full_text = f\"Disease: {disease_name}\\n\\n\" + \"\\n\\n\".join(subsections)\n",
207
+ " docs.append(Document(page_content=full_text, metadata={\"disease_name\": disease_name}))\n",
208
+ " return docs\n"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "id": "19baadb4",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "# Step 2: Convert to LangChain documents\n",
219
+ "documents = convert_disease_dicts_to_documents(disease_dicts)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "id": "15bc8f40",
225
+ "metadata": {},
226
+ "source": [
227
+ "Tag each document with keywords"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "id": "33d70fff",
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "from rapidfuzz import fuzz\n",
238
+ "\n",
239
+ "def tag_documents_with_keywords(documents, keywords, threshold=85):\n",
240
+ " \"\"\"\n",
241
+ " Tags each Document in the list with a 'matched_keywords' metadata field\n",
242
+ " using fuzzy matching (e.g., RapidFuzz partial ratio).\n",
243
+ "\n",
244
+ " Parameters:\n",
245
+ " documents (list): List of langchain `Document` objects.\n",
246
+ " keywords (list): List of predefined clinical keywords (e.g. from GPT).\n",
247
+ " threshold (int): Similarity threshold (0–100) for fuzzy matching.\n",
248
+ "\n",
249
+ " Returns:\n",
250
+ " List of tagged Document objects with updated metadata.\n",
251
+ " \"\"\"\n",
252
+ " tagged = []\n",
253
+ "\n",
254
+ " for doc in documents:\n",
255
+ " content = doc.page_content.lower()\n",
256
+ "\n",
257
+ " # Match keywords against document content\n",
258
+ " matched = []\n",
259
+ " for kw in keywords:\n",
260
+ " kw_lower = kw.lower()\n",
261
+ " if fuzz.partial_ratio(kw_lower, content) >= threshold:\n",
262
+ " matched.append(kw)\n",
263
+ "\n",
264
+ " # Add tags to metadata\n",
265
+ " doc.metadata[\"matched_keywords\"] = matched\n",
266
+ " tagged.append(doc)\n",
267
+ "\n",
268
+ " return tagged\n",
269
+ "\n",
270
+ "tagged_documents = tag_documents_with_keywords(documents, keywords)"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": null,
276
+ "id": "b588f56e",
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "import json\n",
281
+ "\n",
282
+ "# Convert Document objects to dicts\n",
283
+ "doc_dicts = [doc.dict() for doc in tagged_documents]\n",
284
+ "\n",
285
+ "with open(\"tagged_documents.json\", \"w\", encoding=\"utf-8\") as f:\n",
286
+ " json.dump(doc_dicts, f, ensure_ascii=False, indent=2)\n"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "id": "166513b4",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "# load tagged documents from file\n",
297
+ "import json\n",
298
+ "from langchain_core.documents import Document\n",
299
+ "with open(\"tagged_documents.json\", \"r\", encoding=\"utf-8\") as f:\n",
300
+ " tagged_documents = [Document(**doc) for doc in json.load(f)]"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "7f586616",
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "tagged_documents[50]"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "markdown",
315
+ "id": "39882d72",
316
+ "metadata": {},
317
+ "source": [
318
+ "Fuzzy-match query to keywords"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "id": "db127464",
325
+ "metadata": {},
326
+ "outputs": [],
327
+ "source": [
328
+ "from rapidfuzz import fuzz\n",
329
+ "\n",
330
+ "def find_keywords_in_prompt(prompt, keywords, threshold=80):\n",
331
+ " \"\"\"\n",
332
+ " Returns all keywords that appear in the prompt using fuzzy matching.\n",
333
+ " \n",
334
+ " Args:\n",
335
+ " prompt (str): The user prompt.\n",
336
+ " keywords (list): List of keywords to match.\n",
337
+ " threshold (int): Fuzzy match threshold (0-100).\n",
338
+ " \n",
339
+ " Returns:\n",
340
+ " list: Matched keywords.\n",
341
+ " \"\"\"\n",
342
+ " prompt_lower = prompt.lower()\n",
343
+ " matched = []\n",
344
+ " for kw in keywords:\n",
345
+ " kw_lower = kw.lower()\n",
346
+ " # Use partial_ratio for substring-like matching\n",
347
+ " if fuzz.partial_ratio(kw_lower, prompt_lower) >= threshold:\n",
348
+ " matched.append(kw)\n",
349
+ " return matched\n",
350
+ "\n",
351
+ "# Example usage:\n",
352
+ "# keywords = [\"fever\", \"skin lesions\", \"swollen lymph nodes\"]\n",
353
+ "# prompt = \"The patient presents with fever and swollen nodes.\"\n",
354
+ "# print(find_keywords_in_prompt(prompt, keywords))"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "markdown",
359
+ "id": "e51dd2f1",
360
+ "metadata": {},
361
+ "source": [
362
+ "GPT to match query to keywords"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "id": "d51d699e",
369
+ "metadata": {},
370
+ "outputs": [],
371
+ "source": [
372
+ "from typing import List\n",
373
+ "from pydantic import BaseModel, Field\n",
374
+ "from langchain_core.output_parsers import PydanticOutputParser\n",
375
+ "from langchain.prompts import PromptTemplate\n",
376
+ "from langchain.chat_models import ChatOpenAI\n",
377
+ "from langchain.chains import LLMChain\n",
378
+ "\n",
379
+ "class KeywordsOutput(BaseModel):\n",
380
+ " keywords: List[str] = Field(description=\"List of relevant keywords extracted from the query\")\n",
381
+ "\n",
382
+ "def extract_keywords_with_gpt(query: str, known_keywords: List[str]) -> List[str]:\n",
383
+ " parser = PydanticOutputParser(pydantic_object=KeywordsOutput)\n",
384
+ "\n",
385
+ " prompt = PromptTemplate(\n",
386
+ " template=\"\"\"\n",
387
+ "You are helping identify relevant medical concepts. \n",
388
+ "Given this query: \"{query}\"\n",
389
+ "\n",
390
+ "Select the most relevant keywords from this list:\n",
391
+ "{keyword_list}\n",
392
+ "\n",
393
+ "Return the matching keywords as a JSON object with a single key \"keywords\" whose value is a list of strings.\n",
394
+ "\n",
395
+ "{format_instructions}\n",
396
+ "\"\"\",\n",
397
+ " input_variables=[\"query\", \"keyword_list\"],\n",
398
+ " partial_variables={\"format_instructions\": parser.get_format_instructions()},\n",
399
+ " )\n",
400
+ "\n",
401
+ " chain = LLMChain(\n",
402
+ " llm=ChatOpenAI(temperature=0, model=\"gpt-4o\"),\n",
403
+ " prompt=prompt,\n",
404
+ " output_parser=parser,\n",
405
+ " )\n",
406
+ "\n",
407
+ " output = chain.run(query=query, keyword_list=\", \".join(known_keywords))\n",
408
+ "\n",
409
+ " # output is a list of strings, not a KeywordsOutput instance\n",
410
+ " return output.keywords\n",
411
+ "\n",
412
+ "\n"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "id": "45fdb67b",
419
+ "metadata": {},
420
+ "outputs": [],
421
+ "source": [
422
+ "# matched_keywords = extract_keywords_with_gpt(query = \"child presenting with lesions\", known_keywords = keywords)\n",
423
+ "# print(\"Matched Keywords:\", matched_keywords)\n",
424
+ "type(matched_keywords)\n",
425
+ "\n"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "id": "d9c4c9bc",
431
+ "metadata": {},
432
+ "source": [
433
+ "Hybrid search using matched keywords"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "2e59aa39",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "def hybrid_search_with_query_keywords(query, vectorstore, documents, keyword_list, top_k=5):\n",
444
+ " # Step 1: Semantic search\n",
445
+ " semantic_hits = vectorstore.similarity_search(query, k=top_k)\n",
446
+ "\n",
447
+ " # Step 2: Use GPT to extract keywords from the query\n",
448
+ " matched_keywords = extract_keywords_with_gpt(query, keyword_list)\n",
449
+ "\n",
450
+ " # Step 3: Filter docs whose metadata has any of those keywords\n",
451
+ " keyword_hits = [\n",
452
+ " doc for doc in documents\n",
453
+ " if any(\n",
454
+ " normalize_kw(kw1) == normalize_kw(kw2)\n",
455
+ " for kw1 in doc.metadata.get(\"matched_keywords\", [])\n",
456
+ " for kw2 in matched_keywords\n",
457
+ " )\n",
458
+ " ]\n",
459
+ "\n",
460
+ " for kw in matched_keywords:\n",
461
+ " print(f\"Matched keyword: {kw}\")\n",
462
+ "\n",
463
+ " # print metadata of keyword_hits\n",
464
+ " for doc in keyword_hits:\n",
465
+ " print(doc.metadata.get(\"disease_name\"))\n",
466
+ " print(doc.metadata.get(\"matched_keywords\"))\n",
467
+ " print(doc.page_content)\n",
468
+ "\n",
469
+ " # Step 4: Merge by unique content\n",
470
+ " merged = {doc.page_content: doc for doc in semantic_hits + keyword_hits}\n",
471
+ " return list(merged.values()), matched_keywords\n"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": null,
477
+ "id": "b215b0fb",
478
+ "metadata": {},
479
+ "outputs": [],
480
+ "source": [
481
+ "from langchain_openai import OpenAIEmbeddings\n",
482
+ "from langchain.vectorstores import FAISS\n",
483
+ "\n",
484
+ "embedding_model = OpenAIEmbeddings()\n",
485
+ "\n",
486
+ "# `documents` is the list of LangChain Document objects from before\n",
487
+ "vectorstore = FAISS.from_documents(tagged_documents, embedding_model)\n",
488
+ "\n",
489
+ "vectorstore.save_local(\"disease_vectorstore\")"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "id": "96ffa9b2",
496
+ "metadata": {},
497
+ "outputs": [],
498
+ "source": [
499
+ "# Startup:\n",
500
+ "from langchain.vectorstores import FAISS\n",
501
+ "from langchain_openai import OpenAIEmbeddings\n",
502
+ "vectorstore = FAISS.load_local(\"disease_vectorstore\", OpenAIEmbeddings(),allow_dangerous_deserialization=True)\n",
503
+ "\n",
504
+ "# Query time:\n",
505
+ "query = \"child presenting with lesions\"\n",
506
+ "results, matched = hybrid_search_with_query_keywords(query, vectorstore, tagged_documents, keywords)\n",
507
+ "\n",
508
+ "# print(\"Matched keywords:\", matched)\n",
509
+ "# for doc in results:\n",
510
+ "# print(\"---\")\n",
511
+ "# print(doc.metadata.get(\"disease_name\"))\n",
512
+ "# print(doc.metadata.get(\"matched_keywords\"))\n",
513
+ "# print(doc.page_content)\n",
514
+ "\n",
515
+ "\n"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "id": "38fb3c90",
522
+ "metadata": {},
523
+ "outputs": [],
524
+ "source": [
525
+ "# doc=tagged_documents[0].metadata.get(\"matched_keywords\")\n",
526
+ "doc\n",
527
+ "# matched_keywords\n",
528
+ "# doc in matched_keywords\n",
529
+ "\n"
530
+ ]
531
+ },
532
+ {
533
+ "cell_type": "code",
534
+ "execution_count": null,
535
+ "id": "ed6a99f4",
536
+ "metadata": {},
537
+ "outputs": [],
538
+ "source": []
539
+ }
540
+ ],
541
+ "metadata": {
542
+ "kernelspec": {
543
+ "display_name": ".venv",
544
+ "language": "python",
545
+ "name": "python3"
546
+ },
547
+ "language_info": {
548
+ "codemirror_mode": {
549
+ "name": "ipython",
550
+ "version": 3
551
+ },
552
+ "file_extension": ".py",
553
+ "mimetype": "text/x-python",
554
+ "name": "python",
555
+ "nbconvert_exporter": "python",
556
+ "pygments_lexer": "ipython3",
557
+ "version": "3.12.1"
558
+ }
559
+ },
560
+ "nbformat": 4,
561
+ "nbformat_minor": 5
562
+ }
scripts/build_location_db.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+ import pandas as pd
4
+
5
+ # Define file paths
6
+ base_dir = os.path.dirname(__file__)
7
+ raw_dir = os.path.abspath(os.path.join(base_dir, "data", "raw"))
8
+ processed_dir = os.path.abspath(os.path.join(base_dir, "data", "processed"))
9
+ os.makedirs(processed_dir, exist_ok=True)
10
+
11
+ # Input CSVs
12
+ disease_path = os.path.join(raw_dir, "kenya_disease_county_matrix.csv")
13
+ xwalk_path = os.path.join(raw_dir, "sitecode_county_xwalk.csv")
14
+ rainy_path = os.path.join(raw_dir, "kenya_counties_rainy_seasons.csv")
15
+ who_path = os.path.join(raw_dir, "who_bulletin.csv")
16
+
17
+ # Output DB
18
+ db_path = os.path.join(processed_dir, "location_data.sqlite")
19
+
20
+ # Read CSVs
21
+ disease_df = pd.read_csv(disease_path)
22
+ xwalk_df = pd.read_csv(xwalk_path)
23
+ rainy_df = pd.read_csv(rainy_path)
24
+ who_df = pd.read_csv(who_path)
25
+
26
+ # Write to SQLite
27
+ conn = sqlite3.connect(db_path)
28
+ disease_df.to_sql('county_disease_info', conn, if_exists='replace', index=False)
29
+ xwalk_df.to_sql('sitecode_county_xwalk', conn, if_exists='replace', index=False)
30
+ rainy_df.to_sql('county_rainy_seasons', conn, if_exists='replace', index=False)
31
+ who_df.to_sql('who_bulletin', conn, if_exists='replace', index=False)
32
+
33
+ conn.commit()
34
+ conn.close()
35
+
36
+ print(f"SQLite database written to: {db_path}")
scripts/parse_guidelines.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from dotenv import load_dotenv
4
+
5
+ from llama_parse import LlamaParse
6
+ from llama_index.core import VectorStoreIndex
7
+ from llama_index.core.node_parser import SimpleNodeParser
8
+ from llama_index.core.schema import Document
9
+
10
+ # Load environment variables
11
+ load_dotenv("config.env")
12
+
13
+ # Set up LlamaParse
14
+ parser = LlamaParse(
15
+ api_key=os.environ.get("LLAMAPARSE_API_KEY"),
16
+ result_type="markdown",
17
+ extract_charts=True,
18
+ auto_mode=True,
19
+ auto_mode_trigger_on_image_in_page=True,
20
+ auto_mode_trigger_on_table_in_page=True,
21
+ bbox_top=0.05,
22
+ bbox_bottom=0.1,
23
+ verbose=True,
24
+ )
25
+
26
+ # Create output directory if it doesn't exist
27
+ os.makedirs("data/processed/lp/indices", exist_ok=True)
28
+
29
+ async def parse_docs():
30
+ for filename in os.listdir("data/raw/GuidelinesSections"):
31
+ if filename.endswith(".pdf"):
32
+ filepath = f"data/raw/GuidelinesSections/{filename}"
33
+ print(f"Processing: {filepath}")
34
+
35
+ try:
36
+ documents = await parser.aload_data(filepath)
37
+ except Exception as e:
38
+ print(f"❌ Failed to parse {filename}: {e}")
39
+ continue
40
+
41
+ full_text = "\n\n".join(doc.text for doc in documents)
42
+ combined_doc = Document(text=full_text)
43
+
44
+ node_parser = SimpleNodeParser()
45
+ nodes = node_parser.get_nodes_from_documents([combined_doc])
46
+
47
+ index = VectorStoreIndex(nodes)
48
+
49
+ short_filename = (
50
+ filename.replace("Kenya-ARV-Guidelines-2022-", "")
51
+ .replace(".pdf", "")
52
+ )
53
+
54
+ index.storage_context.persist(persist_dir=f"data/processed/lp/indices/{short_filename}")
55
+ print(f"βœ… Saved index for {short_filename}")
56
+
57
+ if __name__ == "__main__":
58
+ asyncio.run(parse_docs())
scripts/prep_summaries.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from llama_index.embeddings.openai import OpenAIEmbedding
4
+ import os
5
+ from dotenv import load_dotenv
6
+ load_dotenv("config.env")
7
+ os.environ.get("OPENAI_API_KEY")
8
+
9
+ # load vectorstore summaries
10
+ df = pd.read_csv("data/raw/guidelines_summaries.csv")
11
+
12
+ # Embed summaries
13
+ embedding_model = OpenAIEmbedding()
14
+ summary_embeddings = []
15
+
16
+ for summary in df["summary"]:
17
+ emb = embedding_model.get_text_embedding(summary)
18
+ summary_embeddings.append(emb)
19
+
20
+ summary_embeddings = np.vstack(summary_embeddings)
21
+
22
+ # Save embeddings and metadata
23
+ os.makedirs("data/processed/lp/summary_embeddings", exist_ok=True)
24
+
25
+ np.save("data/processed/lp/summary_embeddings/embeddings.npy", summary_embeddings)
26
+ df.to_csv("data/processed/lp/summary_embeddings/index.tsv", sep="\t", index=False)
27
+
28
+ print("βœ… Saved embeddings and index.")
scripts/process_idsr.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ from dotenv import load_dotenv
5
+ from openai import OpenAI
6
+ from langchain_core.documents import Document
7
+ from langchain_openai import OpenAIEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from rapidfuzz import fuzz
10
+
11
+ # === Setup ===
12
+ base_dir = os.path.dirname(__file__)
13
+ raw_path = os.path.abspath(os.path.join(base_dir, "data", "raw"))
14
+ processed_path = os.path.abspath(os.path.join(base_dir, "data", "processed"))
15
+ os.makedirs(processed_path, exist_ok=True)
16
+
17
+ load_dotenv(os.path.join(base_dir, "config.env"))
18
+ api_key = os.environ.get("OPENAI_API_KEY")
19
+
20
+ # === Step 1: Read IDSR Text ===
21
+ with open(os.path.join(raw_path, "IDSR.txt"), encoding="utf-8") as f:
22
+ text = f.read()
23
+
24
+ # === Step 2: Extract Keywords via GPT ===
25
+ prompt = """
26
+ You are a helpful assistant. Extract a list of 30–50 key symptoms, signs, or diagnostic terms from the following disease descriptions.
27
+
28
+ Focus on words or phrases that are likely to appear in clinical case definitions or user queries β€” such as "fever", "skin lesions", "swollen lymph nodes", "positive blood smear", etc.
29
+
30
+ Only return the keywords or short phrases β€” one per line.
31
+
32
+ Text:
33
+ """
34
+
35
+ client = OpenAI()
36
+ response = client.chat.completions.create(
37
+ model="gpt-4o",
38
+ messages=[
39
+ {"role": "system", "content": "You are a helpful assistant."},
40
+ {"role": "user", "content": prompt + text}
41
+ ],
42
+ temperature=0.0
43
+ )
44
+
45
+ # Normalize keywords
46
+ keywords = [line.strip() for line in response.choices[0].message.content.splitlines() if line.strip()]
47
+ def normalize_kw(kw):
48
+ return kw.lstrip("-β€’ ").strip().lower()
49
+ keywords = [normalize_kw(kw) for kw in keywords]
50
+
51
+ # Save keywords
52
+ kw_path = os.path.join(processed_path, "idsr_keywords.txt")
53
+ with open(kw_path, "w", encoding="utf-8") as f:
54
+ for keyword in keywords:
55
+ f.write(f"{keyword}\n")
56
+
57
+ print(f"βœ… Saved keywords to {kw_path}")
58
+
59
+ # === Step 3: Parse Disease Sections ===
60
+ def parse_disease_text(text):
61
+ diseases = []
62
+ lines = text.strip().splitlines()
63
+
64
+ current_disease = None
65
+ current_subsection = None
66
+ buffer = []
67
+
68
+ def finalize_subsection():
69
+ if current_disease is not None and current_subsection and buffer:
70
+ content = " ".join(line.strip() for line in buffer).strip()
71
+ current_disease[current_subsection] = content
72
+
73
+ subsection_pattern = re.compile(r"^-\s*(.+):\s*$")
74
+
75
+ for line in lines + [""]:
76
+ if not line.strip():
77
+ finalize_subsection()
78
+ if current_disease:
79
+ diseases.append(current_disease)
80
+ current_disease = None
81
+ current_subsection = None
82
+ buffer = []
83
+ continue
84
+
85
+ if current_disease is None:
86
+ current_disease = {"disease_name": line.strip()}
87
+ continue
88
+
89
+ match = subsection_pattern.match(line)
90
+ if match:
91
+ finalize_subsection()
92
+ current_subsection = match.group(1).strip()
93
+ buffer = []
94
+ else:
95
+ buffer.append(line.rstrip())
96
+
97
+ return diseases
98
+
99
+ disease_dicts = parse_disease_text(text)
100
+
101
+ # === Step 4: Convert to LangChain Documents ===
102
+ def convert_disease_dicts_to_documents(disease_dicts):
103
+ docs = []
104
+ for disease in disease_dicts:
105
+ disease_name = disease.get("disease_name", "")
106
+ subsections = [f"{key}:\n{value}" for key, value in disease.items() if key != "disease_name"]
107
+ full_text = f"Disease: {disease_name}\n\n" + "\n\n".join(subsections)
108
+ docs.append(Document(page_content=full_text, metadata={"disease_name": disease_name}))
109
+ return docs
110
+
111
+ documents = convert_disease_dicts_to_documents(disease_dicts)
112
+
113
+ # === Step 5: Tag Documents with Keywords ===
114
+ def tag_documents_with_keywords(documents, keywords, threshold=85):
115
+ tagged = []
116
+ for doc in documents:
117
+ content = doc.page_content.lower()
118
+ matched = [kw for kw in keywords if fuzz.partial_ratio(kw.lower(), content) >= threshold]
119
+ doc.metadata["matched_keywords"] = matched
120
+ tagged.append(doc)
121
+ return tagged
122
+
123
+ tagged_documents = tag_documents_with_keywords(documents, keywords)
124
+
125
+ # Save JSON version
126
+ json_path = os.path.join(processed_path, "tagged_documents.json")
127
+ with open(json_path, "w", encoding="utf-8") as f:
128
+ json.dump([doc.dict() for doc in tagged_documents], f, ensure_ascii=False, indent=2)
129
+
130
+ print(f"βœ… Saved tagged documents to {json_path}")
131
+
132
+ # === Step 6: Build and Save FAISS Vectorstore ===
133
+ embedding_model = OpenAIEmbeddings()
134
+ vectorstore = FAISS.from_documents(tagged_documents, embedding_model)
135
+ vs_path = os.path.join(processed_path, "disease_vectorstore")
136
+ vectorstore.save_local(vs_path)
137
+
138
+ print(f"βœ… Saved FAISS vectorstore to {vs_path}")