JDFPalladium commited on
Commit
97facdb
·
1 Parent(s): 7023043

adding sql pull from start

Browse files
chatlib/patient_all_data.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+
4
+ from langchain_openai import ChatOpenAI
5
+ llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
6
+
7
+ from .state_types import AppState
8
+
9
+ # Define the SQL query tool
10
+ def sql_chain(state: AppState) -> AppState:
11
+ """
12
+ Annotated function that takes a patient identifer (pk_hash) and returns
13
+ all data related to that patient from the SQL database.
14
+ It writes an SQL query to retrieve relevant data, executes the query,
15
+ and generates a natural language answer based on the query results.
16
+ Returns the final answer as a string.
17
+ The function uses the QuerySQLDatabaseTool to handle the SQL operations.
18
+ The state should contain the following fields:
19
+ - question: str - the question seeking information on patient data
20
+ - pk_hash: str - the patient identifier to query the database
21
+ - rag_result: str - context information from the guidelines retrieval
22
+ The function will update the state with the answer to the question.
23
+ The answer will be generated based on the SQL query results and the context information.
24
+ The function will return the updated state with the answer.
25
+ """
26
+ pk_hash = state.get("pk_hash")
27
+ if not pk_hash:
28
+ raise ValueError("pk_hash is required in state for SQL queries.")
29
+
30
+ conn = sqlite3.connect('data/patient_demonstration.sqlite')
31
+ cursor = conn.cursor()
32
+
33
+ # Write the SQL query using the QuerySQLDatabaseTool
34
+ cursor.execute("SELECT * FROM clinical_visits WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
35
+ rows = cursor.fetchall()
36
+ visits_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
37
+
38
+ def summarize_visits(df):
39
+ if df.empty:
40
+ return "No clinical visit data available."
41
+
42
+ def safe(val):
43
+ if pd.isnull(val) or val in ("", "NULL"):
44
+ return 'missing'
45
+ return val
46
+
47
+ summaries = []
48
+ for _, row in df.sort_values("VisitDate", ascending=False).head(5).iterrows():
49
+ summaries.append(f"- {row['VisitDate']}: WHO Stage {safe(row['WHOStage'])}, Weight {safe(row['Weight'])}kg, "
50
+ f"NextAppointmentDate {safe(row['NextAppointmentDate'])}, VisityType {safe(row['VisitType'])}, "
51
+ f"VisitBy {safe(row['VisitBy'])}, Pregnant {safe(row['Pregnant'])}, Breastfeeding {safe(row['Breastfeeding'])}, "
52
+ f"WHOStage {safe(row['WHOStage'])}, StabilityAssessment {safe(row['StabilityAssessment'])}, "
53
+ f"DifferentiatedCare {safe(row['DifferentiatedCare'])}, WHOStagingOI {safe(row['WHOStagingOI'])}, "
54
+ f"Height {safe(row['Height'])}cm, Adherence {safe(row['Adherence'])}, BP {safe(row['BP'])}, "
55
+ f"OI {safe(row['OI'])}, CurrentRegimen {safe(row['CurrentRegimen'])}"
56
+ )
57
+ return "\n".join(summaries)
58
+
59
+ visits_summary = summarize_visits(visits_data)
60
+ print(visits_summary)
61
+
62
+ cursor.execute("SELECT * FROM pharmacy WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
63
+ rows = cursor.fetchall()
64
+ pharmacy_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
65
+
66
+ def summarize_pharmacy(df):
67
+ if df.empty:
68
+ return "No pharmacy data available."
69
+
70
+ def safe(val):
71
+ if pd.isnull(val) or val in ("", "NULL"):
72
+ return 'missing'
73
+ return val
74
+
75
+ summaries = []
76
+ for _, row in df.sort_values("DispenseDate", ascending=False).head(5).iterrows():
77
+ summaries.append(f"- {row['DispenseDate']}: ExpectedReturn {safe(row['ExpectedReturn'])}, Drug {safe(row['Drug'])}, "
78
+ f"Duration {safe(row['Duration'])}, TreatmentType {safe(row['TreatmentType'])}, "
79
+ f"RegimenLine {safe(row['RegimenLine'])}, "
80
+ f"RegimenChangedSwitched {safe(row['RegimenChangedSwitched'])}, "
81
+ f"RegimenChangeSwitchedReason {safe(row['RegimenChangeSwitchedReason'])}, "
82
+ )
83
+ return "\n".join(summaries)
84
+
85
+ pharmacy_summary = summarize_pharmacy(pharmacy_data)
86
+ print(pharmacy_summary)
87
+
88
+ cursor.execute("SELECT * FROM lab WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
89
+ rows = cursor.fetchall()
90
+ lab_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
91
+
92
+ def summarize_lab(df):
93
+ if df.empty:
94
+ return "No lab data available."
95
+
96
+ def safe(val):
97
+ if pd.isnull(val) or val in ("", "NULL"):
98
+ return 'missing'
99
+ return val
100
+
101
+ summaries = []
102
+ for _, row in df.sort_values("OrderedbyDate", ascending=False).head(5).iterrows():
103
+ summaries.append(f"- {row['OrderedbyDate']}: TestName {safe(row['TestName'])}, TestResult {safe(row['TestResult'])},"
104
+ )
105
+ return "\n".join(summaries)
106
+
107
+ lab_summary = summarize_lab(lab_data)
108
+ print(lab_summary)
109
+
110
+ cursor.execute("SELECT * FROM demographics WHERE PatientPKHash = :pk_hash", {"pk_hash": pk_hash})
111
+ rows = cursor.fetchall()
112
+ demographic_data = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
113
+
114
+ def summarize_demographics(df):
115
+ if df.empty:
116
+ return "No demographic data available."
117
+
118
+ def safe(val):
119
+ if pd.isnull(val) or val in ("", "NULL"):
120
+ return 'missing'
121
+ return val
122
+
123
+ row = df.iloc[0]
124
+ summary = (
125
+ f"Sex: {safe(row['Sex'].values[0])}\n"
126
+ f"MaritalStatus: {safe(row['MaritalStatus'].values[0])}\n"
127
+ f"EducationLevel: {safe(row['EducationLevel'].values[0])}\n"
128
+ f"Occupation: {safe(row['Occupation'].values[0])}\n"
129
+ f"OnIPT: {safe(row['OnIPT'].values[0])}\n"
130
+ f"ARTOutcomeDescription: {safe(row['ARTOutcomeDescription'].values[0])}\n"
131
+ f"StartARTDate: {safe(row['StartARTDate'].values[0])}\n"
132
+ f"Date Of Birth: {safe(row['DOB'].values[0])}"
133
+ )
134
+ return summary
135
+
136
+ demographic_summary = summarize_demographics(demographic_data)
137
+ print(demographic_summary)
138
+
139
+ # cursor.execute("SELECT * FROM data_dictionary")
140
+ # rows = cursor.fetchall()
141
+ # data_dictionary = pd.DataFrame(rows, columns=[column[0] for column in cursor.description])
142
+
143
+ conn.close()
144
+
145
+ prompt = (
146
+ "Given the following user question, contextual clinical guidance, "
147
+ "patient clinical data, patient lab data, patient pharmacy data, "
148
+ "patient demographic data, answer the user question. "
149
+ "Try to answer based on the provided data."
150
+ "If there is essential patient information missing that you need in order to answer, "
151
+ "do not provide an answer and instead explain what information is missing. \n\n"
152
+ f'Question: {state["question"]}\n'
153
+ f'Context: {state.get("rag_result", "No guidelines provided.")}\n'
154
+ f'Patient Clinical Visits: {visits_summary}\n'
155
+ f'Patient Pharmacy Data: {pharmacy_summary}\n'
156
+ f'Patient Lab Data: {lab_summary}\n'
157
+ f'Patient Demographic Data: {demographic_summary}\n'
158
+ # f'Data Dictionary: {data_dictionary}\n'
159
+ )
160
+ response = llm.invoke(prompt)
161
+ state["answer"] = response.content
162
+ return state
chatlib/state_types.py CHANGED
@@ -27,18 +27,5 @@ class AppState(TypedDict):
27
  messages: Annotated[list[AnyMessage], add_messages]
28
  question: str
29
  rag_result: str
30
- query: str
31
- result: str
32
  answer: str
33
  pk_hash: str
34
-
35
- # initialize state with patient pk hash
36
- # input_state:State = {
37
- # "messages": [HumanMessage(content="was this person typically late or on time to their visits?")],
38
- # "question": "",
39
- # "rag_result": "",
40
- # "query": "",
41
- # "result": "",
42
- # "answer": "",
43
- # "pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
44
- # }
 
27
  messages: Annotated[list[AnyMessage], add_messages]
28
  question: str
29
  rag_result: str
 
 
30
  answer: str
31
  pk_hash: str
 
 
 
 
 
 
 
 
 
 
 
iit_test.sqlite ADDED
File without changes
main.py CHANGED
@@ -5,6 +5,7 @@ from langgraph.graph import START, StateGraph
5
  from langchain_core.messages import HumanMessage, SystemMessage
6
  from langgraph.prebuilt import tools_condition, ToolNode
7
  from langgraph.checkpoint.memory import MemorySaver
 
8
  memory = MemorySaver()
9
 
10
  load_dotenv("config.env")
@@ -13,7 +14,7 @@ os.environ.get("LANGSMITH_API_KEY")
13
 
14
  from chatlib.state_types import AppState
15
  from chatlib.guidlines_rag_agent_li import rag_retrieve
16
- from chatlib.patient_sql_agent import sql_chain
17
 
18
  # from langchain_ollama.chat_models import ChatOllama
19
  # llm = ChatOllama(model="mistral:latest", temperature=0)
@@ -27,25 +28,19 @@ sys_msg = SystemMessage(content="""
27
  You are a helpful assistant tasked with helping clinicians
28
  meeting with patients. You have two tools available,
29
  rag_retrieve to access information from HIV clinical guidelines,
30
- and sql_chain to access patient data.
31
-
32
- In most cases, you should use both tools to answer a question.
33
- In these cases, first call rag_retrieve to get the relevant information,
34
- then call sql_chain to get the patient data, and finally combine the results
35
- to provide a complete answer. For example, if the question is about whether
36
- a patient is on the correct treatment, first retrieve the treatment guidelines
37
- using rag_retrieve, then check the patient's treatment history using sql_chain.
38
- Another example is if the question is about when they should have their next viral load test,
39
- first retrieve the guidelines for viral load testing using rag_retrieve,
40
- then check the patient's last viral load test date and result using sql_chain.
41
 
42
  You must respond only with a JSON object specifying the tool to call and its arguments.
43
- Do not generate any SQL queries, results or answers yourself. Only the sql_chain
44
- tool should do that.
45
 
46
- When calling a tool, provide only the necessary fields required for that tool to run.
47
- Do not include the full state or raw query results in the tool call arguments.
48
- For example, include the question and pk_hash, but exclude the query or result.
 
 
 
 
49
 
50
  """
51
  )
@@ -90,15 +85,13 @@ builder.add_edge("tools", "assistant")
90
  react_graph = builder.compile(checkpointer=memory)
91
 
92
  # Specify a thread
93
- config = {"configurable": {"thread_id": "25"}}
94
 
95
  # initialize state with patient pk hash
96
  input_state:AppState = {
97
- "messages": [HumanMessage(content="my patient is complaining about feeling headaches. should i consider switching their regimen?")],
98
  "question": "",
99
  "rag_result": "",
100
- "query": "",
101
- "result": "",
102
  "answer": "",
103
  "pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
104
  }
 
5
  from langchain_core.messages import HumanMessage, SystemMessage
6
  from langgraph.prebuilt import tools_condition, ToolNode
7
  from langgraph.checkpoint.memory import MemorySaver
8
+
9
  memory = MemorySaver()
10
 
11
  load_dotenv("config.env")
 
14
 
15
  from chatlib.state_types import AppState
16
  from chatlib.guidlines_rag_agent_li import rag_retrieve
17
+ from chatlib.patient_all_data import sql_chain
18
 
19
  # from langchain_ollama.chat_models import ChatOllama
20
  # llm = ChatOllama(model="mistral:latest", temperature=0)
 
28
  You are a helpful assistant tasked with helping clinicians
29
  meeting with patients. You have two tools available,
30
  rag_retrieve to access information from HIV clinical guidelines,
31
+ and sql_chain to access patient data. When a clinican asks a question about a patient,
32
+ you should first run rag_retrieve to get contextual information from the guidelines,
33
+ then use sql_chain to query the patient's data from the SQL database.
 
 
 
 
 
 
 
 
34
 
35
  You must respond only with a JSON object specifying the tool to call and its arguments.
 
 
36
 
37
+ Keep your responses concise and focused on the task at hand. Remember, you are
38
+ talking to a clinician who needs quick and accurate information about their patient.
39
+ Do not tell them to consult a healthcare professional - they are the healthcare professional.
40
+
41
+ If the clinican questions is not clear, ask for clarification or more information.
42
+ If the clinican asks a question that is not related to the patient, then use the rag_retrieve tool
43
+ to provide general information about HIV clinical guidelines.
44
 
45
  """
46
  )
 
85
  react_graph = builder.compile(checkpointer=memory)
86
 
87
  # Specify a thread
88
+ config = {"configurable": {"thread_id": "30"}}
89
 
90
  # initialize state with patient pk hash
91
  input_state:AppState = {
92
+ "messages": [HumanMessage(content="the patient is 30 and is not pregnant or breastfeeding?")],
93
  "question": "",
94
  "rag_result": "",
 
 
95
  "answer": "",
96
  "pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73"
97
  }
patient_demonstration.sqlite ADDED
File without changes