Spaces:
Sleeping
Sleeping
JDFPalladium commited on
Commit ·
35274a7
1
Parent(s): 1255a5e
adding idsr define tool and reflecting tweaks to other scripts and notebooks
Browse files- app.py +27 -5
- chatlib/idsr_definition.py +85 -0
- chatlib/patient_all_data.py +4 -1
- notebooks/create_patient_db.ipynb +487 -30
- notebooks/create_slim_patient_db.ipynb +18 -135
- chat.py → scripts/chat.py +0 -0
- scripts/evaluate_trulens.py +147 -0
- {chatlib → scripts}/patient_sql_agent.py +0 -0
- scripts/ragas_eval.py +118 -0
app.py
CHANGED
|
@@ -22,6 +22,7 @@ from chatlib.state_types import AppState
|
|
| 22 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 23 |
from chatlib.patient_all_data import sql_chain
|
| 24 |
from chatlib.idsr_check import idsr_check
|
|
|
|
| 25 |
from chatlib.phi_filter import detect_and_redact_phi
|
| 26 |
from chatlib.assistant_node import assistant
|
| 27 |
|
|
@@ -52,8 +53,15 @@ def idsr_check_tool(query, sitecode):
|
|
| 52 |
"context": result.get("context", None),
|
| 53 |
}
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
tools = [rag_retrieve_tool, sql_chain_tool, idsr_check_tool]
|
| 57 |
llm_with_tools = llm.bind_tools(tools)
|
| 58 |
|
| 59 |
|
|
@@ -61,11 +69,12 @@ sys_msg = SystemMessage(
|
|
| 61 |
content="""
|
| 62 |
You are a helpful assistant supporting clinicians during patient visits. When a patient ID is provided, the clinician is meeting with that HIV-positive patient and may inquire about their history, lab results, or medications. If no patient ID is provided, the clinician may be asking general HIV clinical questions or presenting symptoms for a new patient.
|
| 63 |
|
| 64 |
-
You have access to
|
| 65 |
|
| 66 |
-
-
|
| 67 |
-
-
|
| 68 |
-
-
|
|
|
|
| 69 |
|
| 70 |
When a tool is needed, respond only with a JSON object specifying the tool to call and its minimal arguments, for example:
|
| 71 |
{
|
|
@@ -107,6 +116,19 @@ For example:
|
|
| 107 |
}
|
| 108 |
}
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
There are only two cases where a tool is not needed:
|
| 111 |
1. If the clinician's question is a simple greeting, farewell, or acknowledgement.
|
| 112 |
2. The answer is clearly and completely present in the prior conversation turns.
|
|
|
|
| 22 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 23 |
from chatlib.patient_all_data import sql_chain
|
| 24 |
from chatlib.idsr_check import idsr_check
|
| 25 |
+
from chatlib.idsr_definition import idsr_define
|
| 26 |
from chatlib.phi_filter import detect_and_redact_phi
|
| 27 |
from chatlib.assistant_node import assistant
|
| 28 |
|
|
|
|
| 53 |
"context": result.get("context", None),
|
| 54 |
}
|
| 55 |
|
| 56 |
+
def idsr_define_tool(query):
|
| 57 |
+
"""Retrieve disease definition based on the query."""
|
| 58 |
+
result = idsr_define(query, llm=llm)
|
| 59 |
+
return {
|
| 60 |
+
"answer": result.get("answer", ""),
|
| 61 |
+
"last_tool": "idsr_define"
|
| 62 |
+
}
|
| 63 |
|
| 64 |
+
tools = [rag_retrieve_tool, sql_chain_tool, idsr_check_tool, idsr_define_tool]
|
| 65 |
llm_with_tools = llm.bind_tools(tools)
|
| 66 |
|
| 67 |
|
|
|
|
| 69 |
content="""
|
| 70 |
You are a helpful assistant supporting clinicians during patient visits. When a patient ID is provided, the clinician is meeting with that HIV-positive patient and may inquire about their history, lab results, or medications. If no patient ID is provided, the clinician may be asking general HIV clinical questions or presenting symptoms for a new patient.
|
| 71 |
|
| 72 |
+
You have access to four tools to help you answer the clinician's questions.
|
| 73 |
|
| 74 |
+
- rag_retrieve_tool: to access HIV clinical guidelines
|
| 75 |
+
- sql_chain_tool: to access HIV data about the patient with whom the clinician is meeting. For straightforward factual questions about the patient, you may call sql_chain directly. For questions requiring clinical interpretation or classification, first call rag_retrieve to get relevant clinical guideline context, then include that context when calling sql_chain.
|
| 76 |
+
- idsr_check_tool: to check if the patient case description matches any known diseases.
|
| 77 |
+
- idsr_define_tool: to retrieve the official case definition of a disease when the clinician asks about it (e.g., “What is the description of cholera?”). Do not use this tool for analyzing symptom descriptions — use `idsr_check_tool` for that.
|
| 78 |
|
| 79 |
When a tool is needed, respond only with a JSON object specifying the tool to call and its minimal arguments, for example:
|
| 80 |
{
|
|
|
|
| 116 |
}
|
| 117 |
}
|
| 118 |
|
| 119 |
+
When calling the "idsr_define_tool" tool, always include the following arguments in the JSON response:
|
| 120 |
+
|
| 121 |
+
- "query": the clinician's question
|
| 122 |
+
|
| 123 |
+
For example:
|
| 124 |
+
|
| 125 |
+
{
|
| 126 |
+
"tool": "idsr_define_tool",
|
| 127 |
+
"args": {
|
| 128 |
+
"query": "What is the description of cholera?"
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
There are only two cases where a tool is not needed:
|
| 133 |
1. If the clinician's question is a simple greeting, farewell, or acknowledgement.
|
| 134 |
2. The answer is clearly and completely present in the prior conversation turns.
|
chatlib/idsr_definition.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
+
from langchain_core.output_parsers import PydanticOutputParser
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from langchain_core.documents import Document
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
with open("./data/processed/tagged_documents.json", "r", encoding="utf-8") as f:
|
| 9 |
+
doc_dicts = json.load(f)
|
| 10 |
+
|
| 11 |
+
tagged_documents = [Document(**d) for d in doc_dicts]
|
| 12 |
+
|
| 13 |
+
class DiseaseSelectionOutput(BaseModel):
|
| 14 |
+
disease_name: Optional[str] = Field(
|
| 15 |
+
description="The most likely disease the user is asking about, or null if no match is confident"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def select_disease_from_query(query: str, llm, tagged_docs: list[Document]) -> Optional[str]:
|
| 20 |
+
disease_names = [doc.metadata.get("disease_name") for doc in tagged_docs]
|
| 21 |
+
disease_list = "\n".join(f"- {name}" for name in disease_names)
|
| 22 |
+
|
| 23 |
+
parser = PydanticOutputParser(pydantic_object=DiseaseSelectionOutput)
|
| 24 |
+
|
| 25 |
+
prompt = ChatPromptTemplate.from_template(
|
| 26 |
+
"""
|
| 27 |
+
You are helping a clinician retrieve a disease definition from a list of IDSR diseases.
|
| 28 |
+
|
| 29 |
+
Given the following query:
|
| 30 |
+
"{query}"
|
| 31 |
+
|
| 32 |
+
Select the single disease from the list below that the query most likely refers to.
|
| 33 |
+
|
| 34 |
+
List of available diseases:
|
| 35 |
+
{disease_list}
|
| 36 |
+
|
| 37 |
+
If no match is clearly appropriate, set "disease_name" to null.
|
| 38 |
+
|
| 39 |
+
{format_instructions}
|
| 40 |
+
"""
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
chain = prompt | llm | parser
|
| 44 |
+
output = chain.invoke({
|
| 45 |
+
"query": query,
|
| 46 |
+
"disease_list": disease_list,
|
| 47 |
+
"format_instructions": parser.get_format_instructions()
|
| 48 |
+
})
|
| 49 |
+
|
| 50 |
+
return output.disease_name
|
| 51 |
+
|
| 52 |
+
def idsr_define(query: str, llm) -> dict:
|
| 53 |
+
disease_name = select_disease_from_query(query, llm, tagged_documents)
|
| 54 |
+
|
| 55 |
+
if not disease_name:
|
| 56 |
+
return {
|
| 57 |
+
"answer": "Sorry, I couldn't find a clear match for that disease. Please rephrase or try a different name."
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Search for matching doc
|
| 61 |
+
for doc in tagged_documents:
|
| 62 |
+
if doc.metadata.get("disease_name") == disease_name:
|
| 63 |
+
definition = doc.page_content.strip()
|
| 64 |
+
|
| 65 |
+
# Use LLM to generate a helpful answer
|
| 66 |
+
prompt = f"""
|
| 67 |
+
You are a medical assistant helping a clinician understand disease case definitions.
|
| 68 |
+
|
| 69 |
+
Here is a user query:
|
| 70 |
+
"{query}"
|
| 71 |
+
|
| 72 |
+
Here is the official case definition for the relevant disease:
|
| 73 |
+
"{definition}"
|
| 74 |
+
|
| 75 |
+
Based on the case definition, answer the user query clearly and concisely. Do not speculate beyond the information provided.
|
| 76 |
+
"""
|
| 77 |
+
llm_response = llm.invoke(prompt)
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"answer": llm_response.content.strip()
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"answer": f"Sorry, no case definition was found for the selected disease."
|
| 85 |
+
}
|
chatlib/patient_all_data.py
CHANGED
|
@@ -172,6 +172,9 @@ def sql_chain(query: str, llm, rag_result: str, pk_hash: str) -> dict:
|
|
| 172 |
except (ValueError, TypeError):
|
| 173 |
return "invalid date"
|
| 174 |
|
|
|
|
|
|
|
|
|
|
| 175 |
row = df.iloc[0]
|
| 176 |
summary = (
|
| 177 |
f"Sex: {safe(row['Sex'])}\n"
|
|
@@ -180,7 +183,7 @@ def sql_chain(query: str, llm, rag_result: str, pk_hash: str) -> dict:
|
|
| 180 |
f"Occupation: {safe(row['Occupation'])}\n"
|
| 181 |
f"OnIPT: {safe(row['OnIPT'])}\n"
|
| 182 |
f"ARTOutcomeDescription: {safe(row['ARTOutcomeDescription'])}\n"
|
| 183 |
-
f"StartARTDate: {
|
| 184 |
f"Age: {calculate_age(safe(row['DOB']))}"
|
| 185 |
)
|
| 186 |
return summary
|
|
|
|
| 172 |
except (ValueError, TypeError):
|
| 173 |
return "invalid date"
|
| 174 |
|
| 175 |
+
df = df.copy()
|
| 176 |
+
df["StartARTDate"] = pd.to_datetime(df["StartARTDate"], errors="coerce")
|
| 177 |
+
|
| 178 |
row = df.iloc[0]
|
| 179 |
summary = (
|
| 180 |
f"Sex: {safe(row['Sex'])}\n"
|
|
|
|
| 183 |
f"Occupation: {safe(row['Occupation'])}\n"
|
| 184 |
f"OnIPT: {safe(row['OnIPT'])}\n"
|
| 185 |
f"ARTOutcomeDescription: {safe(row['ARTOutcomeDescription'])}\n"
|
| 186 |
+
f"StartARTDate: {describe_relative_date(row['StartARTDate'])}\n"
|
| 187 |
f"Age: {calculate_age(safe(row['DOB']))}"
|
| 188 |
)
|
| 189 |
return summary
|
notebooks/create_patient_db.ipynb
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "ddb26634",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
|
@@ -10,7 +10,7 @@
|
|
| 10 |
"import sqlite3\n",
|
| 11 |
"import pandas as pd\n",
|
| 12 |
"# inspect current database schema\n",
|
| 13 |
-
"conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 14 |
"cursor = conn.cursor()\n",
|
| 15 |
"# list tables\n",
|
| 16 |
"# pull all data from the visits table \n",
|
|
@@ -22,19 +22,41 @@
|
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"cell_type": "code",
|
| 25 |
-
"execution_count":
|
| 26 |
"id": "cd4faa4b",
|
| 27 |
"metadata": {},
|
| 28 |
"outputs": [],
|
| 29 |
"source": [
|
| 30 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 31 |
-
"conn = sqlite3.connect('patient_demonstration.sqlite')\n",
|
| 32 |
"cursor = conn.cursor() "
|
| 33 |
]
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"cell_type": "code",
|
| 37 |
"execution_count": 3,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"id": "f8547b78",
|
| 39 |
"metadata": {},
|
| 40 |
"outputs": [],
|
|
@@ -82,17 +104,17 @@
|
|
| 82 |
},
|
| 83 |
{
|
| 84 |
"cell_type": "code",
|
| 85 |
-
"execution_count":
|
| 86 |
"id": "9ddfa626",
|
| 87 |
"metadata": {},
|
| 88 |
"outputs": [
|
| 89 |
{
|
| 90 |
"data": {
|
| 91 |
"text/plain": [
|
| 92 |
-
"<sqlite3.Cursor at
|
| 93 |
]
|
| 94 |
},
|
| 95 |
-
"execution_count":
|
| 96 |
"metadata": {},
|
| 97 |
"output_type": "execute_result"
|
| 98 |
}
|
|
@@ -110,7 +132,7 @@
|
|
| 110 |
},
|
| 111 |
{
|
| 112 |
"cell_type": "code",
|
| 113 |
-
"execution_count":
|
| 114 |
"id": "d14ef687",
|
| 115 |
"metadata": {},
|
| 116 |
"outputs": [],
|
|
@@ -177,12 +199,12 @@
|
|
| 177 |
},
|
| 178 |
{
|
| 179 |
"cell_type": "code",
|
| 180 |
-
"execution_count":
|
| 181 |
"id": "6e27bce5",
|
| 182 |
"metadata": {},
|
| 183 |
"outputs": [],
|
| 184 |
"source": [
|
| 185 |
-
"conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 186 |
"cursor = conn.cursor()\n",
|
| 187 |
"# pull all data from the lab table except for the \"key\" column \n",
|
| 188 |
"cursor.execute(\"SELECT * FROM lab;\")\n",
|
|
@@ -193,19 +215,19 @@
|
|
| 193 |
},
|
| 194 |
{
|
| 195 |
"cell_type": "code",
|
| 196 |
-
"execution_count":
|
| 197 |
"id": "14402e96",
|
| 198 |
"metadata": {},
|
| 199 |
"outputs": [],
|
| 200 |
"source": [
|
| 201 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 202 |
-
"conn = sqlite3.connect('patient_demonstration.sqlite')\n",
|
| 203 |
"cursor = conn.cursor() "
|
| 204 |
]
|
| 205 |
},
|
| 206 |
{
|
| 207 |
"cell_type": "code",
|
| 208 |
-
"execution_count":
|
| 209 |
"id": "540962b7",
|
| 210 |
"metadata": {},
|
| 211 |
"outputs": [],
|
|
@@ -235,7 +257,7 @@
|
|
| 235 |
},
|
| 236 |
{
|
| 237 |
"cell_type": "code",
|
| 238 |
-
"execution_count":
|
| 239 |
"id": "8df7171e",
|
| 240 |
"metadata": {},
|
| 241 |
"outputs": [],
|
|
@@ -260,12 +282,12 @@
|
|
| 260 |
},
|
| 261 |
{
|
| 262 |
"cell_type": "code",
|
| 263 |
-
"execution_count":
|
| 264 |
"id": "b66d3dbb",
|
| 265 |
"metadata": {},
|
| 266 |
"outputs": [],
|
| 267 |
"source": [
|
| 268 |
-
"conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 269 |
"cursor = conn.cursor()\n",
|
| 270 |
"# pull all data from the lab table except for the \"key\" column \n",
|
| 271 |
"cursor.execute(\"SELECT * FROM pharmacy;\")\n",
|
|
@@ -276,19 +298,19 @@
|
|
| 276 |
},
|
| 277 |
{
|
| 278 |
"cell_type": "code",
|
| 279 |
-
"execution_count":
|
| 280 |
"id": "435b8d4e",
|
| 281 |
"metadata": {},
|
| 282 |
"outputs": [],
|
| 283 |
"source": [
|
| 284 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 285 |
-
"conn = sqlite3.connect('patient_demonstration.sqlite')\n",
|
| 286 |
"cursor = conn.cursor() "
|
| 287 |
]
|
| 288 |
},
|
| 289 |
{
|
| 290 |
"cell_type": "code",
|
| 291 |
-
"execution_count":
|
| 292 |
"id": "b3753eeb",
|
| 293 |
"metadata": {},
|
| 294 |
"outputs": [],
|
|
@@ -322,7 +344,7 @@
|
|
| 322 |
},
|
| 323 |
{
|
| 324 |
"cell_type": "code",
|
| 325 |
-
"execution_count":
|
| 326 |
"id": "8b8ed08a",
|
| 327 |
"metadata": {},
|
| 328 |
"outputs": [],
|
|
@@ -348,12 +370,12 @@
|
|
| 348 |
},
|
| 349 |
{
|
| 350 |
"cell_type": "code",
|
| 351 |
-
"execution_count":
|
| 352 |
"id": "2de65432",
|
| 353 |
"metadata": {},
|
| 354 |
"outputs": [],
|
| 355 |
"source": [
|
| 356 |
-
"conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 357 |
"cursor = conn.cursor()\n",
|
| 358 |
"# pull all data from the lab table except for the \"key\" column \n",
|
| 359 |
"cursor.execute(\"SELECT * FROM demographics;\")\n",
|
|
@@ -364,19 +386,454 @@
|
|
| 364 |
},
|
| 365 |
{
|
| 366 |
"cell_type": "code",
|
| 367 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
"id": "a7a10f4f",
|
| 369 |
"metadata": {},
|
| 370 |
"outputs": [],
|
| 371 |
"source": [
|
| 372 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 373 |
-
"conn = sqlite3.connect('patient_demonstration.sqlite')\n",
|
| 374 |
"cursor = conn.cursor() "
|
| 375 |
]
|
| 376 |
},
|
| 377 |
{
|
| 378 |
"cell_type": "code",
|
| 379 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
"id": "947c63d5",
|
| 381 |
"metadata": {},
|
| 382 |
"outputs": [],
|
|
@@ -386,6 +843,7 @@
|
|
| 386 |
"cursor.execute('DROP TABLE IF EXISTS demographics;')\n",
|
| 387 |
"cursor.execute('''\n",
|
| 388 |
"CREATE TABLE demographics (\n",
|
|
|
|
| 389 |
" PatientPKHash TEXT,\n",
|
| 390 |
" MFLCode TEXT,\n",
|
| 391 |
" FacilityName TEXT,\n",
|
|
@@ -403,14 +861,13 @@
|
|
| 403 |
" AsOfDate TEXT,\n",
|
| 404 |
" LoadDate TEXT,\n",
|
| 405 |
" StartARTDate TEXT,\n",
|
| 406 |
-
" DOB TEXT
|
| 407 |
-
" key TEXT\n",
|
| 408 |
");\n",
|
| 409 |
"''')\n",
|
| 410 |
"\n",
|
| 411 |
"# let's now populate the table with the rows variable that contains all the data from the visits table\n",
|
| 412 |
"cursor.executemany('''\n",
|
| 413 |
-
"INSERT INTO demographics (PatientPKHash, MFLCode, FacilityName, County, SubCounty, PartnerName, AgencyName, Sex, MaritalStatus, EducationLevel, Occupation, OnIPT, AgeGroup, ARTOutcomeDescription, AsOfDate, LoadDate, StartARTDate, DOB
|
| 414 |
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n",
|
| 415 |
"''', rows)\n",
|
| 416 |
"conn.commit()"
|
|
@@ -418,7 +875,7 @@
|
|
| 418 |
},
|
| 419 |
{
|
| 420 |
"cell_type": "code",
|
| 421 |
-
"execution_count":
|
| 422 |
"id": "9cff0d90",
|
| 423 |
"metadata": {},
|
| 424 |
"outputs": [],
|
|
@@ -458,7 +915,7 @@
|
|
| 458 |
],
|
| 459 |
"metadata": {
|
| 460 |
"kernelspec": {
|
| 461 |
-
"display_name": "
|
| 462 |
"language": "python",
|
| 463 |
"name": "python3"
|
| 464 |
},
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 19,
|
| 6 |
"id": "ddb26634",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
|
|
|
| 10 |
"import sqlite3\n",
|
| 11 |
"import pandas as pd\n",
|
| 12 |
"# inspect current database schema\n",
|
| 13 |
+
"conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 14 |
"cursor = conn.cursor()\n",
|
| 15 |
"# list tables\n",
|
| 16 |
"# pull all data from the visits table \n",
|
|
|
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"cell_type": "code",
|
| 25 |
+
"execution_count": 20,
|
| 26 |
"id": "cd4faa4b",
|
| 27 |
"metadata": {},
|
| 28 |
"outputs": [],
|
| 29 |
"source": [
|
| 30 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 31 |
+
"conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
|
| 32 |
"cursor = conn.cursor() "
|
| 33 |
]
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"cell_type": "code",
|
| 37 |
"execution_count": 3,
|
| 38 |
+
"id": "866e707d",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [
|
| 41 |
+
{
|
| 42 |
+
"name": "stdout",
|
| 43 |
+
"output_type": "stream",
|
| 44 |
+
"text": [
|
| 45 |
+
"(271, 25)\n"
|
| 46 |
+
]
|
| 47 |
+
}
|
| 48 |
+
],
|
| 49 |
+
"source": [
|
| 50 |
+
"# extract everything from the visits table\n",
|
| 51 |
+
"cursor.execute(\"SELECT * FROM clinical_visits;\")\n",
|
| 52 |
+
"rows = cursor.fetchall()\n",
|
| 53 |
+
"visits_df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
|
| 54 |
+
"print(visits_df.shape)"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 4,
|
| 60 |
"id": "f8547b78",
|
| 61 |
"metadata": {},
|
| 62 |
"outputs": [],
|
|
|
|
| 104 |
},
|
| 105 |
{
|
| 106 |
"cell_type": "code",
|
| 107 |
+
"execution_count": 5,
|
| 108 |
"id": "9ddfa626",
|
| 109 |
"metadata": {},
|
| 110 |
"outputs": [
|
| 111 |
{
|
| 112 |
"data": {
|
| 113 |
"text/plain": [
|
| 114 |
+
"<sqlite3.Cursor at 0x71e4721907c0>"
|
| 115 |
]
|
| 116 |
},
|
| 117 |
+
"execution_count": 5,
|
| 118 |
"metadata": {},
|
| 119 |
"output_type": "execute_result"
|
| 120 |
}
|
|
|
|
| 132 |
},
|
| 133 |
{
|
| 134 |
"cell_type": "code",
|
| 135 |
+
"execution_count": 6,
|
| 136 |
"id": "d14ef687",
|
| 137 |
"metadata": {},
|
| 138 |
"outputs": [],
|
|
|
|
| 199 |
},
|
| 200 |
{
|
| 201 |
"cell_type": "code",
|
| 202 |
+
"execution_count": 7,
|
| 203 |
"id": "6e27bce5",
|
| 204 |
"metadata": {},
|
| 205 |
"outputs": [],
|
| 206 |
"source": [
|
| 207 |
+
"conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 208 |
"cursor = conn.cursor()\n",
|
| 209 |
"# pull all data from the lab table except for the \"key\" column \n",
|
| 210 |
"cursor.execute(\"SELECT * FROM lab;\")\n",
|
|
|
|
| 215 |
},
|
| 216 |
{
|
| 217 |
"cell_type": "code",
|
| 218 |
+
"execution_count": 8,
|
| 219 |
"id": "14402e96",
|
| 220 |
"metadata": {},
|
| 221 |
"outputs": [],
|
| 222 |
"source": [
|
| 223 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 224 |
+
"conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
|
| 225 |
"cursor = conn.cursor() "
|
| 226 |
]
|
| 227 |
},
|
| 228 |
{
|
| 229 |
"cell_type": "code",
|
| 230 |
+
"execution_count": 9,
|
| 231 |
"id": "540962b7",
|
| 232 |
"metadata": {},
|
| 233 |
"outputs": [],
|
|
|
|
| 257 |
},
|
| 258 |
{
|
| 259 |
"cell_type": "code",
|
| 260 |
+
"execution_count": 10,
|
| 261 |
"id": "8df7171e",
|
| 262 |
"metadata": {},
|
| 263 |
"outputs": [],
|
|
|
|
| 282 |
},
|
| 283 |
{
|
| 284 |
"cell_type": "code",
|
| 285 |
+
"execution_count": 11,
|
| 286 |
"id": "b66d3dbb",
|
| 287 |
"metadata": {},
|
| 288 |
"outputs": [],
|
| 289 |
"source": [
|
| 290 |
+
"conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 291 |
"cursor = conn.cursor()\n",
|
| 292 |
"# pull all data from the lab table except for the \"key\" column \n",
|
| 293 |
"cursor.execute(\"SELECT * FROM pharmacy;\")\n",
|
|
|
|
| 298 |
},
|
| 299 |
{
|
| 300 |
"cell_type": "code",
|
| 301 |
+
"execution_count": 12,
|
| 302 |
"id": "435b8d4e",
|
| 303 |
"metadata": {},
|
| 304 |
"outputs": [],
|
| 305 |
"source": [
|
| 306 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 307 |
+
"conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
|
| 308 |
"cursor = conn.cursor() "
|
| 309 |
]
|
| 310 |
},
|
| 311 |
{
|
| 312 |
"cell_type": "code",
|
| 313 |
+
"execution_count": 13,
|
| 314 |
"id": "b3753eeb",
|
| 315 |
"metadata": {},
|
| 316 |
"outputs": [],
|
|
|
|
| 344 |
},
|
| 345 |
{
|
| 346 |
"cell_type": "code",
|
| 347 |
+
"execution_count": 14,
|
| 348 |
"id": "8b8ed08a",
|
| 349 |
"metadata": {},
|
| 350 |
"outputs": [],
|
|
|
|
| 370 |
},
|
| 371 |
{
|
| 372 |
"cell_type": "code",
|
| 373 |
+
"execution_count": 24,
|
| 374 |
"id": "2de65432",
|
| 375 |
"metadata": {},
|
| 376 |
"outputs": [],
|
| 377 |
"source": [
|
| 378 |
+
"conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 379 |
"cursor = conn.cursor()\n",
|
| 380 |
"# pull all data from the lab table except for the \"key\" column \n",
|
| 381 |
"cursor.execute(\"SELECT * FROM demographics;\")\n",
|
|
|
|
| 386 |
},
|
| 387 |
{
|
| 388 |
"cell_type": "code",
|
| 389 |
+
"execution_count": 27,
|
| 390 |
+
"id": "f3a11ac1",
|
| 391 |
+
"metadata": {},
|
| 392 |
+
"outputs": [
|
| 393 |
+
{
|
| 394 |
+
"data": {
|
| 395 |
+
"text/html": [
|
| 396 |
+
"<div>\n",
|
| 397 |
+
"<style scoped>\n",
|
| 398 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 399 |
+
" vertical-align: middle;\n",
|
| 400 |
+
" }\n",
|
| 401 |
+
"\n",
|
| 402 |
+
" .dataframe tbody tr th {\n",
|
| 403 |
+
" vertical-align: top;\n",
|
| 404 |
+
" }\n",
|
| 405 |
+
"\n",
|
| 406 |
+
" .dataframe thead th {\n",
|
| 407 |
+
" text-align: right;\n",
|
| 408 |
+
" }\n",
|
| 409 |
+
"</style>\n",
|
| 410 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 411 |
+
" <thead>\n",
|
| 412 |
+
" <tr style=\"text-align: right;\">\n",
|
| 413 |
+
" <th></th>\n",
|
| 414 |
+
" <th>key</th>\n",
|
| 415 |
+
" <th>PatientPKHash</th>\n",
|
| 416 |
+
" <th>MFLCode</th>\n",
|
| 417 |
+
" <th>FacilityName</th>\n",
|
| 418 |
+
" <th>County</th>\n",
|
| 419 |
+
" <th>SubCounty</th>\n",
|
| 420 |
+
" <th>PartnerName</th>\n",
|
| 421 |
+
" <th>AgencyName</th>\n",
|
| 422 |
+
" <th>Sex</th>\n",
|
| 423 |
+
" <th>MaritalStatus</th>\n",
|
| 424 |
+
" <th>EducationLevel</th>\n",
|
| 425 |
+
" <th>Occupation</th>\n",
|
| 426 |
+
" <th>OnIPT</th>\n",
|
| 427 |
+
" <th>AgeGroup</th>\n",
|
| 428 |
+
" <th>ARTOutcomeDescription</th>\n",
|
| 429 |
+
" <th>AsOfDate</th>\n",
|
| 430 |
+
" <th>LoadDate</th>\n",
|
| 431 |
+
" <th>StartARTDate</th>\n",
|
| 432 |
+
" <th>DOB</th>\n",
|
| 433 |
+
" </tr>\n",
|
| 434 |
+
" </thead>\n",
|
| 435 |
+
" <tbody>\n",
|
| 436 |
+
" <tr>\n",
|
| 437 |
+
" <th>0</th>\n",
|
| 438 |
+
" <td>07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8...</td>\n",
|
| 439 |
+
" <td>3</td>\n",
|
| 440 |
+
" <td>13703</td>\n",
|
| 441 |
+
" <td>Kisii Teaching and Referral Hospital (Level 6)</td>\n",
|
| 442 |
+
" <td>Kisii</td>\n",
|
| 443 |
+
" <td>Kitutu Chache South</td>\n",
|
| 444 |
+
" <td>LVCT Vukisha 95</td>\n",
|
| 445 |
+
" <td>CDC</td>\n",
|
| 446 |
+
" <td>Female</td>\n",
|
| 447 |
+
" <td>Single</td>\n",
|
| 448 |
+
" <td>NULL</td>\n",
|
| 449 |
+
" <td>NULL</td>\n",
|
| 450 |
+
" <td>NULL</td>\n",
|
| 451 |
+
" <td>NULL</td>\n",
|
| 452 |
+
" <td>LOST IN HMIS</td>\n",
|
| 453 |
+
" <td>20088</td>\n",
|
| 454 |
+
" <td>20161</td>\n",
|
| 455 |
+
" <td>2012-04-12 00:00:00.000</td>\n",
|
| 456 |
+
" <td>2010-05-10 00:00:00.000</td>\n",
|
| 457 |
+
" </tr>\n",
|
| 458 |
+
" <tr>\n",
|
| 459 |
+
" <th>1</th>\n",
|
| 460 |
+
" <td>290D316E1B41A21F58E780026971F4D86DBB3BF043A77B...</td>\n",
|
| 461 |
+
" <td>4</td>\n",
|
| 462 |
+
" <td>13028</td>\n",
|
| 463 |
+
" <td>Kibera Community Health Centre - Amref</td>\n",
|
| 464 |
+
" <td>Nairobi</td>\n",
|
| 465 |
+
" <td>Kibra</td>\n",
|
| 466 |
+
" <td>CIHEB CONNECT</td>\n",
|
| 467 |
+
" <td>CDC</td>\n",
|
| 468 |
+
" <td>Female</td>\n",
|
| 469 |
+
" <td>MARRIED MONOGAMOUS</td>\n",
|
| 470 |
+
" <td>SECONDARY</td>\n",
|
| 471 |
+
" <td>Trader</td>\n",
|
| 472 |
+
" <td>NULL</td>\n",
|
| 473 |
+
" <td>NULL</td>\n",
|
| 474 |
+
" <td>ACTIVE</td>\n",
|
| 475 |
+
" <td>20088</td>\n",
|
| 476 |
+
" <td>20161</td>\n",
|
| 477 |
+
" <td>2009-05-12 00:00:00.000</td>\n",
|
| 478 |
+
" <td>1970-08-25 00:00:00.000</td>\n",
|
| 479 |
+
" </tr>\n",
|
| 480 |
+
" <tr>\n",
|
| 481 |
+
" <th>2</th>\n",
|
| 482 |
+
" <td>45889B18F2C615A78371E1DAFC2680C0A36284C6195885...</td>\n",
|
| 483 |
+
" <td>9</td>\n",
|
| 484 |
+
" <td>15834</td>\n",
|
| 485 |
+
" <td>Busia County Referral Hospital</td>\n",
|
| 486 |
+
" <td>Busia</td>\n",
|
| 487 |
+
" <td>Matayos</td>\n",
|
| 488 |
+
" <td>USAID Dumisha Afya</td>\n",
|
| 489 |
+
" <td>USAID</td>\n",
|
| 490 |
+
" <td>Female</td>\n",
|
| 491 |
+
" <td>NULL</td>\n",
|
| 492 |
+
" <td>NULL</td>\n",
|
| 493 |
+
" <td>NULL</td>\n",
|
| 494 |
+
" <td>NULL</td>\n",
|
| 495 |
+
" <td>NULL</td>\n",
|
| 496 |
+
" <td>ACTIVE</td>\n",
|
| 497 |
+
" <td>20088</td>\n",
|
| 498 |
+
" <td>20161</td>\n",
|
| 499 |
+
" <td>2014-08-12 00:00:00.000</td>\n",
|
| 500 |
+
" <td>1972-04-13 00:00:00.000</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" <tr>\n",
|
| 503 |
+
" <th>3</th>\n",
|
| 504 |
+
" <td>9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772...</td>\n",
|
| 505 |
+
" <td>7</td>\n",
|
| 506 |
+
" <td>14831</td>\n",
|
| 507 |
+
" <td>Kericho District Hospital</td>\n",
|
| 508 |
+
" <td>Kericho</td>\n",
|
| 509 |
+
" <td>Ainamoi</td>\n",
|
| 510 |
+
" <td>HJF-South Rift Valley</td>\n",
|
| 511 |
+
" <td>DOD</td>\n",
|
| 512 |
+
" <td>Female</td>\n",
|
| 513 |
+
" <td>MARRIED MONOGAMOUS</td>\n",
|
| 514 |
+
" <td>NULL</td>\n",
|
| 515 |
+
" <td>Trader</td>\n",
|
| 516 |
+
" <td>NULL</td>\n",
|
| 517 |
+
" <td>NULL</td>\n",
|
| 518 |
+
" <td>ACTIVE</td>\n",
|
| 519 |
+
" <td>20088</td>\n",
|
| 520 |
+
" <td>20161</td>\n",
|
| 521 |
+
" <td>2023-05-10 00:00:00.000</td>\n",
|
| 522 |
+
" <td>1989-06-15 00:00:00.000</td>\n",
|
| 523 |
+
" </tr>\n",
|
| 524 |
+
" <tr>\n",
|
| 525 |
+
" <th>4</th>\n",
|
| 526 |
+
" <td>A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C...</td>\n",
|
| 527 |
+
" <td>1</td>\n",
|
| 528 |
+
" <td>11259</td>\n",
|
| 529 |
+
" <td>Bomu Medical Centre (Likoni)</td>\n",
|
| 530 |
+
" <td>Mombasa</td>\n",
|
| 531 |
+
" <td>Likoni</td>\n",
|
| 532 |
+
" <td>Mkomani Clinic society</td>\n",
|
| 533 |
+
" <td>CDC</td>\n",
|
| 534 |
+
" <td>Female</td>\n",
|
| 535 |
+
" <td>MARRIED MONOGAMOUS</td>\n",
|
| 536 |
+
" <td>PRIMARY</td>\n",
|
| 537 |
+
" <td>Trader</td>\n",
|
| 538 |
+
" <td>NULL</td>\n",
|
| 539 |
+
" <td>NULL</td>\n",
|
| 540 |
+
" <td>UNDOCUMENTED LOSS</td>\n",
|
| 541 |
+
" <td>20088</td>\n",
|
| 542 |
+
" <td>20161</td>\n",
|
| 543 |
+
" <td>2018-05-22 00:00:00.000</td>\n",
|
| 544 |
+
" <td>1995-05-21 00:00:00.000</td>\n",
|
| 545 |
+
" </tr>\n",
|
| 546 |
+
" </tbody>\n",
|
| 547 |
+
"</table>\n",
|
| 548 |
+
"</div>"
|
| 549 |
+
],
|
| 550 |
+
"text/plain": [
|
| 551 |
+
" key PatientPKHash MFLCode \\\n",
|
| 552 |
+
"0 07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8... 3 13703 \n",
|
| 553 |
+
"1 290D316E1B41A21F58E780026971F4D86DBB3BF043A77B... 4 13028 \n",
|
| 554 |
+
"2 45889B18F2C615A78371E1DAFC2680C0A36284C6195885... 9 15834 \n",
|
| 555 |
+
"3 9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772... 7 14831 \n",
|
| 556 |
+
"4 A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C... 1 11259 \n",
|
| 557 |
+
"\n",
|
| 558 |
+
" FacilityName County \\\n",
|
| 559 |
+
"0 Kisii Teaching and Referral Hospital (Level 6) Kisii \n",
|
| 560 |
+
"1 Kibera Community Health Centre - Amref Nairobi \n",
|
| 561 |
+
"2 Busia County Referral Hospital Busia \n",
|
| 562 |
+
"3 Kericho District Hospital Kericho \n",
|
| 563 |
+
"4 Bomu Medical Centre (Likoni) Mombasa \n",
|
| 564 |
+
"\n",
|
| 565 |
+
" SubCounty PartnerName AgencyName Sex \\\n",
|
| 566 |
+
"0 Kitutu Chache South LVCT Vukisha 95 CDC Female \n",
|
| 567 |
+
"1 Kibra CIHEB CONNECT CDC Female \n",
|
| 568 |
+
"2 Matayos USAID Dumisha Afya USAID Female \n",
|
| 569 |
+
"3 Ainamoi HJF-South Rift Valley DOD Female \n",
|
| 570 |
+
"4 Likoni Mkomani Clinic society CDC Female \n",
|
| 571 |
+
"\n",
|
| 572 |
+
" MaritalStatus EducationLevel Occupation OnIPT AgeGroup \\\n",
|
| 573 |
+
"0 Single NULL NULL NULL NULL \n",
|
| 574 |
+
"1 MARRIED MONOGAMOUS SECONDARY Trader NULL NULL \n",
|
| 575 |
+
"2 NULL NULL NULL NULL NULL \n",
|
| 576 |
+
"3 MARRIED MONOGAMOUS NULL Trader NULL NULL \n",
|
| 577 |
+
"4 MARRIED MONOGAMOUS PRIMARY Trader NULL NULL \n",
|
| 578 |
+
"\n",
|
| 579 |
+
" ARTOutcomeDescription AsOfDate LoadDate StartARTDate \\\n",
|
| 580 |
+
"0 LOST IN HMIS 20088 20161 2012-04-12 00:00:00.000 \n",
|
| 581 |
+
"1 ACTIVE 20088 20161 2009-05-12 00:00:00.000 \n",
|
| 582 |
+
"2 ACTIVE 20088 20161 2014-08-12 00:00:00.000 \n",
|
| 583 |
+
"3 ACTIVE 20088 20161 2023-05-10 00:00:00.000 \n",
|
| 584 |
+
"4 UNDOCUMENTED LOSS 20088 20161 2018-05-22 00:00:00.000 \n",
|
| 585 |
+
"\n",
|
| 586 |
+
" DOB \n",
|
| 587 |
+
"0 2010-05-10 00:00:00.000 \n",
|
| 588 |
+
"1 1970-08-25 00:00:00.000 \n",
|
| 589 |
+
"2 1972-04-13 00:00:00.000 \n",
|
| 590 |
+
"3 1989-06-15 00:00:00.000 \n",
|
| 591 |
+
"4 1995-05-21 00:00:00.000 "
|
| 592 |
+
]
|
| 593 |
+
},
|
| 594 |
+
"execution_count": 27,
|
| 595 |
+
"metadata": {},
|
| 596 |
+
"output_type": "execute_result"
|
| 597 |
+
}
|
| 598 |
+
],
|
| 599 |
+
"source": [
|
| 600 |
+
"df.head()"
|
| 601 |
+
]
|
| 602 |
+
},
|
| 603 |
+
{
|
| 604 |
+
"cell_type": "code",
|
| 605 |
+
"execution_count": 30,
|
| 606 |
"id": "a7a10f4f",
|
| 607 |
"metadata": {},
|
| 608 |
"outputs": [],
|
| 609 |
"source": [
|
| 610 |
"# let's create a new sqlite database called patient_demonstration.sqlite\n",
|
| 611 |
+
"conn = sqlite3.connect('../data/processed/patient_demonstration.sqlite')\n",
|
| 612 |
"cursor = conn.cursor() "
|
| 613 |
]
|
| 614 |
},
|
| 615 |
{
|
| 616 |
"cell_type": "code",
|
| 617 |
+
"execution_count": 32,
|
| 618 |
+
"id": "07296631",
|
| 619 |
+
"metadata": {},
|
| 620 |
+
"outputs": [
|
| 621 |
+
{
|
| 622 |
+
"data": {
|
| 623 |
+
"text/html": [
|
| 624 |
+
"<div>\n",
|
| 625 |
+
"<style scoped>\n",
|
| 626 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 627 |
+
" vertical-align: middle;\n",
|
| 628 |
+
" }\n",
|
| 629 |
+
"\n",
|
| 630 |
+
" .dataframe tbody tr th {\n",
|
| 631 |
+
" vertical-align: top;\n",
|
| 632 |
+
" }\n",
|
| 633 |
+
"\n",
|
| 634 |
+
" .dataframe thead th {\n",
|
| 635 |
+
" text-align: right;\n",
|
| 636 |
+
" }\n",
|
| 637 |
+
"</style>\n",
|
| 638 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 639 |
+
" <thead>\n",
|
| 640 |
+
" <tr style=\"text-align: right;\">\n",
|
| 641 |
+
" <th></th>\n",
|
| 642 |
+
" <th>key</th>\n",
|
| 643 |
+
" <th>PatientPKHash</th>\n",
|
| 644 |
+
" <th>MFLCode</th>\n",
|
| 645 |
+
" <th>FacilityName</th>\n",
|
| 646 |
+
" <th>County</th>\n",
|
| 647 |
+
" <th>SubCounty</th>\n",
|
| 648 |
+
" <th>PartnerName</th>\n",
|
| 649 |
+
" <th>AgencyName</th>\n",
|
| 650 |
+
" <th>Sex</th>\n",
|
| 651 |
+
" <th>MaritalStatus</th>\n",
|
| 652 |
+
" <th>EducationLevel</th>\n",
|
| 653 |
+
" <th>Occupation</th>\n",
|
| 654 |
+
" <th>OnIPT</th>\n",
|
| 655 |
+
" <th>AgeGroup</th>\n",
|
| 656 |
+
" <th>ARTOutcomeDescription</th>\n",
|
| 657 |
+
" <th>AsOfDate</th>\n",
|
| 658 |
+
" <th>LoadDate</th>\n",
|
| 659 |
+
" <th>StartARTDate</th>\n",
|
| 660 |
+
" <th>DOB</th>\n",
|
| 661 |
+
" </tr>\n",
|
| 662 |
+
" </thead>\n",
|
| 663 |
+
" <tbody>\n",
|
| 664 |
+
" <tr>\n",
|
| 665 |
+
" <th>0</th>\n",
|
| 666 |
+
" <td>07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8...</td>\n",
|
| 667 |
+
" <td>3</td>\n",
|
| 668 |
+
" <td>13703</td>\n",
|
| 669 |
+
" <td>Kisii Teaching and Referral Hospital (Level 6)</td>\n",
|
| 670 |
+
" <td>Kisii</td>\n",
|
| 671 |
+
" <td>Kitutu Chache South</td>\n",
|
| 672 |
+
" <td>LVCT Vukisha 95</td>\n",
|
| 673 |
+
" <td>CDC</td>\n",
|
| 674 |
+
" <td>Female</td>\n",
|
| 675 |
+
" <td>Single</td>\n",
|
| 676 |
+
" <td>NULL</td>\n",
|
| 677 |
+
" <td>NULL</td>\n",
|
| 678 |
+
" <td>NULL</td>\n",
|
| 679 |
+
" <td>NULL</td>\n",
|
| 680 |
+
" <td>LOST IN HMIS</td>\n",
|
| 681 |
+
" <td>20088</td>\n",
|
| 682 |
+
" <td>20161</td>\n",
|
| 683 |
+
" <td>2012-04-12 00:00:00.000</td>\n",
|
| 684 |
+
" <td>2010-05-10 00:00:00.000</td>\n",
|
| 685 |
+
" </tr>\n",
|
| 686 |
+
" <tr>\n",
|
| 687 |
+
" <th>1</th>\n",
|
| 688 |
+
" <td>290D316E1B41A21F58E780026971F4D86DBB3BF043A77B...</td>\n",
|
| 689 |
+
" <td>4</td>\n",
|
| 690 |
+
" <td>13028</td>\n",
|
| 691 |
+
" <td>Kibera Community Health Centre - Amref</td>\n",
|
| 692 |
+
" <td>Nairobi</td>\n",
|
| 693 |
+
" <td>Kibra</td>\n",
|
| 694 |
+
" <td>CIHEB CONNECT</td>\n",
|
| 695 |
+
" <td>CDC</td>\n",
|
| 696 |
+
" <td>Female</td>\n",
|
| 697 |
+
" <td>MARRIED MONOGAMOUS</td>\n",
|
| 698 |
+
" <td>SECONDARY</td>\n",
|
| 699 |
+
" <td>Trader</td>\n",
|
| 700 |
+
" <td>NULL</td>\n",
|
| 701 |
+
" <td>NULL</td>\n",
|
| 702 |
+
" <td>ACTIVE</td>\n",
|
| 703 |
+
" <td>20088</td>\n",
|
| 704 |
+
" <td>20161</td>\n",
|
| 705 |
+
" <td>2009-05-12 00:00:00.000</td>\n",
|
| 706 |
+
" <td>1970-08-25 00:00:00.000</td>\n",
|
| 707 |
+
" </tr>\n",
|
| 708 |
+
" <tr>\n",
|
| 709 |
+
" <th>2</th>\n",
|
| 710 |
+
" <td>45889B18F2C615A78371E1DAFC2680C0A36284C6195885...</td>\n",
|
| 711 |
+
" <td>9</td>\n",
|
| 712 |
+
" <td>15834</td>\n",
|
| 713 |
+
" <td>Busia County Referral Hospital</td>\n",
|
| 714 |
+
" <td>Busia</td>\n",
|
| 715 |
+
" <td>Matayos</td>\n",
|
| 716 |
+
" <td>USAID Dumisha Afya</td>\n",
|
| 717 |
+
" <td>USAID</td>\n",
|
| 718 |
+
" <td>Female</td>\n",
|
| 719 |
+
" <td>NULL</td>\n",
|
| 720 |
+
" <td>NULL</td>\n",
|
| 721 |
+
" <td>NULL</td>\n",
|
| 722 |
+
" <td>NULL</td>\n",
|
| 723 |
+
" <td>NULL</td>\n",
|
| 724 |
+
" <td>ACTIVE</td>\n",
|
| 725 |
+
" <td>20088</td>\n",
|
| 726 |
+
" <td>20161</td>\n",
|
| 727 |
+
" <td>2014-08-12 00:00:00.000</td>\n",
|
| 728 |
+
" <td>1972-04-13 00:00:00.000</td>\n",
|
| 729 |
+
" </tr>\n",
|
| 730 |
+
" <tr>\n",
|
| 731 |
+
" <th>3</th>\n",
|
| 732 |
+
" <td>9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772...</td>\n",
|
| 733 |
+
" <td>7</td>\n",
|
| 734 |
+
" <td>14831</td>\n",
|
| 735 |
+
" <td>Kericho District Hospital</td>\n",
|
| 736 |
+
" <td>Kericho</td>\n",
|
| 737 |
+
" <td>Ainamoi</td>\n",
|
| 738 |
+
" <td>HJF-South Rift Valley</td>\n",
|
| 739 |
+
" <td>DOD</td>\n",
|
| 740 |
+
" <td>Female</td>\n",
|
| 741 |
+
" <td>MARRIED MONOGAMOUS</td>\n",
|
| 742 |
+
" <td>NULL</td>\n",
|
| 743 |
+
" <td>Trader</td>\n",
|
| 744 |
+
" <td>NULL</td>\n",
|
| 745 |
+
" <td>NULL</td>\n",
|
| 746 |
+
" <td>ACTIVE</td>\n",
|
| 747 |
+
" <td>20088</td>\n",
|
| 748 |
+
" <td>20161</td>\n",
|
| 749 |
+
" <td>2023-05-10 00:00:00.000</td>\n",
|
| 750 |
+
" <td>1989-06-15 00:00:00.000</td>\n",
|
| 751 |
+
" </tr>\n",
|
| 752 |
+
" <tr>\n",
|
| 753 |
+
" <th>4</th>\n",
|
| 754 |
+
" <td>A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C...</td>\n",
|
| 755 |
+
" <td>1</td>\n",
|
| 756 |
+
" <td>11259</td>\n",
|
| 757 |
+
" <td>Bomu Medical Centre (Likoni)</td>\n",
|
| 758 |
+
" <td>Mombasa</td>\n",
|
| 759 |
+
" <td>Likoni</td>\n",
|
| 760 |
+
" <td>Mkomani Clinic society</td>\n",
|
| 761 |
+
" <td>CDC</td>\n",
|
| 762 |
+
" <td>Female</td>\n",
|
| 763 |
+
" <td>MARRIED MONOGAMOUS</td>\n",
|
| 764 |
+
" <td>PRIMARY</td>\n",
|
| 765 |
+
" <td>Trader</td>\n",
|
| 766 |
+
" <td>NULL</td>\n",
|
| 767 |
+
" <td>NULL</td>\n",
|
| 768 |
+
" <td>UNDOCUMENTED LOSS</td>\n",
|
| 769 |
+
" <td>20088</td>\n",
|
| 770 |
+
" <td>20161</td>\n",
|
| 771 |
+
" <td>2018-05-22 00:00:00.000</td>\n",
|
| 772 |
+
" <td>1995-05-21 00:00:00.000</td>\n",
|
| 773 |
+
" </tr>\n",
|
| 774 |
+
" </tbody>\n",
|
| 775 |
+
"</table>\n",
|
| 776 |
+
"</div>"
|
| 777 |
+
],
|
| 778 |
+
"text/plain": [
|
| 779 |
+
" key PatientPKHash MFLCode \\\n",
|
| 780 |
+
"0 07149C6735AA9A2B3EFB198A5DB19825E3DA3DBCDE8CB8... 3 13703 \n",
|
| 781 |
+
"1 290D316E1B41A21F58E780026971F4D86DBB3BF043A77B... 4 13028 \n",
|
| 782 |
+
"2 45889B18F2C615A78371E1DAFC2680C0A36284C6195885... 9 15834 \n",
|
| 783 |
+
"3 9C9BFF8365B05D99D4F6A62716DD1353875B8A9280A772... 7 14831 \n",
|
| 784 |
+
"4 A51AEA4EC14F999A52AF53B4B531F760992ADA406B620C... 1 11259 \n",
|
| 785 |
+
"\n",
|
| 786 |
+
" FacilityName County \\\n",
|
| 787 |
+
"0 Kisii Teaching and Referral Hospital (Level 6) Kisii \n",
|
| 788 |
+
"1 Kibera Community Health Centre - Amref Nairobi \n",
|
| 789 |
+
"2 Busia County Referral Hospital Busia \n",
|
| 790 |
+
"3 Kericho District Hospital Kericho \n",
|
| 791 |
+
"4 Bomu Medical Centre (Likoni) Mombasa \n",
|
| 792 |
+
"\n",
|
| 793 |
+
" SubCounty PartnerName AgencyName Sex \\\n",
|
| 794 |
+
"0 Kitutu Chache South LVCT Vukisha 95 CDC Female \n",
|
| 795 |
+
"1 Kibra CIHEB CONNECT CDC Female \n",
|
| 796 |
+
"2 Matayos USAID Dumisha Afya USAID Female \n",
|
| 797 |
+
"3 Ainamoi HJF-South Rift Valley DOD Female \n",
|
| 798 |
+
"4 Likoni Mkomani Clinic society CDC Female \n",
|
| 799 |
+
"\n",
|
| 800 |
+
" MaritalStatus EducationLevel Occupation OnIPT AgeGroup \\\n",
|
| 801 |
+
"0 Single NULL NULL NULL NULL \n",
|
| 802 |
+
"1 MARRIED MONOGAMOUS SECONDARY Trader NULL NULL \n",
|
| 803 |
+
"2 NULL NULL NULL NULL NULL \n",
|
| 804 |
+
"3 MARRIED MONOGAMOUS NULL Trader NULL NULL \n",
|
| 805 |
+
"4 MARRIED MONOGAMOUS PRIMARY Trader NULL NULL \n",
|
| 806 |
+
"\n",
|
| 807 |
+
" ARTOutcomeDescription AsOfDate LoadDate StartARTDate \\\n",
|
| 808 |
+
"0 LOST IN HMIS 20088 20161 2012-04-12 00:00:00.000 \n",
|
| 809 |
+
"1 ACTIVE 20088 20161 2009-05-12 00:00:00.000 \n",
|
| 810 |
+
"2 ACTIVE 20088 20161 2014-08-12 00:00:00.000 \n",
|
| 811 |
+
"3 ACTIVE 20088 20161 2023-05-10 00:00:00.000 \n",
|
| 812 |
+
"4 UNDOCUMENTED LOSS 20088 20161 2018-05-22 00:00:00.000 \n",
|
| 813 |
+
"\n",
|
| 814 |
+
" DOB \n",
|
| 815 |
+
"0 2010-05-10 00:00:00.000 \n",
|
| 816 |
+
"1 1970-08-25 00:00:00.000 \n",
|
| 817 |
+
"2 1972-04-13 00:00:00.000 \n",
|
| 818 |
+
"3 1989-06-15 00:00:00.000 \n",
|
| 819 |
+
"4 1995-05-21 00:00:00.000 "
|
| 820 |
+
]
|
| 821 |
+
},
|
| 822 |
+
"execution_count": 32,
|
| 823 |
+
"metadata": {},
|
| 824 |
+
"output_type": "execute_result"
|
| 825 |
+
}
|
| 826 |
+
],
|
| 827 |
+
"source": [
|
| 828 |
+
"cursor.execute(\"select * from demographics;\")\n",
|
| 829 |
+
"rows = cursor.fetchall()\n",
|
| 830 |
+
"df = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])\n",
|
| 831 |
+
"df.head()"
|
| 832 |
+
]
|
| 833 |
+
},
|
| 834 |
+
{
|
| 835 |
+
"cell_type": "code",
|
| 836 |
+
"execution_count": 31,
|
| 837 |
"id": "947c63d5",
|
| 838 |
"metadata": {},
|
| 839 |
"outputs": [],
|
|
|
|
| 843 |
"cursor.execute('DROP TABLE IF EXISTS demographics;')\n",
|
| 844 |
"cursor.execute('''\n",
|
| 845 |
"CREATE TABLE demographics (\n",
|
| 846 |
+
" key TEXT,\n",
|
| 847 |
" PatientPKHash TEXT,\n",
|
| 848 |
" MFLCode TEXT,\n",
|
| 849 |
" FacilityName TEXT,\n",
|
|
|
|
| 861 |
" AsOfDate TEXT,\n",
|
| 862 |
" LoadDate TEXT,\n",
|
| 863 |
" StartARTDate TEXT,\n",
|
| 864 |
+
" DOB TEXT\n",
|
|
|
|
| 865 |
");\n",
|
| 866 |
"''')\n",
|
| 867 |
"\n",
|
| 868 |
"# let's now populate the table with the rows variable that contains all the data from the visits table\n",
|
| 869 |
"cursor.executemany('''\n",
|
| 870 |
+
"INSERT INTO demographics (key, PatientPKHash, MFLCode, FacilityName, County, SubCounty, PartnerName, AgencyName, Sex, MaritalStatus, EducationLevel, Occupation, OnIPT, AgeGroup, ARTOutcomeDescription, AsOfDate, LoadDate, StartARTDate, DOB)\n",
|
| 871 |
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);\n",
|
| 872 |
"''', rows)\n",
|
| 873 |
"conn.commit()"
|
|
|
|
| 875 |
},
|
| 876 |
{
|
| 877 |
"cell_type": "code",
|
| 878 |
+
"execution_count": 18,
|
| 879 |
"id": "9cff0d90",
|
| 880 |
"metadata": {},
|
| 881 |
"outputs": [],
|
|
|
|
| 915 |
],
|
| 916 |
"metadata": {
|
| 917 |
"kernelspec": {
|
| 918 |
+
"display_name": "clinician-assistant-lg",
|
| 919 |
"language": "python",
|
| 920 |
"name": "python3"
|
| 921 |
},
|
notebooks/create_slim_patient_db.ipynb
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "c867740b",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
|
@@ -10,7 +10,7 @@
|
|
| 10 |
"import sqlite3\n",
|
| 11 |
"import pandas as pd\n",
|
| 12 |
"# inspect current database schema\n",
|
| 13 |
-
"conn = sqlite3.connect('iit_test.sqlite')\n",
|
| 14 |
"cursor = conn.cursor()\n",
|
| 15 |
"# list tables\n",
|
| 16 |
"# pull all data from the visits table \n",
|
|
@@ -22,7 +22,7 @@
|
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"cell_type": "code",
|
| 25 |
-
"execution_count":
|
| 26 |
"id": "f424fcf6",
|
| 27 |
"metadata": {},
|
| 28 |
"outputs": [
|
|
@@ -30,7 +30,7 @@
|
|
| 30 |
"name": "stderr",
|
| 31 |
"output_type": "stream",
|
| 32 |
"text": [
|
| 33 |
-
"/tmp/
|
| 34 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 35 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 36 |
"\n",
|
|
@@ -53,35 +53,14 @@
|
|
| 53 |
"sampled_df['PatientPKHash'] = sampled_df['PatientPKHash'].map(key_to_number)\n",
|
| 54 |
"\n",
|
| 55 |
"# save sampled_df back to iit_test.sqlite as a new table called sampled_visits\n",
|
| 56 |
-
"sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 57 |
"sampled_df.to_sql('visits', sampled_conn, if_exists='replace', index=False)\n",
|
| 58 |
"sampled_conn.close()"
|
| 59 |
]
|
| 60 |
},
|
| 61 |
{
|
| 62 |
"cell_type": "code",
|
| 63 |
-
"execution_count":
|
| 64 |
-
"id": "8615f9fa",
|
| 65 |
-
"metadata": {},
|
| 66 |
-
"outputs": [
|
| 67 |
-
{
|
| 68 |
-
"data": {
|
| 69 |
-
"text/plain": [
|
| 70 |
-
"(271, 25)"
|
| 71 |
-
]
|
| 72 |
-
},
|
| 73 |
-
"execution_count": 23,
|
| 74 |
-
"metadata": {},
|
| 75 |
-
"output_type": "execute_result"
|
| 76 |
-
}
|
| 77 |
-
],
|
| 78 |
-
"source": [
|
| 79 |
-
"sampled_df.shape"
|
| 80 |
-
]
|
| 81 |
-
},
|
| 82 |
-
{
|
| 83 |
-
"cell_type": "code",
|
| 84 |
-
"execution_count": 24,
|
| 85 |
"id": "1bad1098",
|
| 86 |
"metadata": {},
|
| 87 |
"outputs": [
|
|
@@ -89,7 +68,7 @@
|
|
| 89 |
"name": "stderr",
|
| 90 |
"output_type": "stream",
|
| 91 |
"text": [
|
| 92 |
-
"/tmp/
|
| 93 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 94 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 95 |
"\n",
|
|
@@ -100,7 +79,7 @@
|
|
| 100 |
],
|
| 101 |
"source": [
|
| 102 |
"# now, read in pharmacy table from iit_test.sqlite\n",
|
| 103 |
-
"conn = sqlite3.connect('iit_test.sqlite')\n",
|
| 104 |
"cursor = conn.cursor()\n",
|
| 105 |
"cursor.execute(\"SELECT * FROM pharmacy;\")\n",
|
| 106 |
"rows = cursor.fetchall()\n",
|
|
@@ -110,46 +89,14 @@
|
|
| 110 |
"# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_pharmacy\n",
|
| 111 |
"sampled_pharmacy_df = pharmacy_df[pharmacy_df['PatientPKHash'].isin(sampled_keys)]\n",
|
| 112 |
"sampled_pharmacy_df['PatientPKHash'] = sampled_pharmacy_df['PatientPKHash'].map(key_to_number)\n",
|
| 113 |
-
"sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 114 |
"sampled_pharmacy_df.to_sql('pharmacy', sampled_conn, if_exists='replace', index=False)\n",
|
| 115 |
"sampled_conn.close()\n"
|
| 116 |
]
|
| 117 |
},
|
| 118 |
{
|
| 119 |
"cell_type": "code",
|
| 120 |
-
"execution_count":
|
| 121 |
-
"id": "bc8fac93",
|
| 122 |
-
"metadata": {},
|
| 123 |
-
"outputs": [
|
| 124 |
-
{
|
| 125 |
-
"data": {
|
| 126 |
-
"text/plain": [
|
| 127 |
-
"PatientPKHash\n",
|
| 128 |
-
"1 14\n",
|
| 129 |
-
"2 24\n",
|
| 130 |
-
"3 24\n",
|
| 131 |
-
"4 9\n",
|
| 132 |
-
"5 40\n",
|
| 133 |
-
"6 1\n",
|
| 134 |
-
"7 15\n",
|
| 135 |
-
"8 1\n",
|
| 136 |
-
"9 64\n",
|
| 137 |
-
"10 14\n",
|
| 138 |
-
"dtype: int64"
|
| 139 |
-
]
|
| 140 |
-
},
|
| 141 |
-
"execution_count": 25,
|
| 142 |
-
"metadata": {},
|
| 143 |
-
"output_type": "execute_result"
|
| 144 |
-
}
|
| 145 |
-
],
|
| 146 |
-
"source": [
|
| 147 |
-
"sampled_pharmacy_df.groupby('PatientPKHash').size()"
|
| 148 |
-
]
|
| 149 |
-
},
|
| 150 |
-
{
|
| 151 |
-
"cell_type": "code",
|
| 152 |
-
"execution_count": 26,
|
| 153 |
"id": "df01b886",
|
| 154 |
"metadata": {},
|
| 155 |
"outputs": [
|
|
@@ -157,7 +104,7 @@
|
|
| 157 |
"name": "stderr",
|
| 158 |
"output_type": "stream",
|
| 159 |
"text": [
|
| 160 |
-
"/tmp/
|
| 161 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 162 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 163 |
"\n",
|
|
@@ -168,7 +115,7 @@
|
|
| 168 |
],
|
| 169 |
"source": [
|
| 170 |
"# repeat the process above for lab table\n",
|
| 171 |
-
"conn = sqlite3.connect('iit_test.sqlite')\n",
|
| 172 |
"cursor = conn.cursor()\n",
|
| 173 |
"cursor.execute(\"SELECT * FROM lab;\")\n",
|
| 174 |
"rows = cursor.fetchall()\n",
|
|
@@ -178,46 +125,14 @@
|
|
| 178 |
"# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_lab\n",
|
| 179 |
"sampled_lab_df = lab_df[lab_df['PatientPKHash'].isin(sampled_keys)]\n",
|
| 180 |
"sampled_lab_df['PatientPKHash'] = sampled_lab_df['PatientPKHash'].map(key_to_number)\n",
|
| 181 |
-
"sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 182 |
"sampled_lab_df.to_sql('lab', sampled_conn, if_exists='replace', index=False)\n",
|
| 183 |
"sampled_conn.close()\n"
|
| 184 |
]
|
| 185 |
},
|
| 186 |
{
|
| 187 |
"cell_type": "code",
|
| 188 |
-
"execution_count":
|
| 189 |
-
"id": "2578bf85",
|
| 190 |
-
"metadata": {},
|
| 191 |
-
"outputs": [
|
| 192 |
-
{
|
| 193 |
-
"data": {
|
| 194 |
-
"text/plain": [
|
| 195 |
-
"PatientPKHash\n",
|
| 196 |
-
"1 6\n",
|
| 197 |
-
"2 2\n",
|
| 198 |
-
"3 17\n",
|
| 199 |
-
"4 22\n",
|
| 200 |
-
"5 23\n",
|
| 201 |
-
"6 1\n",
|
| 202 |
-
"7 2\n",
|
| 203 |
-
"8 10\n",
|
| 204 |
-
"9 13\n",
|
| 205 |
-
"10 12\n",
|
| 206 |
-
"dtype: int64"
|
| 207 |
-
]
|
| 208 |
-
},
|
| 209 |
-
"execution_count": 27,
|
| 210 |
-
"metadata": {},
|
| 211 |
-
"output_type": "execute_result"
|
| 212 |
-
}
|
| 213 |
-
],
|
| 214 |
-
"source": [
|
| 215 |
-
"sampled_lab_df.groupby('PatientPKHash').size()"
|
| 216 |
-
]
|
| 217 |
-
},
|
| 218 |
-
{
|
| 219 |
-
"cell_type": "code",
|
| 220 |
-
"execution_count": 28,
|
| 221 |
"id": "ebf358c5",
|
| 222 |
"metadata": {},
|
| 223 |
"outputs": [
|
|
@@ -225,7 +140,7 @@
|
|
| 225 |
"name": "stderr",
|
| 226 |
"output_type": "stream",
|
| 227 |
"text": [
|
| 228 |
-
"/tmp/
|
| 229 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 230 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 231 |
"\n",
|
|
@@ -236,7 +151,7 @@
|
|
| 236 |
],
|
| 237 |
"source": [
|
| 238 |
"# now, from dem table\n",
|
| 239 |
-
"conn = sqlite3.connect('iit_test.sqlite')\n",
|
| 240 |
"cursor = conn.cursor()\n",
|
| 241 |
"cursor.execute(\"SELECT * FROM dem;\")\n",
|
| 242 |
"rows = cursor.fetchall()\n",
|
|
@@ -246,47 +161,15 @@
|
|
| 246 |
"# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_dem\n",
|
| 247 |
"sampled_dem_df = dem_df[dem_df['PatientPKHash'].isin(sampled_keys)]\n",
|
| 248 |
"sampled_dem_df['PatientPKHash'] = sampled_dem_df['PatientPKHash'].map(key_to_number)\n",
|
| 249 |
-
"sampled_conn = sqlite3.connect('patient_slim.sqlite')\n",
|
| 250 |
"sampled_dem_df.to_sql('demographics', sampled_conn, if_exists='replace', index=False)\n",
|
| 251 |
"sampled_conn.close()"
|
| 252 |
]
|
| 253 |
-
},
|
| 254 |
-
{
|
| 255 |
-
"cell_type": "code",
|
| 256 |
-
"execution_count": 29,
|
| 257 |
-
"id": "527420fa",
|
| 258 |
-
"metadata": {},
|
| 259 |
-
"outputs": [
|
| 260 |
-
{
|
| 261 |
-
"data": {
|
| 262 |
-
"text/plain": [
|
| 263 |
-
"PatientPKHash\n",
|
| 264 |
-
"1 1\n",
|
| 265 |
-
"2 1\n",
|
| 266 |
-
"3 1\n",
|
| 267 |
-
"4 1\n",
|
| 268 |
-
"5 1\n",
|
| 269 |
-
"6 1\n",
|
| 270 |
-
"7 1\n",
|
| 271 |
-
"8 1\n",
|
| 272 |
-
"9 1\n",
|
| 273 |
-
"10 1\n",
|
| 274 |
-
"dtype: int64"
|
| 275 |
-
]
|
| 276 |
-
},
|
| 277 |
-
"execution_count": 29,
|
| 278 |
-
"metadata": {},
|
| 279 |
-
"output_type": "execute_result"
|
| 280 |
-
}
|
| 281 |
-
],
|
| 282 |
-
"source": [
|
| 283 |
-
"sampled_dem_df.groupby('PatientPKHash').size()"
|
| 284 |
-
]
|
| 285 |
}
|
| 286 |
],
|
| 287 |
"metadata": {
|
| 288 |
"kernelspec": {
|
| 289 |
-
"display_name": "
|
| 290 |
"language": "python",
|
| 291 |
"name": "python3"
|
| 292 |
},
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 4,
|
| 6 |
"id": "c867740b",
|
| 7 |
"metadata": {},
|
| 8 |
"outputs": [],
|
|
|
|
| 10 |
"import sqlite3\n",
|
| 11 |
"import pandas as pd\n",
|
| 12 |
"# inspect current database schema\n",
|
| 13 |
+
"conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
|
| 14 |
"cursor = conn.cursor()\n",
|
| 15 |
"# list tables\n",
|
| 16 |
"# pull all data from the visits table \n",
|
|
|
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"cell_type": "code",
|
| 25 |
+
"execution_count": 5,
|
| 26 |
"id": "f424fcf6",
|
| 27 |
"metadata": {},
|
| 28 |
"outputs": [
|
|
|
|
| 30 |
"name": "stderr",
|
| 31 |
"output_type": "stream",
|
| 32 |
"text": [
|
| 33 |
+
"/tmp/ipykernel_12725/435846127.py:11: SettingWithCopyWarning: \n",
|
| 34 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 35 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 36 |
"\n",
|
|
|
|
| 53 |
"sampled_df['PatientPKHash'] = sampled_df['PatientPKHash'].map(key_to_number)\n",
|
| 54 |
"\n",
|
| 55 |
"# save sampled_df back to iit_test.sqlite as a new table called sampled_visits\n",
|
| 56 |
+
"sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 57 |
"sampled_df.to_sql('visits', sampled_conn, if_exists='replace', index=False)\n",
|
| 58 |
"sampled_conn.close()"
|
| 59 |
]
|
| 60 |
},
|
| 61 |
{
|
| 62 |
"cell_type": "code",
|
| 63 |
+
"execution_count": 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
"id": "1bad1098",
|
| 65 |
"metadata": {},
|
| 66 |
"outputs": [
|
|
|
|
| 68 |
"name": "stderr",
|
| 69 |
"output_type": "stream",
|
| 70 |
"text": [
|
| 71 |
+
"/tmp/ipykernel_12725/2381592446.py:11: SettingWithCopyWarning: \n",
|
| 72 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 73 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 74 |
"\n",
|
|
|
|
| 79 |
],
|
| 80 |
"source": [
|
| 81 |
"# now, read in pharmacy table from iit_test.sqlite\n",
|
| 82 |
+
"conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
|
| 83 |
"cursor = conn.cursor()\n",
|
| 84 |
"cursor.execute(\"SELECT * FROM pharmacy;\")\n",
|
| 85 |
"rows = cursor.fetchall()\n",
|
|
|
|
| 89 |
"# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_pharmacy\n",
|
| 90 |
"sampled_pharmacy_df = pharmacy_df[pharmacy_df['PatientPKHash'].isin(sampled_keys)]\n",
|
| 91 |
"sampled_pharmacy_df['PatientPKHash'] = sampled_pharmacy_df['PatientPKHash'].map(key_to_number)\n",
|
| 92 |
+
"sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 93 |
"sampled_pharmacy_df.to_sql('pharmacy', sampled_conn, if_exists='replace', index=False)\n",
|
| 94 |
"sampled_conn.close()\n"
|
| 95 |
]
|
| 96 |
},
|
| 97 |
{
|
| 98 |
"cell_type": "code",
|
| 99 |
+
"execution_count": 7,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
"id": "df01b886",
|
| 101 |
"metadata": {},
|
| 102 |
"outputs": [
|
|
|
|
| 104 |
"name": "stderr",
|
| 105 |
"output_type": "stream",
|
| 106 |
"text": [
|
| 107 |
+
"/tmp/ipykernel_12725/4028870248.py:11: SettingWithCopyWarning: \n",
|
| 108 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 109 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 110 |
"\n",
|
|
|
|
| 115 |
],
|
| 116 |
"source": [
|
| 117 |
"# repeat the process above for lab table\n",
|
| 118 |
+
"conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
|
| 119 |
"cursor = conn.cursor()\n",
|
| 120 |
"cursor.execute(\"SELECT * FROM lab;\")\n",
|
| 121 |
"rows = cursor.fetchall()\n",
|
|
|
|
| 125 |
"# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_lab\n",
|
| 126 |
"sampled_lab_df = lab_df[lab_df['PatientPKHash'].isin(sampled_keys)]\n",
|
| 127 |
"sampled_lab_df['PatientPKHash'] = sampled_lab_df['PatientPKHash'].map(key_to_number)\n",
|
| 128 |
+
"sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 129 |
"sampled_lab_df.to_sql('lab', sampled_conn, if_exists='replace', index=False)\n",
|
| 130 |
"sampled_conn.close()\n"
|
| 131 |
]
|
| 132 |
},
|
| 133 |
{
|
| 134 |
"cell_type": "code",
|
| 135 |
+
"execution_count": 8,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
"id": "ebf358c5",
|
| 137 |
"metadata": {},
|
| 138 |
"outputs": [
|
|
|
|
| 140 |
"name": "stderr",
|
| 141 |
"output_type": "stream",
|
| 142 |
"text": [
|
| 143 |
+
"/tmp/ipykernel_12725/696424165.py:11: SettingWithCopyWarning: \n",
|
| 144 |
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
| 145 |
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
| 146 |
"\n",
|
|
|
|
| 151 |
],
|
| 152 |
"source": [
|
| 153 |
"# now, from dem table\n",
|
| 154 |
+
"conn = sqlite3.connect('../data/raw/iit_test.sqlite')\n",
|
| 155 |
"cursor = conn.cursor()\n",
|
| 156 |
"cursor.execute(\"SELECT * FROM dem;\")\n",
|
| 157 |
"rows = cursor.fetchall()\n",
|
|
|
|
| 161 |
"# filter these to the same 10 keys, replace the keys with numbers 1-10, and save to patient_slim.sqlite as a new table called sampled_dem\n",
|
| 162 |
"sampled_dem_df = dem_df[dem_df['PatientPKHash'].isin(sampled_keys)]\n",
|
| 163 |
"sampled_dem_df['PatientPKHash'] = sampled_dem_df['PatientPKHash'].map(key_to_number)\n",
|
| 164 |
+
"sampled_conn = sqlite3.connect('../data/raw/patient_slim.sqlite')\n",
|
| 165 |
"sampled_dem_df.to_sql('demographics', sampled_conn, if_exists='replace', index=False)\n",
|
| 166 |
"sampled_conn.close()"
|
| 167 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
}
|
| 169 |
],
|
| 170 |
"metadata": {
|
| 171 |
"kernelspec": {
|
| 172 |
+
"display_name": "clinician-assistant-lg",
|
| 173 |
"language": "python",
|
| 174 |
"name": "python3"
|
| 175 |
},
|
chat.py → scripts/chat.py
RENAMED
|
File without changes
|
scripts/evaluate_trulens.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from llama_index.core import StorageContext, load_index_from_storage, QueryBundle
|
| 7 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
| 8 |
+
from llama_index.core.postprocessor import LLMRerank
|
| 9 |
+
from llama_index.embeddings.openai import OpenAIEmbedding
|
| 10 |
+
from llama_index.llms.openai import OpenAI
|
| 11 |
+
|
| 12 |
+
from langchain_openai import ChatOpenAI
|
| 13 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 14 |
+
|
| 15 |
+
from trulens_eval import Tru
|
| 16 |
+
from trulens.core import Feedback
|
| 17 |
+
from trulens.providers.openai import OpenAI as OpenAIFeedbackProvider
|
| 18 |
+
from trulens_eval.tru_app import TruLlama
|
| 19 |
+
|
| 20 |
+
# Load environment
|
| 21 |
+
if os.path.exists("config.env"):
|
| 22 |
+
load_dotenv("config.env")
|
| 23 |
+
|
| 24 |
+
# Load vectorstore metadata
|
| 25 |
+
embeddings = np.load("data/processed/lp/summary_embeddings/embeddings.npy")
|
| 26 |
+
df = pd.read_csv("data/processed/lp/summary_embeddings/index.tsv", sep="\t")
|
| 27 |
+
|
| 28 |
+
# LLMs and components
|
| 29 |
+
embedding_model = OpenAIEmbedding()
|
| 30 |
+
llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
|
| 31 |
+
reranker = LLMRerank(llm=llm_llama, top_n=3)
|
| 32 |
+
|
| 33 |
+
# langchain summarize LLM
|
| 34 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.0)
|
| 35 |
+
|
| 36 |
+
grounded = Feedback(Groundedness()).on_input().on_context().with_name("faithfulness")
|
| 37 |
+
context_rel = Feedback(Relevance()).on_input().on_context().with_name("context_relevance")
|
| 38 |
+
answer_rel = Feedback(AnswerRelevance()).on_input().on_output().with_name("answer_relevance")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Prompt for query expansion
|
| 42 |
+
query_expansion_prompt = ChatPromptTemplate.from_messages([
|
| 43 |
+
("system", "You are an expert in HIV medicine."),
|
| 44 |
+
("user", (
|
| 45 |
+
"Given the query below, provide a concise, comma-separated list of related terms and synonyms "
|
| 46 |
+
"useful for document retrieval. Return only the list, no explanations.\n\n"
|
| 47 |
+
"Query: {query}"
|
| 48 |
+
))
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
# ---------- Functions ----------
|
| 52 |
+
|
| 53 |
+
def cosine_similarity_numpy(query_vec: np.ndarray, matrix: np.ndarray) -> np.ndarray:
|
| 54 |
+
query_norm = query_vec / np.linalg.norm(query_vec)
|
| 55 |
+
matrix_norm = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
|
| 56 |
+
return matrix_norm @ query_norm
|
| 57 |
+
|
| 58 |
+
def expand_query(query, llm, prompt_template):
|
| 59 |
+
messages = prompt_template.format_messages(query=query)
|
| 60 |
+
return llm.invoke(messages).content.strip()
|
| 61 |
+
|
| 62 |
+
def retrieve_contexts(expanded_query, embeddings, df, embedding_model):
|
| 63 |
+
query_vec = embedding_model.get_text_embedding(expanded_query)
|
| 64 |
+
sims = cosine_similarity_numpy(query_vec, embeddings)
|
| 65 |
+
top_indices = sims.argsort()[-3:][::-1]
|
| 66 |
+
paths = df.loc[top_indices, "vectorestore_path"].tolist()
|
| 67 |
+
|
| 68 |
+
all_nodes = []
|
| 69 |
+
for path in paths:
|
| 70 |
+
ctx = StorageContext.from_defaults(persist_dir=path)
|
| 71 |
+
index = load_index_from_storage(ctx)
|
| 72 |
+
retriever = VectorIndexRetriever(index=index, similarity_top_k=3)
|
| 73 |
+
all_nodes.extend(retriever.retrieve(expanded_query))
|
| 74 |
+
|
| 75 |
+
reranked = reranker.postprocess_nodes(all_nodes, QueryBundle(expanded_query))
|
| 76 |
+
return [n.text for n in reranked]
|
| 77 |
+
|
| 78 |
+
def summarize(query, contexts, llm):
|
| 79 |
+
prompt = (
|
| 80 |
+
"You're a clinical assistant helping a provider answer a question using HIV/AIDS guidelines.\n\n"
|
| 81 |
+
f"Question: {query}\n\n"
|
| 82 |
+
"Provide a detailed summary of the most relevant points to the user question from the following source texts. Use bullet points.\n\n"
|
| 83 |
+
+ "\n\n".join([f"Source {i+1}: {t}" for i, t in enumerate(contexts)])
|
| 84 |
+
)
|
| 85 |
+
return llm.invoke(prompt).content.strip()
|
| 86 |
+
|
| 87 |
+
# ---------- RAG Pipeline ----------
|
| 88 |
+
|
| 89 |
+
def custom_rag_app(query):
|
| 90 |
+
expanded = expand_query(query, llm, query_expansion_prompt)
|
| 91 |
+
contexts = retrieve_contexts(expanded, embeddings, df, embedding_model)
|
| 92 |
+
answer = summarize(query, contexts, llm)
|
| 93 |
+
return {
|
| 94 |
+
"question": query,
|
| 95 |
+
"expanded_query": expanded,
|
| 96 |
+
"contexts": contexts,
|
| 97 |
+
"answer": answer
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ---------- Feedbacks ----------
|
| 102 |
+
|
| 103 |
+
provider = OpenAIFeedbackProvider()
|
| 104 |
+
|
| 105 |
+
f_grounded = Feedback(provider.groundedness).on_input().on_context().with_name("faithfulness")
|
| 106 |
+
f_context_rel = Feedback(provider.context_relevance).on_input().on_context().with_name("context_relevance")
|
| 107 |
+
f_answer_rel = Feedback(provider.relevance).on_input().on_output().with_name("answer_relevance")
|
| 108 |
+
|
| 109 |
+
# ---------- TruLens App ----------
|
| 110 |
+
|
| 111 |
+
tru_llama = TruLlama(
|
| 112 |
+
app=custom_rag_app,
|
| 113 |
+
feedbacks=[f_grounded, f_context_rel, f_answer_rel],
|
| 114 |
+
app_id="evaluate-trulens-llama-v2"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
tru = Tru()
|
| 118 |
+
|
| 119 |
+
# ---------- Run Evaluation ----------
|
| 120 |
+
|
| 121 |
+
test_queries = [
|
| 122 |
+
"What are important drug interactions with dolutegravir?",
|
| 123 |
+
"How should PrEP be provided to adolescent girls?",
|
| 124 |
+
"When is cotrimoxazole prophylaxis indicated?",
|
| 125 |
+
"What are the guidelines for ART failure?",
|
| 126 |
+
"How do you manage HIV in pregnancy?"
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
records = []
|
| 130 |
+
|
| 131 |
+
for q in test_queries:
|
| 132 |
+
record = tru_llama.run_with_record(question=q)
|
| 133 |
+
fb = record["feedback"]
|
| 134 |
+
records.append({
|
| 135 |
+
"question": q,
|
| 136 |
+
"answer": record["output"],
|
| 137 |
+
"contexts": record["context"],
|
| 138 |
+
"faithfulness_score": fb["faithfulness"].get("score"),
|
| 139 |
+
"context_relevance_score": fb["context_relevance"].get("score"),
|
| 140 |
+
"answer_relevance_score": fb["answer_relevance"].get("score"),
|
| 141 |
+
"faithfulness_justification": fb["faithfulness"].get("justification", "")
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
df = pd.DataFrame(records)
|
| 145 |
+
df.to_csv("trulens_llama_eval_results.csv", index=False)
|
| 146 |
+
print("✅ Evaluation complete. Saved to trulens_llama_eval_results.csv")
|
| 147 |
+
print(df)
|
{chatlib → scripts}/patient_sql_agent.py
RENAMED
|
File without changes
|
scripts/ragas_eval.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# custom_rag_with_ragas.py
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
from ragas.evaluation import evaluate
|
| 7 |
+
from ragas.metrics import (
|
| 8 |
+
faithfulness,
|
| 9 |
+
answer_relevancy,
|
| 10 |
+
context_precision,
|
| 11 |
+
context_recall
|
| 12 |
+
)
|
| 13 |
+
from llama_index.core import StorageContext, load_index_from_storage, QueryBundle
|
| 14 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
| 15 |
+
from llama_index.core.postprocessor import LLMRerank
|
| 16 |
+
from llama_index.embeddings.openai import OpenAIEmbedding
|
| 17 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 18 |
+
from langchain.chat_models import ChatOpenAI
|
| 19 |
+
from llama_index.llms.openai import OpenAI
|
| 20 |
+
import os
|
| 21 |
+
from dotenv import load_dotenv
|
| 22 |
+
if os.path.exists("config.env"):
|
| 23 |
+
load_dotenv("config.env")
|
| 24 |
+
|
| 25 |
+
embeddings = np.load("data/processed/lp/summary_embeddings/embeddings.npy")
|
| 26 |
+
df = pd.read_csv("data/processed/lp/summary_embeddings/index.tsv", sep="\t")
|
| 27 |
+
|
| 28 |
+
embedding_model = OpenAIEmbedding()
|
| 29 |
+
|
| 30 |
+
# Define your reranker-compatible LLM
|
| 31 |
+
llm_llama = OpenAI(model="gpt-4o", temperature=0.0)
|
| 32 |
+
|
| 33 |
+
# Create LLM reranker
|
| 34 |
+
reranker = LLMRerank(llm=llm_llama, top_n=3)
|
| 35 |
+
|
| 36 |
+
# summarizer LLM
|
| 37 |
+
llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
|
| 38 |
+
|
| 39 |
+
# Define a prompt template for query expansion
|
| 40 |
+
query_expansion_prompt = ChatPromptTemplate.from_messages([
|
| 41 |
+
("system", "You are an expert in HIV medicine."),
|
| 42 |
+
("user", (
|
| 43 |
+
"Given the query below, provide a concise, comma-separated list of related terms and synonyms "
|
| 44 |
+
"useful for document retrieval. Return only the list, no explanations.\n\n"
|
| 45 |
+
"Query: {query}"
|
| 46 |
+
))
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
def cosine_similarity_numpy(query_vec: np.ndarray, matrix: np.ndarray) -> np.ndarray:
|
| 50 |
+
query_norm = query_vec / np.linalg.norm(query_vec)
|
| 51 |
+
matrix_norm = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
|
| 52 |
+
return matrix_norm @ query_norm
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def expand_query(query, llm, prompt_template):
|
| 56 |
+
messages = prompt_template.format_messages(query=query)
|
| 57 |
+
return llm.invoke(messages).content.strip()
|
| 58 |
+
|
| 59 |
+
def retrieve_contexts(expanded_query, embeddings, df, embedding_model):
|
| 60 |
+
query_vec = embedding_model.get_text_embedding(expanded_query)
|
| 61 |
+
similarities = cosine_similarity_numpy(query_vec, embeddings)
|
| 62 |
+
top_indices = similarities.argsort()[-3:][::-1]
|
| 63 |
+
paths = df.loc[top_indices, "vectorestore_path"].tolist()
|
| 64 |
+
print(paths)
|
| 65 |
+
all_nodes = []
|
| 66 |
+
for path in paths:
|
| 67 |
+
ctx = StorageContext.from_defaults(persist_dir=path)
|
| 68 |
+
index = load_index_from_storage(ctx)
|
| 69 |
+
retriever = VectorIndexRetriever(index=index, similarity_top_k=3)
|
| 70 |
+
all_nodes.extend(retriever.retrieve(expanded_query))
|
| 71 |
+
|
| 72 |
+
return [n.text for n in LLMRerank(llm=llm_llama, top_n=3).postprocess_nodes(all_nodes, QueryBundle(expanded_query))]
|
| 73 |
+
|
| 74 |
+
def summarize(query, contexts, llm):
|
| 75 |
+
prompt = (
|
| 76 |
+
"You're a clinical assistant helping a provider answer a question using HIV/AIDS guidelines.\n\n"
|
| 77 |
+
f"Question: {query}\n\n"
|
| 78 |
+
"Provide a detailed summary of the most relevant points from the following source texts using bullet points.\n\n"
|
| 79 |
+
+ "\n\n".join([f"Source {i+1}: {text}" for i, text in enumerate(contexts)])
|
| 80 |
+
)
|
| 81 |
+
return llm.invoke(prompt).content.strip()
|
| 82 |
+
|
| 83 |
+
# Run on test queries
|
| 84 |
+
test_queries = [
|
| 85 |
+
"What are important drug interactions with dolutegravir?",
|
| 86 |
+
"How should PrEP be provided to adolescent girls?",
|
| 87 |
+
"When is cotrimoxazole prophylaxis indicated?",
|
| 88 |
+
"What are the guidelines for ART failure?",
|
| 89 |
+
"How do you manage HIV in pregnancy?"
|
| 90 |
+
]
|
| 91 |
+
results = []
|
| 92 |
+
|
| 93 |
+
for q in test_queries:
|
| 94 |
+
print(f"⏳ Processing: {q}")
|
| 95 |
+
expanded = expand_query(q, llm, query_expansion_prompt)
|
| 96 |
+
contexts = retrieve_contexts(expanded, embeddings, df, embedding_model)
|
| 97 |
+
answer = summarize(q, contexts, llm)
|
| 98 |
+
results.append({
|
| 99 |
+
"question": q,
|
| 100 |
+
"contexts": contexts,
|
| 101 |
+
"answer": answer
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
# --- Ragas Evaluation ---
|
| 105 |
+
print("✅ Running Ragas evaluation...")
|
| 106 |
+
|
| 107 |
+
ragas_data = Dataset.from_list(results)
|
| 108 |
+
|
| 109 |
+
eval_results = evaluate(
|
| 110 |
+
ragas_data,
|
| 111 |
+
metrics=[faithfulness, answer_relevancy]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
df_eval = eval_results.to_pandas()
|
| 115 |
+
df_eval.to_csv("ragas_eval_results.csv", index=False)
|
| 116 |
+
|
| 117 |
+
print("✅ Evaluation complete. Saved to ragas_eval_results.csv")
|
| 118 |
+
print(df_eval)
|