Spaces:
Sleeping
Sleeping
JDFPalladium commited on
Commit ·
5e988c3
1
Parent(s): 9beda2c
adding guidance rag agent
Browse files- chat.ipynb +2 -2
- chatlib/__init__.py +0 -0
- chatlib/guidlines_rag_agent_li.py +32 -0
- chatlib/patient_sql_agent.py +0 -0
- chatlib/state_types.py +11 -0
- main.py +56 -0
- requirements.txt +2 -1
- sql_agent.ipynb +6 -38
chat.ipynb
CHANGED
|
@@ -20,7 +20,7 @@
|
|
| 20 |
},
|
| 21 |
{
|
| 22 |
"cell_type": "code",
|
| 23 |
-
"execution_count":
|
| 24 |
"id": "73bd3df7",
|
| 25 |
"metadata": {},
|
| 26 |
"outputs": [
|
|
@@ -40,7 +40,7 @@
|
|
| 40 |
"# Load index for retrieval\n",
|
| 41 |
"storage_context = StorageContext.from_defaults(persist_dir=\"arv_metadata\")\n",
|
| 42 |
"index = load_index_from_storage(storage_context)\n",
|
| 43 |
-
"retriever = index.as_retriever(similarity_top_k=
|
| 44 |
" # Similarity threshold for filtering\n",
|
| 45 |
" similarity_threshold=0.5)\n",
|
| 46 |
"\n",
|
|
|
|
| 20 |
},
|
| 21 |
{
|
| 22 |
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
"id": "73bd3df7",
|
| 25 |
"metadata": {},
|
| 26 |
"outputs": [
|
|
|
|
| 40 |
"# Load index for retrieval\n",
|
| 41 |
"storage_context = StorageContext.from_defaults(persist_dir=\"arv_metadata\")\n",
|
| 42 |
"index = load_index_from_storage(storage_context)\n",
|
| 43 |
+
"retriever = index.as_retriever(similarity_top_k=5,\n",
|
| 44 |
" # Similarity threshold for filtering\n",
|
| 45 |
" similarity_threshold=0.5)\n",
|
| 46 |
"\n",
|
chatlib/__init__.py
ADDED
|
File without changes
|
chatlib/guidlines_rag_agent_li.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llama_index.core import StorageContext, load_index_from_storage
|
| 2 |
+
from .state_types import State
|
| 3 |
+
|
| 4 |
+
# Load index for retrieval
|
| 5 |
+
storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
|
| 6 |
+
index = load_index_from_storage(storage_context)
|
| 7 |
+
retriever = index.as_retriever(similarity_top_k=5,
|
| 8 |
+
# Similarity threshold for filtering
|
| 9 |
+
similarity_threshold=0.5)
|
| 10 |
+
|
| 11 |
+
def rag_retrieve(state:State) -> State:
|
| 12 |
+
"""Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya.
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
user_prompt = state["question"] # or whatever key holds the prompt
|
| 16 |
+
sources = retriever.retrieve(user_prompt)
|
| 17 |
+
retrieved_text = "\n\n".join([f"Source {i+1}: {source.text}" for i, source in enumerate(sources)])
|
| 18 |
+
|
| 19 |
+
return {**state, "rag_result": "RAG search results for: " + retrieved_text}
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
# Test the function
|
| 23 |
+
test_state = State(
|
| 24 |
+
messages=[],
|
| 25 |
+
question="What are the first-line treatments for HIV in Kenya?",
|
| 26 |
+
rag_result="",
|
| 27 |
+
query="",
|
| 28 |
+
result="",
|
| 29 |
+
answer=""
|
| 30 |
+
)
|
| 31 |
+
updated_state = rag_retrieve(test_state)
|
| 32 |
+
print(updated_state["rag_result"])
|
chatlib/patient_sql_agent.py
ADDED
|
File without changes
|
chatlib/state_types.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import TypedDict, Annotated
|
| 2 |
+
from langchain_core.messages import AnyMessage
|
| 3 |
+
from langgraph.graph.message import add_messages
|
| 4 |
+
|
| 5 |
+
class State(TypedDict):
|
| 6 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
| 7 |
+
question: str
|
| 8 |
+
rag_result: str
|
| 9 |
+
query: str
|
| 10 |
+
result: str
|
| 11 |
+
answer: str
|
main.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
from langchain_openai import ChatOpenAI
|
| 4 |
+
from langgraph.graph import MessagesState, START, StateGraph
|
| 5 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 6 |
+
from langgraph.prebuilt import tools_condition, ToolNode
|
| 7 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 8 |
+
memory = MemorySaver()
|
| 9 |
+
|
| 10 |
+
load_dotenv("config.env")
|
| 11 |
+
os.environ.get("OPENAI_API_KEY")
|
| 12 |
+
os.environ.get("LANGSMITH_API_KEY")
|
| 13 |
+
|
| 14 |
+
from chatlib.state_types import State
|
| 15 |
+
from chatlib.guidlines_rag_agent_li import rag_retrieve
|
| 16 |
+
|
| 17 |
+
tools = [rag_retrieve]
|
| 18 |
+
llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
|
| 19 |
+
llm_with_tools = llm.bind_tools([rag_retrieve])
|
| 20 |
+
|
| 21 |
+
# System message
|
| 22 |
+
sys_msg = SystemMessage(content="""
|
| 23 |
+
You are a helpful assistant tasked with helping clinicians
|
| 24 |
+
access information from HIV clinical guidelines.
|
| 25 |
+
"""
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Assistant Node
|
| 29 |
+
def assistant(state: MessagesState):
|
| 30 |
+
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
|
| 31 |
+
|
| 32 |
+
# Graph
|
| 33 |
+
builder = StateGraph(MessagesState)
|
| 34 |
+
|
| 35 |
+
# Define nodes: these do the work
|
| 36 |
+
builder.add_node("assistant", assistant)
|
| 37 |
+
builder.add_node("tools", ToolNode(tools))
|
| 38 |
+
|
| 39 |
+
# Define edges: these determine how the control flow moves
|
| 40 |
+
builder.add_edge(START, "assistant")
|
| 41 |
+
builder.add_conditional_edges(
|
| 42 |
+
"assistant",
|
| 43 |
+
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
|
| 44 |
+
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
|
| 45 |
+
tools_condition,
|
| 46 |
+
)
|
| 47 |
+
builder.add_edge("tools", "assistant")
|
| 48 |
+
react_graph = builder.compile(checkpointer=memory)
|
| 49 |
+
|
| 50 |
+
# Specify a thread
|
| 51 |
+
config = {"configurable": {"thread_id": "11"}}
|
| 52 |
+
|
| 53 |
+
messages = [HumanMessage(content="What are the first-line treatments for HIV in Kenya?")]
|
| 54 |
+
messages = react_graph.invoke({"messages": messages}, config)
|
| 55 |
+
for m in messages['messages']:
|
| 56 |
+
m.pretty_print()
|
requirements.txt
CHANGED
|
@@ -14,4 +14,5 @@ langgraph-cli[inmem]
|
|
| 14 |
llama_index==0.12.34
|
| 15 |
pylint
|
| 16 |
black
|
| 17 |
-
pytest
|
|
|
|
|
|
| 14 |
llama_index==0.12.34
|
| 15 |
pylint
|
| 16 |
black
|
| 17 |
+
pytest
|
| 18 |
+
dotenv
|
sql_agent.ipynb
CHANGED
|
@@ -41,37 +41,7 @@
|
|
| 41 |
},
|
| 42 |
{
|
| 43 |
"cell_type": "code",
|
| 44 |
-
"execution_count":
|
| 45 |
-
"id": "40ddb630",
|
| 46 |
-
"metadata": {},
|
| 47 |
-
"outputs": [
|
| 48 |
-
{
|
| 49 |
-
"data": {
|
| 50 |
-
"text/plain": [
|
| 51 |
-
"[QuerySQLDatabaseTool(description=\"Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.\", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>),\n",
|
| 52 |
-
" InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>),\n",
|
| 53 |
-
" ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>),\n",
|
| 54 |
-
" QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>, llm=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x707a4004e300>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x707a33d0d370>, root_client=<openai.OpenAI object at 0x707a40194050>, root_async_client=<openai.AsyncOpenAI object at 0x707a4004df70>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\\n{query}\\nDouble check the {dialect} query above for common mistakes, including:\\n- Using NOT IN with NULL values\\n- Using UNION when UNION ALL should have been used\\n- Using BETWEEN for exclusive ranges\\n- Data type mismatch in predicates\\n- Properly quoting identifiers\\n- Using the correct number of arguments for functions\\n- Casting to the correct data type\\n- Using the proper columns for joins\\n\\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\\n\\nOutput the final SQL query only.\\n\\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x707a4004e300>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x707a33d0d370>, root_client=<openai.OpenAI object at 0x707a40194050>, root_async_client=<openai.AsyncOpenAI object at 0x707a4004df70>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]"
|
| 55 |
-
]
|
| 56 |
-
},
|
| 57 |
-
"execution_count": 3,
|
| 58 |
-
"metadata": {},
|
| 59 |
-
"output_type": "execute_result"
|
| 60 |
-
}
|
| 61 |
-
],
|
| 62 |
-
"source": [
|
| 63 |
-
"from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
|
| 64 |
-
"\n",
|
| 65 |
-
"toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"tools = toolkit.get_tools()\n",
|
| 68 |
-
"\n",
|
| 69 |
-
"tools"
|
| 70 |
-
]
|
| 71 |
-
},
|
| 72 |
-
{
|
| 73 |
-
"cell_type": "code",
|
| 74 |
-
"execution_count": 4,
|
| 75 |
"id": "f9c96976",
|
| 76 |
"metadata": {},
|
| 77 |
"outputs": [
|
|
@@ -131,8 +101,8 @@
|
|
| 131 |
" [(\"system\", system_message), (\"user\", user_prompt)]\n",
|
| 132 |
")\n",
|
| 133 |
"\n",
|
| 134 |
-
"for message in query_prompt_template.messages:\n",
|
| 135 |
-
"
|
| 136 |
]
|
| 137 |
},
|
| 138 |
{
|
|
@@ -349,7 +319,7 @@
|
|
| 349 |
},
|
| 350 |
{
|
| 351 |
"cell_type": "code",
|
| 352 |
-
"execution_count":
|
| 353 |
"id": "5fb11ed6",
|
| 354 |
"metadata": {},
|
| 355 |
"outputs": [
|
|
@@ -357,9 +327,7 @@
|
|
| 357 |
"name": "stdout",
|
| 358 |
"output_type": "stream",
|
| 359 |
"text": [
|
| 360 |
-
"{'assistant': {'messages': [AIMessage(content=
|
| 361 |
-
"{'tools': {'messages': [ToolMessage(content=\"{'messages': [HumanMessage(content='how many unique regimens are there?', additional_kwargs={}, response_metadata={})], 'question': 'how many unique regimens are there?', 'query': 'SELECT COUNT(DISTINCT CurrentRegimen) AS unique_regimen_count FROM clinical_visits;', 'result': '[(12,)]', 'answer': 'There are 12 unique regimens.'}\", name='sql_chain', id='80f0822d-ee23-4c9d-a28c-359a1200a173', tool_call_id='call_fFrF2w9DXG12kOFn4Ox5wnGJ')]}}\n",
|
| 362 |
-
"{'assistant': {'messages': [AIMessage(content='There are 12 unique regimens.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 9, 'prompt_tokens': 3309, 'total_tokens': 3318, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 3072}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-BkBcBVVLMPEjPvVpK7Ap8L24bNDYH', 'service_tier': 'default', 'finish_reason': 'stop', 'logprobs': None}, id='run--4dea7411-1b2b-4f3b-b8e0-09b2407221fd-0', usage_metadata={'input_tokens': 3309, 'output_tokens': 9, 'total_tokens': 3318, 'input_token_details': {'audio': 0, 'cache_read': 3072}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}}\n"
|
| 363 |
]
|
| 364 |
}
|
| 365 |
],
|
|
@@ -367,7 +335,7 @@
|
|
| 367 |
"# Specify a thread\n",
|
| 368 |
"config = {\"configurable\": {\"thread_id\": \"4\"}}\n",
|
| 369 |
"\n",
|
| 370 |
-
"user_prompt = \"
|
| 371 |
"input_state = {\n",
|
| 372 |
" \"messages\": [HumanMessage(content=user_prompt)],\n",
|
| 373 |
" \"question\": user_prompt,\n",
|
|
|
|
| 41 |
},
|
| 42 |
{
|
| 43 |
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"id": "f9c96976",
|
| 46 |
"metadata": {},
|
| 47 |
"outputs": [
|
|
|
|
| 101 |
" [(\"system\", system_message), (\"user\", user_prompt)]\n",
|
| 102 |
")\n",
|
| 103 |
"\n",
|
| 104 |
+
"# for message in query_prompt_template.messages:\n",
|
| 105 |
+
"# message.pretty_print()"
|
| 106 |
]
|
| 107 |
},
|
| 108 |
{
|
|
|
|
| 319 |
},
|
| 320 |
{
|
| 321 |
"cell_type": "code",
|
| 322 |
+
"execution_count": 17,
|
| 323 |
"id": "5fb11ed6",
|
| 324 |
"metadata": {},
|
| 325 |
"outputs": [
|
|
|
|
| 327 |
"name": "stdout",
|
| 328 |
"output_type": "stream",
|
| 329 |
"text": [
|
| 330 |
+
"{'assistant': {'messages': [AIMessage(content=\"You're welcome! If you have any more questions, feel free to ask.\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 3327, 'total_tokens': 3343, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-BkBjCRm7HUlUQZV3ntGa2Sbpz8UPJ', 'service_tier': 'default', 'finish_reason': 'stop', 'logprobs': None}, id='run--45e27132-7994-4010-91e7-cd18113d643c-0', usage_metadata={'input_tokens': 3327, 'output_tokens': 16, 'total_tokens': 3343, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}}\n"
|
|
|
|
|
|
|
| 331 |
]
|
| 332 |
}
|
| 333 |
],
|
|
|
|
| 335 |
"# Specify a thread\n",
|
| 336 |
"config = {\"configurable\": {\"thread_id\": \"4\"}}\n",
|
| 337 |
"\n",
|
| 338 |
+
"user_prompt = \"thanks!?\"\n",
|
| 339 |
"input_state = {\n",
|
| 340 |
" \"messages\": [HumanMessage(content=user_prompt)],\n",
|
| 341 |
" \"question\": user_prompt,\n",
|