JDFPalladium
adding idsr define tool and reflecting tweaks to other scripts and notebooks
35274a7
from dotenv import load_dotenv
import os
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
load_dotenv("config.env")
os.environ.get("OPENAI_API_KEY")
os.environ.get("LANGSMITH_API_KEY")
from chatlib.state_types import AppState
from chatlib.guidlines_rag_agent_li import rag_retrieve
from chatlib.patient_all_data import sql_chain
from chatlib.idsr_check import idsr_check
tools = [rag_retrieve, sql_chain, idsr_check]
llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain, idsr_check])
sys_msg = SystemMessage(
content="""
You are a helpful assistant supporting clinicians during patient visits. You have three tools:
- rag_retrieve: to access HIV clinical guidelines
- sql_chain: to query patient data from the SQL database
- idsr_check: to check if the patient case description matches any known diseases
There are three types of questions you may receive:
1. Questions about patients (e.g., "When should this patient switch regimens?" or "What is their viral load history?")
2. Questions about HIV clinical guidelines (e.g., "What are the latest guidelines for changing ART regimens?")
3. Questions about disease identification based on patient case descriptions (e.g., "Should I be concerned about certain diseases with this patient?")
When a clinician asks about patients, first use rag_retrieve to get relevant guideline context, then use sql_chain to query the patient's data, combining information as needed.
When a clinician asks about guidelines, use rag_retrieve to provide the latest HIV clinical guidelines.
When a clinician asks about disease identification, use idsr_check to match case descriptions against disease definitions.
Respond only with a JSON object specifying the tool to call and its arguments, for example:
{
"tool": "rag_retrieve",
"args": {"query": "latest ART regimen guidelines"}
}
Keep responses concise and focused. The clinician is a healthcare professional; do not suggest consulting one.
If the clinician's question is unclear, ask for clarification.
Do not include any text outside the JSON response.
"""
)
def assistant(state: AppState) -> AppState:
pk_hash = state.get("pk_hash", None)
if pk_hash:
pk_msg = SystemMessage(
content=f"The patient identifier (pk_hash) is: {pk_hash}"
)
messages = [sys_msg, pk_msg] + state["messages"]
else:
messages = [sys_msg] + state["messages"]
new_message = llm_with_tools.invoke(messages)
latest_question = ""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
latest_question = msg.content
break
state["messages"] = state["messages"] + [new_message] # type: ignore
state["question"] = latest_question # type: ignore
return state
# Graph
builder = StateGraph(AppState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
react_graph = builder.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "30"}}
input_state: AppState = {
"messages": [HumanMessage(content="summarize the patient's clinical visits")],
"question": "",
"rag_result": "",
"answer": "",
"pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73", # type: ignore
}
message_output = react_graph.invoke(input_state, config) # type: ignore
for m in message_output["messages"]:
m.pretty_print()