Spaces:
Sleeping
Sleeping
JDFPalladium commited on
Commit ·
6e5b890
1
Parent(s): 5e988c3
adding sql agent
Browse files- chatlib/guidlines_rag_agent_li.py +2 -0
- chatlib/patient_sql_agent.py +86 -0
- main.py +8 -5
chatlib/guidlines_rag_agent_li.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from llama_index.core import StorageContext, load_index_from_storage
|
|
|
|
| 2 |
from .state_types import State
|
| 3 |
|
| 4 |
# Load index for retrieval
|
|
@@ -8,6 +9,7 @@ retriever = index.as_retriever(similarity_top_k=5,
|
|
| 8 |
# Similarity threshold for filtering
|
| 9 |
similarity_threshold=0.5)
|
| 10 |
|
|
|
|
| 11 |
def rag_retrieve(state:State) -> State:
|
| 12 |
"""Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya.
|
| 13 |
|
|
|
|
| 1 |
from llama_index.core import StorageContext, load_index_from_storage
|
| 2 |
+
from langchain_core.tools import tool
|
| 3 |
from .state_types import State
|
| 4 |
|
| 5 |
# Load index for retrieval
|
|
|
|
| 9 |
# Similarity threshold for filtering
|
| 10 |
similarity_threshold=0.5)
|
| 11 |
|
| 12 |
+
@tool
|
| 13 |
def rag_retrieve(state:State) -> State:
|
| 14 |
"""Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya.
|
| 15 |
|
chatlib/patient_sql_agent.py
CHANGED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.utilities import SQLDatabase
|
| 2 |
+
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
|
| 3 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.tools import tool
|
| 5 |
+
from langchain_openai import ChatOpenAI
|
| 6 |
+
from typing_extensions import TypedDict, Annotated
|
| 7 |
+
|
| 8 |
+
from .state_types import State
|
| 9 |
+
db = SQLDatabase.from_uri("sqlite:///data/patient_demonstration.sqlite")
|
| 10 |
+
llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
|
| 11 |
+
|
| 12 |
+
# setup template for sql query tool
|
| 13 |
+
system_message = """
|
| 14 |
+
Given an input question, create a syntactically correct {dialect} query to
|
| 15 |
+
run to help find the answer. Unless the user specifies in his question a
|
| 16 |
+
specific number of examples they wish to obtain, always limit your query to
|
| 17 |
+
at most {top_k} results. You can order the results by a relevant column to
|
| 18 |
+
return the most interesting examples in the database.
|
| 19 |
+
|
| 20 |
+
Never query for all the columns from a specific table, only ask for a the
|
| 21 |
+
few relevant columns given the question.
|
| 22 |
+
|
| 23 |
+
Pay attention to use only the column names that you can see in the schema
|
| 24 |
+
description. Be careful to not query for columns that do not exist. Also,
|
| 25 |
+
pay attention to which column is in which table.
|
| 26 |
+
|
| 27 |
+
Only use the following tables:
|
| 28 |
+
{table_info}
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
user_prompt = "Question: {input}"
|
| 32 |
+
|
| 33 |
+
query_prompt_template = ChatPromptTemplate(
|
| 34 |
+
[("system", system_message), ("user", user_prompt)]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
class QueryOutput(TypedDict):
|
| 38 |
+
"""Generated SQL query."""
|
| 39 |
+
|
| 40 |
+
query: Annotated[str, ..., "Syntactically valid SQL query."]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def write_query(state: State) -> State:
|
| 44 |
+
"""Generate SQL query to fetch information."""
|
| 45 |
+
prompt = query_prompt_template.invoke(
|
| 46 |
+
{
|
| 47 |
+
"dialect": db.dialect,
|
| 48 |
+
"top_k": 10,
|
| 49 |
+
"table_info": db.get_table_info(),
|
| 50 |
+
"input": state["question"],
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
structured_llm = llm.with_structured_output(QueryOutput)
|
| 54 |
+
result = structured_llm.invoke(prompt)
|
| 55 |
+
return {**state, "query": result["query"]}
|
| 56 |
+
|
| 57 |
+
def execute_query(state: State) -> State:
|
| 58 |
+
"""Execute SQL query."""
|
| 59 |
+
execute_query_tool = QuerySQLDatabaseTool(db=db)
|
| 60 |
+
return {**state, "result": execute_query_tool.invoke(state["query"])}
|
| 61 |
+
|
| 62 |
+
def generate_answer(state: State) -> State:
|
| 63 |
+
"""Answer question using retrieved information as context."""
|
| 64 |
+
prompt = (
|
| 65 |
+
"Given the following user question, corresponding SQL query, "
|
| 66 |
+
"and SQL result, answer the user question.\n\n"
|
| 67 |
+
f'Question: {state["question"]}\n'
|
| 68 |
+
f'SQL Query: {state["query"]}\n'
|
| 69 |
+
f'SQL Result: {state["result"]}'
|
| 70 |
+
)
|
| 71 |
+
response = llm.invoke(prompt)
|
| 72 |
+
return {**state, "answer": response.content}
|
| 73 |
+
|
| 74 |
+
# now define a stateful tool that does the same thing
|
| 75 |
+
@tool
|
| 76 |
+
def sql_chain(state: State) -> State:
|
| 77 |
+
"""
|
| 78 |
+
Annotated function that takes a question string seeking information on patient data
|
| 79 |
+
from a SQL database, writes an SQL query to retrieve relevant data, executes the query,
|
| 80 |
+
and generates a natural language answer based on the query results.
|
| 81 |
+
Returns the final answer as a string.
|
| 82 |
+
"""
|
| 83 |
+
state = write_query(state)
|
| 84 |
+
state = execute_query(state)
|
| 85 |
+
state = generate_answer(state)
|
| 86 |
+
return state
|
main.py
CHANGED
|
@@ -13,15 +13,18 @@ os.environ.get("LANGSMITH_API_KEY")
|
|
| 13 |
|
| 14 |
from chatlib.state_types import State
|
| 15 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
|
|
|
| 16 |
|
| 17 |
-
tools = [rag_retrieve]
|
| 18 |
llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
|
| 19 |
-
llm_with_tools = llm.bind_tools([rag_retrieve])
|
| 20 |
|
| 21 |
# System message
|
| 22 |
sys_msg = SystemMessage(content="""
|
| 23 |
You are a helpful assistant tasked with helping clinicians
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
)
|
| 27 |
|
|
@@ -48,9 +51,9 @@ builder.add_edge("tools", "assistant")
|
|
| 48 |
react_graph = builder.compile(checkpointer=memory)
|
| 49 |
|
| 50 |
# Specify a thread
|
| 51 |
-
config = {"configurable": {"thread_id": "
|
| 52 |
|
| 53 |
-
messages = [HumanMessage(content="
|
| 54 |
messages = react_graph.invoke({"messages": messages}, config)
|
| 55 |
for m in messages['messages']:
|
| 56 |
m.pretty_print()
|
|
|
|
| 13 |
|
| 14 |
from chatlib.state_types import State
|
| 15 |
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 16 |
+
from chatlib.patient_sql_agent import sql_chain
|
| 17 |
|
| 18 |
+
tools = [rag_retrieve, sql_chain]
|
| 19 |
llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
|
| 20 |
+
llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain])
|
| 21 |
|
| 22 |
# System message
|
| 23 |
sys_msg = SystemMessage(content="""
|
| 24 |
You are a helpful assistant tasked with helping clinicians
|
| 25 |
+
meeting with patients. You have two tools available,
|
| 26 |
+
one to access information from HIV clinical guidelines, the other is
|
| 27 |
+
a SQL tool to access patient data.
|
| 28 |
"""
|
| 29 |
)
|
| 30 |
|
|
|
|
| 51 |
react_graph = builder.compile(checkpointer=memory)
|
| 52 |
|
| 53 |
# Specify a thread
|
| 54 |
+
config = {"configurable": {"thread_id": "13"}}
|
| 55 |
|
| 56 |
+
messages = [HumanMessage(content="what is the proper course of treatment for someone with opportunistic infections?")]
|
| 57 |
messages = react_graph.invoke({"messages": messages}, config)
|
| 58 |
for m in messages['messages']:
|
| 59 |
m.pretty_print()
|