JDFPalladium commited on
Commit
6e5b890
·
1 Parent(s): 5e988c3

adding sql agent

Browse files
chatlib/guidlines_rag_agent_li.py CHANGED
@@ -1,4 +1,5 @@
1
  from llama_index.core import StorageContext, load_index_from_storage
 
2
  from .state_types import State
3
 
4
  # Load index for retrieval
@@ -8,6 +9,7 @@ retriever = index.as_retriever(similarity_top_k=5,
8
  # Similarity threshold for filtering
9
  similarity_threshold=0.5)
10
 
 
11
  def rag_retrieve(state:State) -> State:
12
  """Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya.
13
 
 
1
  from llama_index.core import StorageContext, load_index_from_storage
2
+ from langchain_core.tools import tool
3
  from .state_types import State
4
 
5
  # Load index for retrieval
 
9
  # Similarity threshold for filtering
10
  similarity_threshold=0.5)
11
 
12
+ @tool
13
  def rag_retrieve(state:State) -> State:
14
  """Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya.
15
 
chatlib/patient_sql_agent.py CHANGED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.utilities import SQLDatabase
2
+ from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.tools import tool
5
+ from langchain_openai import ChatOpenAI
6
+ from typing_extensions import TypedDict, Annotated
7
+
8
+ from .state_types import State
9
+ db = SQLDatabase.from_uri("sqlite:///data/patient_demonstration.sqlite")
10
+ llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
11
+
12
+ # setup template for sql query tool
13
+ system_message = """
14
+ Given an input question, create a syntactically correct {dialect} query to
15
+ run to help find the answer. Unless the user specifies in his question a
16
+ specific number of examples they wish to obtain, always limit your query to
17
+ at most {top_k} results. You can order the results by a relevant column to
18
+ return the most interesting examples in the database.
19
+
20
+ Never query for all the columns from a specific table, only ask for a the
21
+ few relevant columns given the question.
22
+
23
+ Pay attention to use only the column names that you can see in the schema
24
+ description. Be careful to not query for columns that do not exist. Also,
25
+ pay attention to which column is in which table.
26
+
27
+ Only use the following tables:
28
+ {table_info}
29
+ """
30
+
31
+ user_prompt = "Question: {input}"
32
+
33
+ query_prompt_template = ChatPromptTemplate(
34
+ [("system", system_message), ("user", user_prompt)]
35
+ )
36
+
37
+ class QueryOutput(TypedDict):
38
+ """Generated SQL query."""
39
+
40
+ query: Annotated[str, ..., "Syntactically valid SQL query."]
41
+
42
+
43
+ def write_query(state: State) -> State:
44
+ """Generate SQL query to fetch information."""
45
+ prompt = query_prompt_template.invoke(
46
+ {
47
+ "dialect": db.dialect,
48
+ "top_k": 10,
49
+ "table_info": db.get_table_info(),
50
+ "input": state["question"],
51
+ }
52
+ )
53
+ structured_llm = llm.with_structured_output(QueryOutput)
54
+ result = structured_llm.invoke(prompt)
55
+ return {**state, "query": result["query"]}
56
+
57
+ def execute_query(state: State) -> State:
58
+ """Execute SQL query."""
59
+ execute_query_tool = QuerySQLDatabaseTool(db=db)
60
+ return {**state, "result": execute_query_tool.invoke(state["query"])}
61
+
62
+ def generate_answer(state: State) -> State:
63
+ """Answer question using retrieved information as context."""
64
+ prompt = (
65
+ "Given the following user question, corresponding SQL query, "
66
+ "and SQL result, answer the user question.\n\n"
67
+ f'Question: {state["question"]}\n'
68
+ f'SQL Query: {state["query"]}\n'
69
+ f'SQL Result: {state["result"]}'
70
+ )
71
+ response = llm.invoke(prompt)
72
+ return {**state, "answer": response.content}
73
+
74
+ # now define a stateful tool that does the same thing
75
+ @tool
76
+ def sql_chain(state: State) -> State:
77
+ """
78
+ Annotated function that takes a question string seeking information on patient data
79
+ from a SQL database, writes an SQL query to retrieve relevant data, executes the query,
80
+ and generates a natural language answer based on the query results.
81
+ Returns the final answer as a string.
82
+ """
83
+ state = write_query(state)
84
+ state = execute_query(state)
85
+ state = generate_answer(state)
86
+ return state
main.py CHANGED
@@ -13,15 +13,18 @@ os.environ.get("LANGSMITH_API_KEY")
13
 
14
  from chatlib.state_types import State
15
  from chatlib.guidlines_rag_agent_li import rag_retrieve
 
16
 
17
- tools = [rag_retrieve]
18
  llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
19
- llm_with_tools = llm.bind_tools([rag_retrieve])
20
 
21
  # System message
22
  sys_msg = SystemMessage(content="""
23
  You are a helpful assistant tasked with helping clinicians
24
- access information from HIV clinical guidelines.
 
 
25
  """
26
  )
27
 
@@ -48,9 +51,9 @@ builder.add_edge("tools", "assistant")
48
  react_graph = builder.compile(checkpointer=memory)
49
 
50
  # Specify a thread
51
- config = {"configurable": {"thread_id": "11"}}
52
 
53
- messages = [HumanMessage(content="What are the first-line treatments for HIV in Kenya?")]
54
  messages = react_graph.invoke({"messages": messages}, config)
55
  for m in messages['messages']:
56
  m.pretty_print()
 
13
 
14
  from chatlib.state_types import State
15
  from chatlib.guidlines_rag_agent_li import rag_retrieve
16
+ from chatlib.patient_sql_agent import sql_chain
17
 
18
+ tools = [rag_retrieve, sql_chain]
19
  llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
20
+ llm_with_tools = llm.bind_tools([rag_retrieve, sql_chain])
21
 
22
  # System message
23
  sys_msg = SystemMessage(content="""
24
  You are a helpful assistant tasked with helping clinicians
25
+ meeting with patients. You have two tools available,
26
+ one to access information from HIV clinical guidelines, the other is
27
+ a SQL tool to access patient data.
28
  """
29
  )
30
 
 
51
  react_graph = builder.compile(checkpointer=memory)
52
 
53
  # Specify a thread
54
+ config = {"configurable": {"thread_id": "13"}}
55
 
56
+ messages = [HumanMessage(content="what is the proper course of treatment for someone with opportunistic infections?")]
57
  messages = react_graph.invoke({"messages": messages}, config)
58
  for m in messages['messages']:
59
  m.pretty_print()