File size: 9,416 Bytes
e23fefd
b459a9c
 
249458d
 
 
 
1255a5e
249458d
b459a9c
 
4c42be0
b459a9c
249458d
9d3b553
 
249458d
 
afe6838
f0e2099
249458d
 
4935ec0
 
35274a7
c1f3739
6781788
249458d
afe6838
f0e2099
 
afe6838
 
4c42be0
e670011
afe6838
 
 
f0e2099
233d8ee
f0e2099
233d8ee
afe6838
 
f0e2099
233d8ee
f0e2099
233d8ee
afe6838
233d8ee
 
 
 
 
afe6838
35274a7
 
 
 
 
 
 
f0e2099
35274a7
f0e2099
249458d
e23fefd
afe6838
 
da7df85
 
35274a7
4935ec0
35274a7
 
 
 
4935ec0
fac69f9
4935ec0
233d8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4935ec0
 
35274a7
 
 
 
 
 
 
 
 
 
 
 
 
1255a5e
 
 
da7df85
1255a5e
fac69f9
1255a5e
e887897
1255a5e
 
 
4935ec0
3cc5767
 
1255a5e
3cc5767
1255a5e
afe6838
 
249458d
fac69f9
249458d
6781788
afe6838
6781788
249458d
 
 
 
 
b459a9c
afe6838
233d8ee
b459a9c
 
 
c1f3739
4c42be0
f9557a1
 
80c7afd
 
 
 
 
233d8ee
24e3e87
afe6838
233d8ee
 
 
b459a9c
 
 
 
4c42be0
b459a9c
afe6838
4935ec0
 
b459a9c
 
1255a5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b459a9c
507b761
25fea2a
 
 
afe6838
2797409
24e3e87
 
 
 
8e35004
233d8ee
24e3e87
 
233d8ee
 
 
 
 
 
 
507b761
233d8ee
 
 
 
 
 
507b761
233d8ee
 
e670011
b459a9c
25fea2a
b459a9c
 
 
 
1255a5e
 
 
 
afe6838
b459a9c
233d8ee
1255a5e
eaf6fa1
3b24ccb
eaf6fa1
 
 
1255a5e
b459a9c
 
e23fefd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# type: ignore
import gradio as gr
import uuid
from dotenv import load_dotenv
import os
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver


memory = MemorySaver()

if os.path.exists("config.env"):
    load_dotenv("config.env")
os.environ.get("OPENAI_API_KEY")

llm = ChatOpenAI(temperature=0.0, model="gpt-4o")

from chatlib.state_types import AppState
from chatlib.guidlines_rag_agent_li import rag_retrieve
from chatlib.patient_all_data import sql_chain
from chatlib.idsr_check import idsr_check
from chatlib.idsr_definition import idsr_define
from chatlib.phi_filter import detect_and_redact_phi
from chatlib.assistant_node import assistant


def rag_retrieve_tool(query):
    """Retrieve relevant HIV clinical guidelines for the given query."""
    result = rag_retrieve(query, llm=llm)
    return {
        "rag_result": result.get("rag_result", ""),
        "rag_sources": result.get("rag_sources", []),
        "last_tool": "rag_retrieve",
    }


def sql_chain_tool(query, rag_result, pk_hash):
    """Query patient data from the SQL database and summarize results."""
    result = sql_chain(query, llm=llm, rag_result=rag_result, pk_hash=pk_hash)
    return {"answer": result.get("answer", ""), "last_tool": "sql_chain"}


def idsr_check_tool(query, sitecode):
    """Check if the patient case description matches any known diseases."""
    result = idsr_check(query, llm=llm, sitecode=sitecode)

    return {
        "answer": result.get("answer", ""),
        "last_tool": "idsr_check",
        "context": result.get("context", None),
    }

def idsr_define_tool(query):
    """Retrieve disease definition based on the query."""
    result = idsr_define(query, llm=llm)
    return {
        "answer": result.get("answer", ""),
        "last_tool": "idsr_define"
    }

tools = [rag_retrieve_tool, sql_chain_tool, idsr_check_tool, idsr_define_tool]
llm_with_tools = llm.bind_tools(tools)


sys_msg = SystemMessage(
    content="""
You are a helpful assistant supporting clinicians during patient visits. When a patient ID is provided, the clinician is meeting with that HIV-positive patient and may inquire about their history, lab results, or medications. If no patient ID is provided, the clinician may be asking general HIV clinical questions or presenting symptoms for a new patient.

You have access to four tools to help you answer the clinician's questions. 

- rag_retrieve_tool: to access HIV clinical guidelines
- sql_chain_tool: to access HIV data about the patient with whom the clinician is meeting. For straightforward factual questions about the patient, you may call sql_chain directly. For questions requiring clinical interpretation or classification, first call rag_retrieve to get relevant clinical guideline context, then include that context when calling sql_chain.
- idsr_check_tool: to check if the patient case description matches any known diseases.
- idsr_define_tool: to retrieve the official case definition of a disease when the clinician asks about it (e.g., “What is the description of cholera?”). Do not use this tool for analyzing symptom descriptions — use `idsr_check_tool` for that.

When a tool is needed, respond only with a JSON object specifying the tool to call and its minimal arguments, for example:
{
  "tool": "rag_retrieve_tool",
  "args": {
    "query": "patient vaginal bleeding",
  }
}

When calling the "sql_chain" tool, always include the following arguments in the JSON response:

- "query": the clinician's question
- "rag_result": the clinical guideline context obtained from rag_retrieve
- "pk_hash": the patient identifier string

For example:

{
  "tool": "sql_chain_tool",
  "args": {
    "query": "What is the patient's latest lab results?",
    "rag_result": "<clinical guideline context>",
    "pk_hash": "patient123"
  }
}

When calling the "idsr_check_tool" tool, always include the following arguments in the JSON response:

- "query": the clinician's question
- "sitecode": the site code string

For example:

{
  "tool": "idsr_check_tool",
  "args": {
    "query": "What is the patient's latest lab results?",
    "sitecode": "32060"
  }
}

When calling the "idsr_define_tool" tool, always include the following arguments in the JSON response:

- "query": the clinician's question

For example:

{
  "tool": "idsr_define_tool",
  "args": {
    "query": "What is the description of cholera?"
  }
}

There are only two cases where a tool is not needed:
1. If the clinician's question is a simple greeting, farewell, or acknowledgement.
2. The answer is clearly and completely present in the prior conversation turns.

If a tool is not needed, respond directly in natural language.

If the clinician's question or intent is ambiguous, ask a clarifying question before invoking a tool.

Never include text outside the JSON object when invoking a tool.

Never use your general knowledge to answer medical questions.

Do not reference PHI (Protected Health Information) in your responses.

Keep responses concise and focused. The clinician is a healthcare professional; do not suggest consulting one.

If the question is outside your scope, respond with "I'm sorry, I cannot assist with that request."
"""
)


builder = StateGraph(AppState)
builder.add_node(
    "assistant", lambda state: assistant(state, sys_msg, llm, llm_with_tools)
)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
react_graph = builder.compile(checkpointer=memory)


def chat_with_patient(question: str, patient_id: str, sitecode: str, thread_id: str = None):  # type: ignore
    if thread_id is None or thread_id == "":
        thread_id = str(uuid.uuid4())

    question = detect_and_redact_phi(question)["redacted_text"]

    print(question)

    # get first five characters of sitecode_selection if not none
    if sitecode is None or sitecode == "":
        sitecode_selected = ""
    else:
        sitecode_selected = sitecode[:5]

    # First turn: initialize state
    input_state: AppState = {
        "messages": [HumanMessage(content=question)],
        "pk_hash": patient_id,
        "sitecode": sitecode_selected,
    }

    config = {"configurable": {"thread_id": thread_id, "user_id": thread_id}}

    output_state = react_graph.invoke(input_state, config)  # type: ignore

    for m in output_state["messages"]:
        m.pretty_print()

    assistant_message = output_state["messages"][-1].content

    # Cleaned history: Human + AI only
    chat_history_html = """
    <div style='
        border:1px solid #ccc;
        border-radius:6px;
        padding:10px;
        background-color:#f9f9f9;
        max-height:300px;
        overflow-y:auto;
    '>
    """
    for m in output_state["messages"]:
        if isinstance(m, HumanMessage):
            chat_history_html += f"<strong>You:</strong> {m.content}<br><br>"
        elif isinstance(m, AIMessage):
            chat_history_html += f"<strong>Assistant:</strong> {m.content}<br><br>"
    chat_history_html += "</div>"

    return assistant_message, thread_id, output_state.get("rag_sources", ""), "", chat_history_html

def init_session():
    new_id = str(uuid.uuid4())
    print(f"New session ID: {new_id}")
    return new_id

with gr.Blocks() as app:
    gr.Markdown(
        """
        # Clinician Assistant
        Welcome! Enter your clinical question below. The assistant can access HIV guidelines, patient data, and disease surveillance tools.

        **Note**: This is a prototype tool. There is mock data for ten fictitious patients and a mix of counties to select from (for regional variation in IDSR symptom checking).
        """
    )

    gr.Markdown("### Select Patient Context")
    with gr.Row():
        id_selected = gr.Dropdown(
            choices=[None] + [str(i) for i in range(1, 11)], label="Fake ID Number"
        )
        sitecode_selection = gr.Dropdown(
            choices=[None] + [
                "32060 - Migori",
                "32046 - Machakos",
                "32029 - Nairobi",
                "31660 - Mombasa",
                "31450 - Samburu",
            ],
            label="Sitecode",
        )

    gr.Markdown("### Ask a Clinical Question")
    question_input = gr.Textbox(label="Question")
    thread_id_state = gr.State(init_session)
    output_chat = gr.Textbox(label="Assistant Response")

    submit_btn = gr.Button("Ask")

    chat_history_display = gr.HTML()

    retrieved_sources_display = gr.HTML(label="Retrieved Sources (if applicable)")

    submit_btn.click(  # pylint: disable=no-member
        chat_with_patient,
        inputs=[question_input, id_selected, sitecode_selection, thread_id_state],
        outputs=[output_chat, thread_id_state, retrieved_sources_display, question_input, chat_history_display],
    )
    # pylint: disable=no-member
    question_input.submit(
        chat_with_patient,
        inputs=[question_input, id_selected, sitecode_selection, thread_id_state],
        outputs=[output_chat, thread_id_state, retrieved_sources_display, question_input, chat_history_display],
    )

app.launch(
    server_name="0.0.0.0",
    server_port=7860,
)