Spaces:
Sleeping
Sleeping
File size: 3,820 Bytes
5e988c3 249458d 5e988c3 97facdb 5e988c3 b459a9c 5e988c3 97facdb 4935ec0 b459a9c 4935ec0 e23fefd 4935ec0 5e988c3 4c42be0 e23fefd 4935ec0 e23fefd 5e988c3 4c42be0 b459a9c 3b5f033 5486ae5 e23fefd 5486ae5 e23fefd 5486ae5 b459a9c e23fefd b459a9c 5e988c3 e23fefd 5e988c3 b459a9c 5e988c3 4c42be0 5e988c3 4c42be0 5e988c3 b459a9c 5e988c3 4c42be0 97facdb b459a9c 4c42be0 e23fefd e887897 3b5f033 e23fefd 5486ae5 3b5f033 e23fefd 5e988c3 e23fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | from dotenv import load_dotenv
import os
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
load_dotenv("config.env")
os.environ.get("OPENAI_API_KEY")
os.environ.get("LANGSMITH_API_KEY")
from chatlib.state_types import AppState
from chatlib.guidlines_rag_agent_li import rag_retrieve
from chatlib.patient_all_data import sql_chain
from chatlib.idsr_check import idsr_check
tools = [rag_retrieve, sql_chain, idsr_check]
llm = ChatOpenAI(temperature=0.0, model="gpt-4o")
llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain, idsr_check])
sys_msg = SystemMessage(
content="""
You are a helpful assistant supporting clinicians during patient visits. You have three tools:
- rag_retrieve: to access HIV clinical guidelines
- sql_chain: to query patient data from the SQL database
- idsr_check: to check if the patient case description matches any known diseases
There are three types of questions you may receive:
1. Questions about patients (e.g., "When should this patient switch regimens?" or "What is their viral load history?")
2. Questions about HIV clinical guidelines (e.g., "What are the latest guidelines for changing ART regimens?")
3. Questions about disease identification based on patient case descriptions (e.g., "Should I be concerned about certain diseases with this patient?")
When a clinician asks about patients, first use rag_retrieve to get relevant guideline context, then use sql_chain to query the patient's data, combining information as needed.
When a clinician asks about guidelines, use rag_retrieve to provide the latest HIV clinical guidelines.
When a clinician asks about disease identification, use idsr_check to match case descriptions against disease definitions.
Respond only with a JSON object specifying the tool to call and its arguments, for example:
{
"tool": "rag_retrieve",
"args": {"query": "latest ART regimen guidelines"}
}
Keep responses concise and focused. The clinician is a healthcare professional; do not suggest consulting one.
If the clinician's question is unclear, ask for clarification.
Do not include any text outside the JSON response.
"""
)
def assistant(state: AppState) -> AppState:
pk_hash = state.get("pk_hash", None)
if pk_hash:
pk_msg = SystemMessage(
content=f"The patient identifier (pk_hash) is: {pk_hash}"
)
messages = [sys_msg, pk_msg] + state["messages"]
else:
messages = [sys_msg] + state["messages"]
new_message = llm_with_tools.invoke(messages)
latest_question = ""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
latest_question = msg.content
break
state["messages"] = state["messages"] + [new_message] # type: ignore
state["question"] = latest_question # type: ignore
return state
# Graph
builder = StateGraph(AppState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
react_graph = builder.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "30"}}
input_state: AppState = {
"messages": [HumanMessage(content="summarize the patient's clinical visits")],
"question": "",
"rag_result": "",
"answer": "",
"pk_hash": "962885FEADB7CCF19A2CC506D39818EC448D5396C4D1AEFDC59873090C7FBF73", # type: ignore
}
message_output = react_graph.invoke(input_state, config) # type: ignore
for m in message_output["messages"]:
m.pretty_print()
|