File size: 3,820 Bytes
5e988c3
 
 
249458d
5e988c3
 
 
97facdb
5e988c3
 
 
 
 
 
b459a9c
5e988c3
97facdb
4935ec0
b459a9c
4935ec0
e23fefd
4935ec0
5e988c3
4c42be0
e23fefd
 
4935ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e23fefd
 
5e988c3
4c42be0
b459a9c
3b5f033
 
5486ae5
 
e23fefd
 
 
5486ae5
 
 
 
 
e23fefd
5486ae5
 
 
 
 
b459a9c
e23fefd
 
b459a9c
5e988c3
e23fefd
5e988c3
b459a9c
5e988c3
4c42be0
5e988c3
 
 
4c42be0
5e988c3
b459a9c
5e988c3
 
 
4c42be0
97facdb
b459a9c
4c42be0
e23fefd
e887897
3b5f033
 
 
e23fefd
5486ae5
 
3b5f033
e23fefd
5e988c3
e23fefd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from dotenv import load_dotenv
import os
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

load_dotenv("config.env")
os.environ.get("OPENAI_API_KEY")
os.environ.get("LANGSMITH_API_KEY")

from chatlib.state_types import AppState
from chatlib.guidlines_rag_agent_li import rag_retrieve
from chatlib.patient_all_data import sql_chain
from chatlib.idsr_check import idsr_check

tools = [rag_retrieve, sql_chain, idsr_check]
llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain, idsr_check])


sys_msg = SystemMessage(
    content="""
You are a helpful assistant supporting clinicians during patient visits. You have three tools:

- rag_retrieve: to access HIV clinical guidelines
- sql_chain: to query patient data from the SQL database
- idsr_check: to check if the patient case description matches any known diseases

There are three types of questions you may receive:
1. Questions about patients (e.g., "When should this patient switch regimens?" or "What is their viral load history?")
2. Questions about HIV clinical guidelines (e.g., "What are the latest guidelines for changing ART regimens?")
3. Questions about disease identification based on patient case descriptions (e.g., "Should I be concerned about certain diseases with this patient?")

When a clinician asks about patients, first use rag_retrieve to get relevant guideline context, then use sql_chain to query the patient's data, combining information as needed.

When a clinician asks about guidelines, use rag_retrieve to provide the latest HIV clinical guidelines.

When a clinician asks about disease identification, use idsr_check to match case descriptions against disease definitions.

Respond only with a JSON object specifying the tool to call and its arguments, for example:
{
  "tool": "rag_retrieve",
  "args": {"query": "latest ART regimen guidelines"}
}

Keep responses concise and focused. The clinician is a healthcare professional; do not suggest consulting one.

If the clinician's question is unclear, ask for clarification.

Do not include any text outside the JSON response.
"""
)


def assistant(state: AppState) -> AppState:

    pk_hash = state.get("pk_hash", None)

    if pk_hash:
        pk_msg = SystemMessage(
            content=f"The patient identifier (pk_hash) is: {pk_hash}"
        )
        messages = [sys_msg, pk_msg] + state["messages"]
    else:
        messages = [sys_msg] + state["messages"]

    new_message = llm_with_tools.invoke(messages)

    latest_question = ""
    for msg in reversed(messages):
        if isinstance(msg, HumanMessage):
            latest_question = msg.content
            break

    state["messages"] = state["messages"] + [new_message]  # type: ignore
    state["question"] = latest_question  # type: ignore
    return state


# Graph
builder = StateGraph(AppState)


builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))


builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
react_graph = builder.compile(checkpointer=memory)


config = {"configurable": {"thread_id": "30"}}


input_state: AppState = {
    "messages": [HumanMessage(content="summarize the patient's clinical visits")],
    "question": "",
    "rag_result": "",
    "answer": "",
    "pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73",  # type: ignore
}


message_output = react_graph.invoke(input_state, config)  # type: ignore

for m in message_output["messages"]:
    m.pretty_print()