Spaces:
Sleeping
Sleeping
JDFPalladium commited on
Commit ·
97facdb
1
Parent(s): 7023043
adding sql pull from start
Browse files- chatlib/patient_all_data.py +162 -0
- chatlib/state_types.py +0 -13
- iit_test.sqlite +0 -0
- main.py +14 -21
- patient_demonstration.sqlite +0 -0
chatlib/patient_all_data.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
|
| 6 |
+
|
| 7 |
+
from .state_types import AppState
|
| 8 |
+
|
| 9 |
+
# Define the SQL query tool
|
| 10 |
+
def sql_chain(state: AppState) -> AppState:
|
| 11 |
+
"""
|
| 12 |
+
Annotated function that takes a patient identifer (pk_hash) and returns
|
| 13 |
+
all data related to that patient from the SQL database.
|
| 14 |
+
It writes an SQL query to retrieve relevant data, executes the query,
|
| 15 |
+
and generates a natural language answer based on the query results.
|
| 16 |
+
Returns the final answer as a string.
|
| 17 |
+
The function uses the QuerySQLDatabaseTool to handle the SQL operations.
|
| 18 |
+
The state should contain the following fields:
|
| 19 |
+
- question: str - the question seeking information on patient data
|
| 20 |
+
- pk_hash: str - the patient identifier to query the database
|
| 21 |
+
- rag_result: str - context information from the guidelines retrieval
|
| 22 |
+
The function will update the state with the answer to the question.
|
| 23 |
+
The answer will be generated based on the SQL query results and the context information.
|
| 24 |
+
The function will return the updated state with the answer.
|
| 25 |
+
"""
|
| 26 |
+
pk_hash = state.get("pk_hash")
|
| 27 |
+
if not pk_hash:
|
| 28 |
+
raise ValueError("pk_hash is required in state for SQL queries.")
|
| 29 |
+
|
| 30 |
+
conn = sqlite3.connect('data/patient_demonstration.sqlite')
|
| 31 |
+
cursor = conn.cursor()
|
| 32 |
+
|
| 33 |
+
# Write the SQL query using the QuerySQLDatabaseTool
|
| 34 |
+
cursor.execute("SELECT * FROM clinical_visits WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
|
| 35 |
+
rows = cursor.fetchall()
|
| 36 |
+
visits_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
|
| 37 |
+
|
| 38 |
+
def summarize_visits(df):
|
| 39 |
+
if df.empty:
|
| 40 |
+
return "No clinical visit data available."
|
| 41 |
+
|
| 42 |
+
def safe(val):
|
| 43 |
+
if pd.isnull(val) or val in ("", "NULL"):
|
| 44 |
+
return 'missing'
|
| 45 |
+
return val
|
| 46 |
+
|
| 47 |
+
summaries = []
|
| 48 |
+
for _, row in df.sort_values("VisitDate", ascending=False).head(5).iterrows():
|
| 49 |
+
summaries.append(f"- {row['VisitDate']}: WHO Stage {safe(row['WHOStage'])}, Weight {safe(row['Weight'])}kg, "
|
| 50 |
+
f"NextAppointmentDate {safe(row['NextAppointmentDate'])}, VisityType {safe(row['VisitType'])}, "
|
| 51 |
+
f"VisitBy {safe(row['VisitBy'])}, Pregnant {safe(row['Pregnant'])}, Breastfeeding {safe(row['Breastfeeding'])}, "
|
| 52 |
+
f"WHOStage {safe(row['WHOStage'])}, StabilityAssessment {safe(row['StabilityAssessment'])}, "
|
| 53 |
+
f"DifferentiatedCare {safe(row['DifferentiatedCare'])}, WHOStagingOI {safe(row['WHOStagingOI'])}, "
|
| 54 |
+
f"Height {safe(row['Height'])}cm, Adherence {safe(row['Adherence'])}, BP {safe(row['BP'])}, "
|
| 55 |
+
f"OI {safe(row['OI'])}, CurrentRegimen {safe(row['CurrentRegimen'])}"
|
| 56 |
+
)
|
| 57 |
+
return "\n".join(summaries)
|
| 58 |
+
|
| 59 |
+
visits_summary = summarize_visits(visits_data)
|
| 60 |
+
print(visits_summary)
|
| 61 |
+
|
| 62 |
+
cursor.execute("SELECT * FROM pharmacy WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
|
| 63 |
+
rows = cursor.fetchall()
|
| 64 |
+
pharmacy_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
|
| 65 |
+
|
| 66 |
+
def summarize_pharmacy(df):
|
| 67 |
+
if df.empty:
|
| 68 |
+
return "No pharmacy data available."
|
| 69 |
+
|
| 70 |
+
def safe(val):
|
| 71 |
+
if pd.isnull(val) or val in ("", "NULL"):
|
| 72 |
+
return 'missing'
|
| 73 |
+
return val
|
| 74 |
+
|
| 75 |
+
summaries = []
|
| 76 |
+
for _, row in df.sort_values("DispenseDate", ascending=False).head(5).iterrows():
|
| 77 |
+
summaries.append(f"- {row['DispenseDate']}: ExpectedReturn {safe(row['ExpectedReturn'])}, Drug {safe(row['Drug'])}, "
|
| 78 |
+
f"Duration {safe(row['Duration'])}, TreatmentType {safe(row['TreatmentType'])}, "
|
| 79 |
+
f"RegimenLine {safe(row['RegimenLine'])}, "
|
| 80 |
+
f"RegimenChangedSwitched {safe(row['RegimenChangedSwitched'])}, "
|
| 81 |
+
f"RegimenChangeSwitchedReason {safe(row['RegimenChangeSwitchedReason'])}, "
|
| 82 |
+
)
|
| 83 |
+
return "\n".join(summaries)
|
| 84 |
+
|
| 85 |
+
pharmacy_summary = summarize_pharmacy(pharmacy_data)
|
| 86 |
+
print(pharmacy_summary)
|
| 87 |
+
|
| 88 |
+
cursor.execute("SELECT * FROM lab WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
|
| 89 |
+
rows = cursor.fetchall()
|
| 90 |
+
lab_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
|
| 91 |
+
|
| 92 |
+
def summarize_lab(df):
|
| 93 |
+
if df.empty:
|
| 94 |
+
return "No lab data available."
|
| 95 |
+
|
| 96 |
+
def safe(val):
|
| 97 |
+
if pd.isnull(val) or val in ("", "NULL"):
|
| 98 |
+
return 'missing'
|
| 99 |
+
return val
|
| 100 |
+
|
| 101 |
+
summaries = []
|
| 102 |
+
for _, row in df.sort_values("OrderedbyDate", ascending=False).head(5).iterrows():
|
| 103 |
+
summaries.append(f"- {row['OrderedbyDate']}: TestName {safe(row['TestName'])}, TestResult {safe(row['TestResult'])},"
|
| 104 |
+
)
|
| 105 |
+
return "\n".join(summaries)
|
| 106 |
+
|
| 107 |
+
lab_summary = summarize_lab(lab_data)
|
| 108 |
+
print(lab_summary)
|
| 109 |
+
|
| 110 |
+
cursor.execute("SELECT * FROM demographics WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
|
| 111 |
+
rows = cursor.fetchall()
|
| 112 |
+
demographic_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
|
| 113 |
+
|
| 114 |
+
def summarize_demographics(df):
|
| 115 |
+
if df.empty:
|
| 116 |
+
return "No demographic data available."
|
| 117 |
+
|
| 118 |
+
def safe(val):
|
| 119 |
+
if pd.isnull(val) or val in ("", "NULL"):
|
| 120 |
+
return 'missing'
|
| 121 |
+
return val
|
| 122 |
+
|
| 123 |
+
row = df.iloc[0]
|
| 124 |
+
summary = (
|
| 125 |
+
f"Sex: {safe(row['Sex'].values[0])}\n"
|
| 126 |
+
f"MaritalStatus: {safe(row['MaritalStatus'].values[0])}\n"
|
| 127 |
+
f"EducationLevel: {safe(row['EducationLevel'].values[0])}\n"
|
| 128 |
+
f"Occupation: {safe(row['Occupation'].values[0])}\n"
|
| 129 |
+
f"OnIPT: {safe(row['OnIPT'].values[0])}\n"
|
| 130 |
+
f"ARTOutcomeDescription: {safe(row['ARTOutcomeDescription'].values[0])}\n"
|
| 131 |
+
f"StartARTDate: {safe(row['StartARTDate'].values[0])}\n"
|
| 132 |
+
f"Date Of Birth: {safe(row['DOB'].values[0])}"
|
| 133 |
+
)
|
| 134 |
+
return summary
|
| 135 |
+
|
| 136 |
+
demographic_summary = summarize_demographics(demographic_data)
|
| 137 |
+
print(demographic_summary)
|
| 138 |
+
|
| 139 |
+
# cursor.execute("SELECT * FROM data_dictionary")
|
| 140 |
+
# rows = cursor.fetchall()
|
| 141 |
+
# data_dictionary = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
|
| 142 |
+
|
| 143 |
+
conn.close()
|
| 144 |
+
|
| 145 |
+
prompt = (
|
| 146 |
+
"Given the following user question, contextual clinical guidance, "
|
| 147 |
+
"patient clinical data, patient lab data, patient pharmacy data, "
|
| 148 |
+
"patient demographic data, answer the user question. "
|
| 149 |
+
"Try to answer based on the provided data."
|
| 150 |
+
"If there is essential patient information missing that you need in order to answer, "
|
| 151 |
+
"do not provide an answer and instead explain what information is missing. \n\n"
|
| 152 |
+
f'Question: {state["question"]}\n'
|
| 153 |
+
f'Context: {state.get("rag_result", "No guidelines provided.")}\n'
|
| 154 |
+
f'Patient Clinical Visits: {visits_summary}\n'
|
| 155 |
+
f'Patient Pharmacy Data: {pharmacy_summary}\n'
|
| 156 |
+
f'Patient Lab Data: {lab_summary}\n'
|
| 157 |
+
f'Patient Demographic Data: {demographic_summary}\n'
|
| 158 |
+
# f'Data Dictionary: {data_dictionary}\n'
|
| 159 |
+
)
|
| 160 |
+
response = llm.invoke(prompt)
|
| 161 |
+
state["answer"] = response.content
|
| 162 |
+
return state
|
chatlib/state_types.py
CHANGED
|
@@ -27,18 +27,5 @@ class AppState(TypedDict):
|
|
| 27 |
messages: Annotated[list[AnyMessage], add_messages]
|
| 28 |
question: str
|
| 29 |
rag_result: str
|
| 30 |
-
query: str
|
| 31 |
-
result: str
|
| 32 |
answer: str
|
| 33 |
pk_hash: str
|
| 34 |
-
|
| 35 |
-
# initialize state with patient pk hash
|
| 36 |
-
# input_state:State = {
|
| 37 |
-
# "messages": [HumanMessage(content="was this person typically late or on time to their visits?")],
|
| 38 |
-
# "question": "",
|
| 39 |
-
# "rag_result": "",
|
| 40 |
-
# "query": "",
|
| 41 |
-
# "result": "",
|
| 42 |
-
# "answer": "",
|
| 43 |
-
# "pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
|
| 44 |
-
# }
|
|
|
|
| 27 |
messages: Annotated[list[AnyMessage], add_messages]
|
| 28 |
question: str
|
| 29 |
rag_result: str
|
|
|
|
|
|
|
| 30 |
answer: str
|
| 31 |
pk_hash: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iit_test.sqlite
ADDED
|
File without changes
|
main.py
CHANGED
|
@@ -5,6 +5,7 @@ from langgraph.graph import START, StateGraph
|
|
| 5 |
from langchain_core.messages import HumanMessage, SystemMessage
|
| 6 |
from langgraph.prebuilt import tools_condition, ToolNode
|
| 7 |
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
| 8 |
memory = MemorySaver()
|
| 9 |
|
| 10 |
load_dotenv("config.env")
|
|
@@ -13,7 +14,7 @@ os.environ.get("LANGSMITH_API_KEY")
|
|
| 13 |
|
| 14 |
from chatlib.state_types import AppState
|
| 15 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 16 |
-
from chatlib.
|
| 17 |
|
| 18 |
# from langchain_ollama.chat_models import ChatOllama
|
| 19 |
# llm = ChatOllama(model="mistral:latest", temperature=0)
|
|
@@ -27,25 +28,19 @@ sys_msg = SystemMessage(content="""
|
|
| 27 |
You are a helpful assistant tasked with helping clinicians
|
| 28 |
meeting with patients. You have two tools available,
|
| 29 |
rag_retrieve to access information from HIV clinical guidelines,
|
| 30 |
-
and sql_chain to access patient data.
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
In these cases, first call rag_retrieve to get the relevant information,
|
| 34 |
-
then call sql_chain to get the patient data, and finally combine the results
|
| 35 |
-
to provide a complete answer. For example, if the question is about whether
|
| 36 |
-
a patient is on the correct treatment, first retrieve the treatment guidelines
|
| 37 |
-
using rag_retrieve, then check the patient's treatment history using sql_chain.
|
| 38 |
-
Another example is if the question is about when they should have their next viral load test,
|
| 39 |
-
first retrieve the guidelines for viral load testing using rag_retrieve,
|
| 40 |
-
then check the patient's last viral load test date and result using sql_chain.
|
| 41 |
|
| 42 |
You must respond only with a JSON object specifying the tool to call and its arguments.
|
| 43 |
-
Do not generate any SQL queries, results or answers yourself. Only the sql_chain
|
| 44 |
-
tool should do that.
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
"""
|
| 51 |
)
|
|
@@ -90,15 +85,13 @@ builder.add_edge("tools", "assistant")
|
|
| 90 |
react_graph = builder.compile(checkpointer=memory)
|
| 91 |
|
| 92 |
# Specify a thread
|
| 93 |
-
config = {"configurable": {"thread_id": "
|
| 94 |
|
| 95 |
# initialize state with patient pk hash
|
| 96 |
input_state:AppState = {
|
| 97 |
-
"messages": [HumanMessage(content="
|
| 98 |
"question": "",
|
| 99 |
"rag_result": "",
|
| 100 |
-
"query": "",
|
| 101 |
-
"result": "",
|
| 102 |
"answer": "",
|
| 103 |
"pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
|
| 104 |
}
|
|
|
|
| 5 |
from langchain_core.messages import HumanMessage, SystemMessage
|
| 6 |
from langgraph.prebuilt import tools_condition, ToolNode
|
| 7 |
from langgraph.checkpoint.memory import MemorySaver
|
| 8 |
+
|
| 9 |
memory = MemorySaver()
|
| 10 |
|
| 11 |
load_dotenv("config.env")
|
|
|
|
| 14 |
|
| 15 |
from chatlib.state_types import AppState
|
| 16 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 17 |
+
from chatlib.patient_all_data import sql_chain
|
| 18 |
|
| 19 |
# from langchain_ollama.chat_models import ChatOllama
|
| 20 |
# llm = ChatOllama(model="mistral:latest", temperature=0)
|
|
|
|
| 28 |
You are a helpful assistant tasked with helping clinicians
|
| 29 |
meeting with patients. You have two tools available,
|
| 30 |
rag_retrieve to access information from HIV clinical guidelines,
|
| 31 |
+
and sql_chain to access patient data. When a clinican asks a question about a patient,
|
| 32 |
+
you should first run rag_retrieve to get contextual information from the guidelines,
|
| 33 |
+
then use sql_chain to query the patient's data from the SQL database.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
You must respond only with a JSON object specifying the tool to call and its arguments.
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
Keep your responses concise and focused on the task at hand. Remember, you are
|
| 38 |
+
talking to a clinician who needs quick and accurate information about their patient.
|
| 39 |
+
Do not tell them to consult a healthcare professional - they are the healthcare professional.
|
| 40 |
+
|
| 41 |
+
If the clinican questions is not clear, ask for clarification or more information.
|
| 42 |
+
If the clinican asks a question that is not related to the patient, then use the rag_retrieve tool
|
| 43 |
+
to provide general information about HIV clinical guidelines.
|
| 44 |
|
| 45 |
"""
|
| 46 |
)
|
|
|
|
| 85 |
react_graph = builder.compile(checkpointer=memory)
|
| 86 |
|
| 87 |
# Specify a thread
|
| 88 |
+
config = {"configurable": {"thread_id": "30"}}
|
| 89 |
|
| 90 |
# initialize state with patient pk hash
|
| 91 |
input_state:AppState = {
|
| 92 |
+
"messages": [HumanMessage(content="the patient is 30 and is not pregnant or breastfeeding?")],
|
| 93 |
"question": "",
|
| 94 |
"rag_result": "",
|
|
|
|
|
|
|
| 95 |
"answer": "",
|
| 96 |
"pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
|
| 97 |
}
|
patient_demonstration.sqlite
ADDED
|
File without changes
|