File size: 6,684 Bytes
6e5b890
 
 
 
 
 
9c4032d
 
3b5f033
5486ae5
6e5b890
afe6838
6e5b890
b459a9c
6e5b890
 
7023043
3b5f033
7023043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b5f033
 
249458d
6e5b890
7023043
 
 
6e5b890
3b5f033
6e5b890
 
 
3b5f033
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e5b890
5486ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e5b890
 
 
 
 
 
 
 
afe6838
6e5b890
 
 
 
 
 
afe6838
6e5b890
b459a9c
6e5b890
 
 
4c42be0
3b5f033
 
afe6838
6e5b890
 
5486ae5
6e5b890
 
4c42be0
b459a9c
6e5b890
afe6838
 
6e5b890
3b5f033
6e5b890
4c42be0
b459a9c
6e5b890
afe6838
 
5486ae5
 
 
 
 
 
 
afe6838
5486ae5
b459a9c
6e5b890
249458d
afe6838
7023043
3b5f033
 
4c42be0
 
6e5b890
249458d
4c42be0
b459a9c
6e5b890
afe6838
 
 
6e5b890
 
 
 
 
 
 
 
 
9c4032d
4c42be0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict, Annotated


from .state_types import AppState

db = SQLDatabase.from_uri("sqlite:///data/patient_demonstration.sqlite")
llm = ChatOpenAI(temperature=0.0, model="gpt-4o")


system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. The database contains the following tables and columns:

table: clinical_visits
columns: PatientPKHash, VisitDate, VisitType, VisitBy, NextAppointmentDate, TCAReason, Pregnant, Breastfeeding,
    StabilityAssessment, DifferentiatedCare, WHOStage, WHOStagingOI, Height, Weight,
    EMR, Project, Adherence, AdherenceCategory, BP, OI, OIDate, CurrentRegimen, AppointmentReminderWillingness

table: lab
columns: PatientPKHash, SiteCode, OrderedByDate, TestName, TestResult 

table: pharmacy
columns: PatientPKHash, SiteCode, DispenseDate, Drug, ExpectedReturn, Duration, TreatmentType,
    RegimenLine, RegimenChangedSwitched, RegimenChangeSwitchedReason

table: demographics
columns: PatientPKHash, MFLCode, FacilityName, County, SubCounty, PartnerName, AgencyName, Sex,
    MaritalStatus, EducationLevel, Occupation, OnIPT, AgeGroup, ARTOutcomeDescription, AsOfDate, LoadDate, StartARTDate, DOB

To understand what each column means, refer to the following data dictionary: {table_info}.    

Filter PatientPKHash column using exactly the provided value: {pk_hash} if the value is provided
to get information about the patient with whom the clinician is meeting. 

If provided, create the query based on the following authoriative context from HIV clinical guidelines:
{guidelines}.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question. Use LIMIT 10 unless otherwise specified.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table. Do not join or select from any tables not listed above. 
Do not select columns not listed in the schema.

When looking for a patient's regimen information, use CurrentRegimen from clinical_visits table to see
what regimen they are on and join with the pharmacy table for other regimen details, including regimen line
and regimen switch.

Here's an example of how to do this in SQL:
SELECT r.CurrentRegimen, p.RegimenChangedSwitched, p.RegimenChangeSwitchedReason
FROM pharmacy p
JOIN regimen r ON p.PatientPKHash = r.PatientPKHash
WHERE p.PatientPKHash = '{pk_hash}'
ORDER BY p.DispenseDate DESC
LIMIT 10;


When checking if a patient was late for an appointment, for each visit, compare the NextAppointmentDate from the previous visit to the VisitDate of the current visit.
Do not compare NextAppointmentDate to the VisitDate in the same row. Use SQL to find, for each patient, the next VisitDate after a given VisitDate, and compare it to the NextAppointmentDate from the previous visit.

Here is an example of how to do this in SQL:
SELECT
v1.PatientPKHash,
v1.VisitDate AS PreviousVisitDate,
v1.NextAppointmentDate,
v2.VisitDate AS NextVisitDate,
CASE
    WHEN v2.VisitDate <= v1.NextAppointmentDate THEN 'On time'
    ELSE 'Late'
END AS AttendanceStatus
FROM clinical_visits v1
JOIN clinical_visits v2
ON v1.PatientPKHash = v2.PatientPKHash
AND v2.VisitDate > v1.VisitDate
WHERE NOT EXISTS (
SELECT 1 FROM clinical_visits v3
WHERE v3.PatientPKHash = v1.PatientPKHash
    AND v3.VisitDate > v1.VisitDate
    AND v3.VisitDate < v2.VisitDate
)
ORDER BY v1.PatientPKHash, v1.VisitDate;
"""

user_prompt = "Question: {input}"

query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: AppState) -> AppState:
    """Generate SQL query to fetch information."""

    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "table_info": db.run("SELECT * FROM data_dictionary;"),
            "input": state["question"],
            "guidelines": state.get("rag_result", "No guidelines provided."),
            "pk_hash": state.get("pk_hash", ""),
        }
    )

    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    state["query"] = result["query"]  # type: ignore
    return state


def execute_query(state: AppState) -> AppState:
    """Execute SQL query."""

    execute_query_tool = QuerySQLDatabaseTool(db=db)
    state["result"] = execute_query_tool.invoke(state["query"])  # type: ignore
    return state


def generate_answer(state: AppState) -> AppState:
    """
    Answer question using retrieved information as context.
    For awareness, NextAppointmentDate is set during the VisitDate of the same entry.
    To determine if the patient came on time to their next appointment, compare NextAppointmentDate
    with the next recorded VisitDate. For example, if a patient has a VisitDate of
    2023-01-01 and a NextAppointmentDate of 2023-01-15, check if the next VisitDate is on or before
    2023-01-15 to determine if the patient came on time.

    """

    prompt = (
        "Given the following user question, context information, corresponding SQL query, "
        "and SQL result, answer the user question. If the SQL result is empty, then the SQL query was not able to retrieve any information. "
        "In that case, ignore the SQL query too and generate an answer based only on the context. \n\n"
        f'Question: {state["question"]}\n'
        f'Context: {state.get("rag_result", "No guidelines provided.")}\n'
        f'SQL Query: {state["query"]}\n'  # type: ignore
        f'SQL Result: {state["result"]}'  # type: ignore
    )
    response = llm.invoke(prompt)
    state["answer"] = response.content  # type: ignore
    return state


@tool
def sql_chain(state: AppState) -> dict:
    """
    Annotated function that takes a question string seeking information on patient data
    from a SQL database, writes an SQL query to retrieve relevant data, executes the query,
    and generates a natural language answer based on the query results.
    Returns the final answer as a string.
    """
    state = write_query(state)
    state = execute_query(state)
    state = generate_answer(state)

    return state  # type: ignore