JDFPalladium commited on
Commit
5e988c3
·
1 Parent(s): 9beda2c

adding guidance rag agent

Browse files
chat.ipynb CHANGED
@@ -20,7 +20,7 @@
20
  },
21
  {
22
  "cell_type": "code",
23
- "execution_count": 6,
24
  "id": "73bd3df7",
25
  "metadata": {},
26
  "outputs": [
@@ -40,7 +40,7 @@
40
  "# Load index for retrieval\n",
41
  "storage_context = StorageContext.from_defaults(persist_dir=\"arv_metadata\")\n",
42
  "index = load_index_from_storage(storage_context)\n",
43
- "retriever = index.as_retriever(similarity_top_k=10,\n",
44
  " # Similarity threshold for filtering\n",
45
  " similarity_threshold=0.5)\n",
46
  "\n",
 
20
  },
21
  {
22
  "cell_type": "code",
23
+ "execution_count": null,
24
  "id": "73bd3df7",
25
  "metadata": {},
26
  "outputs": [
 
40
  "# Load index for retrieval\n",
41
  "storage_context = StorageContext.from_defaults(persist_dir=\"arv_metadata\")\n",
42
  "index = load_index_from_storage(storage_context)\n",
43
+ "retriever = index.as_retriever(similarity_top_k=5,\n",
44
  " # Similarity threshold for filtering\n",
45
  " similarity_threshold=0.5)\n",
46
  "\n",
chatlib/__init__.py ADDED
File without changes
chatlib/guidlines_rag_agent_li.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.core import StorageContext, load_index_from_storage
2
+ from .state_types import State
3
+
4
+ # Load index for retrieval
5
+ storage_context = StorageContext.from_defaults(persist_dir="arv_metadata")
6
+ index = load_index_from_storage(storage_context)
7
+ retriever = index.as_retriever(similarity_top_k=5,
8
+ # Similarity threshold for filtering
9
+ similarity_threshold=0.5)
10
+
11
+ def rag_retrieve(state:State) -> State:
12
+ """Perform RAG search of repository containing authoritative information on HIV/AIDS in Kenya.
13
+
14
+ """
15
+ user_prompt = state["question"] # or whatever key holds the prompt
16
+ sources = retriever.retrieve(user_prompt)
17
+ retrieved_text = "\n\n".join([f"Source {i+1}: {source.text}" for i, source in enumerate(sources)])
18
+
19
+ return {**state, "rag_result": "RAG search results for: " + retrieved_text}
20
+
21
+ if __name__ == "__main__":
22
+ # Test the function
23
+ test_state = State(
24
+ messages=[],
25
+ question="What are the first-line treatments for HIV in Kenya?",
26
+ rag_result="",
27
+ query="",
28
+ result="",
29
+ answer=""
30
+ )
31
+ updated_state = rag_retrieve(test_state)
32
+ print(updated_state["rag_result"])
chatlib/patient_sql_agent.py ADDED
File without changes
chatlib/state_types.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import TypedDict, Annotated
2
+ from langchain_core.messages import AnyMessage
3
+ from langgraph.graph.message import add_messages
4
+
5
+ class State(TypedDict):
6
+ messages: Annotated[list[AnyMessage], add_messages]
7
+ question: str
8
+ rag_result: str
9
+ query: str
10
+ result: str
11
+ answer: str
main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ from langchain_openai import ChatOpenAI
4
+ from langgraph.graph import MessagesState, START, StateGraph
5
+ from langchain_core.messages import HumanMessage, SystemMessage
6
+ from langgraph.prebuilt import tools_condition, ToolNode
7
+ from langgraph.checkpoint.memory import MemorySaver
8
+ memory = MemorySaver()
9
+
10
+ load_dotenv("config.env")
11
+ os.environ.get("OPENAI_API_KEY")
12
+ os.environ.get("LANGSMITH_API_KEY")
13
+
14
+ from chatlib.state_types import State
15
+ from chatlib.guidlines_rag_agent_li import rag_retrieve
16
+
17
+ tools = [rag_retrieve]
18
+ llm = ChatOpenAI(temperature = 0.0, model="gpt-4o")
19
+ llm_with_tools = llm.bind_tools([rag_retrieve])
20
+
21
+ # System message
22
+ sys_msg = SystemMessage(content="""
23
+ You are a helpful assistant tasked with helping clinicians
24
+ access information from HIV clinical guidelines.
25
+ """
26
+ )
27
+
28
+ # Assistant Node
29
+ def assistant(state: MessagesState):
30
+ return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
31
+
32
+ # Graph
33
+ builder = StateGraph(MessagesState)
34
+
35
+ # Define nodes: these do the work
36
+ builder.add_node("assistant", assistant)
37
+ builder.add_node("tools", ToolNode(tools))
38
+
39
+ # Define edges: these determine how the control flow moves
40
+ builder.add_edge(START, "assistant")
41
+ builder.add_conditional_edges(
42
+ "assistant",
43
+ # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
44
+ # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
45
+ tools_condition,
46
+ )
47
+ builder.add_edge("tools", "assistant")
48
+ react_graph = builder.compile(checkpointer=memory)
49
+
50
+ # Specify a thread
51
+ config = {"configurable": {"thread_id": "11"}}
52
+
53
+ messages = [HumanMessage(content="What are the first-line treatments for HIV in Kenya?")]
54
+ messages = react_graph.invoke({"messages": messages}, config)
55
+ for m in messages['messages']:
56
+ m.pretty_print()
requirements.txt CHANGED
@@ -14,4 +14,5 @@ langgraph-cli[inmem]
14
  llama_index==0.12.34
15
  pylint
16
  black
17
- pytest
 
 
14
  llama_index==0.12.34
15
  pylint
16
  black
17
+ pytest
18
+ dotenv
sql_agent.ipynb CHANGED
@@ -41,37 +41,7 @@
41
  },
42
  {
43
  "cell_type": "code",
44
- "execution_count": 3,
45
- "id": "40ddb630",
46
- "metadata": {},
47
- "outputs": [
48
- {
49
- "data": {
50
- "text/plain": [
51
- "[QuerySQLDatabaseTool(description=\"Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.\", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>),\n",
52
- " InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>),\n",
53
- " ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>),\n",
54
- " QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x707a407ab3e0>, llm=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x707a4004e300>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x707a33d0d370>, root_client=<openai.OpenAI object at 0x707a40194050>, root_async_client=<openai.AsyncOpenAI object at 0x707a4004df70>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\\n{query}\\nDouble check the {dialect} query above for common mistakes, including:\\n- Using NOT IN with NULL values\\n- Using UNION when UNION ALL should have been used\\n- Using BETWEEN for exclusive ranges\\n- Data type mismatch in predicates\\n- Properly quoting identifiers\\n- Using the correct number of arguments for functions\\n- Casting to the correct data type\\n- Using the proper columns for joins\\n\\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\\n\\nOutput the final SQL query only.\\n\\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x707a4004e300>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x707a33d0d370>, root_client=<openai.OpenAI object at 0x707a40194050>, root_async_client=<openai.AsyncOpenAI object at 0x707a4004df70>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]"
55
- ]
56
- },
57
- "execution_count": 3,
58
- "metadata": {},
59
- "output_type": "execute_result"
60
- }
61
- ],
62
- "source": [
63
- "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
64
- "\n",
65
- "toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n",
66
- "\n",
67
- "tools = toolkit.get_tools()\n",
68
- "\n",
69
- "tools"
70
- ]
71
- },
72
- {
73
- "cell_type": "code",
74
- "execution_count": 4,
75
  "id": "f9c96976",
76
  "metadata": {},
77
  "outputs": [
@@ -131,8 +101,8 @@
131
  " [(\"system\", system_message), (\"user\", user_prompt)]\n",
132
  ")\n",
133
  "\n",
134
- "for message in query_prompt_template.messages:\n",
135
- " message.pretty_print()"
136
  ]
137
  },
138
  {
@@ -349,7 +319,7 @@
349
  },
350
  {
351
  "cell_type": "code",
352
- "execution_count": 16,
353
  "id": "5fb11ed6",
354
  "metadata": {},
355
  "outputs": [
@@ -357,9 +327,7 @@
357
  "name": "stdout",
358
  "output_type": "stream",
359
  "text": [
360
- "{'assistant': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_fFrF2w9DXG12kOFn4Ox5wnGJ', 'function': {'arguments': '{\"state\":{\"messages\":[{\"content\":\"how many unique regimens are there?\",\"type\":\"human\"}],\"question\":\"how many unique regimens are there?\",\"query\":\"SELECT COUNT(DISTINCT regimen) FROM patient_records;\",\"result\":\"42\",\"answer\":\"There are 42 unique regimens.\"}}', 'name': 'sql_chain'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 71, 'prompt_tokens': 3149, 'total_tokens': 3220, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-BkBc7T5CPXYLG1qFVcWony1uByR15', 'service_tier': 'default', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run--65558f43-fd67-4e1c-83dc-068693dcda8e-0', tool_calls=[{'name': 'sql_chain', 'args': {'state': {'messages': [{'content': 'how many unique regimens are there?', 'type': 'human'}], 'question': 'how many unique regimens are there?', 'query': 'SELECT COUNT(DISTINCT regimen) FROM patient_records;', 'result': '42', 'answer': 'There are 42 unique regimens.'}}, 'id': 'call_fFrF2w9DXG12kOFn4Ox5wnGJ', 'type': 'tool_call'}], usage_metadata={'input_tokens': 3149, 'output_tokens': 71, 'total_tokens': 3220, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}}\n",
361
- "{'tools': {'messages': [ToolMessage(content=\"{'messages': [HumanMessage(content='how many unique regimens are there?', additional_kwargs={}, response_metadata={})], 'question': 'how many unique regimens are there?', 'query': 'SELECT COUNT(DISTINCT CurrentRegimen) AS unique_regimen_count FROM clinical_visits;', 'result': '[(12,)]', 'answer': 'There are 12 unique regimens.'}\", name='sql_chain', id='80f0822d-ee23-4c9d-a28c-359a1200a173', tool_call_id='call_fFrF2w9DXG12kOFn4Ox5wnGJ')]}}\n",
362
- "{'assistant': {'messages': [AIMessage(content='There are 12 unique regimens.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 9, 'prompt_tokens': 3309, 'total_tokens': 3318, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 3072}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-BkBcBVVLMPEjPvVpK7Ap8L24bNDYH', 'service_tier': 'default', 'finish_reason': 'stop', 'logprobs': None}, id='run--4dea7411-1b2b-4f3b-b8e0-09b2407221fd-0', usage_metadata={'input_tokens': 3309, 'output_tokens': 9, 'total_tokens': 3318, 'input_token_details': {'audio': 0, 'cache_read': 3072}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}}\n"
363
  ]
364
  }
365
  ],
@@ -367,7 +335,7 @@
367
  "# Specify a thread\n",
368
  "config = {\"configurable\": {\"thread_id\": \"4\"}}\n",
369
  "\n",
370
- "user_prompt = \"sorry, i meant to ask how many unique regimens are there?\"\n",
371
  "input_state = {\n",
372
  " \"messages\": [HumanMessage(content=user_prompt)],\n",
373
  " \"question\": user_prompt,\n",
 
41
  },
42
  {
43
  "cell_type": "code",
44
+ "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  "id": "f9c96976",
46
  "metadata": {},
47
  "outputs": [
 
101
  " [(\"system\", system_message), (\"user\", user_prompt)]\n",
102
  ")\n",
103
  "\n",
104
+ "# for message in query_prompt_template.messages:\n",
105
+ "# message.pretty_print()"
106
  ]
107
  },
108
  {
 
319
  },
320
  {
321
  "cell_type": "code",
322
+ "execution_count": 17,
323
  "id": "5fb11ed6",
324
  "metadata": {},
325
  "outputs": [
 
327
  "name": "stdout",
328
  "output_type": "stream",
329
  "text": [
330
+ "{'assistant': {'messages': [AIMessage(content=\"You're welcome! If you have any more questions, feel free to ask.\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 3327, 'total_tokens': 3343, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_07871e2ad8', 'id': 'chatcmpl-BkBjCRm7HUlUQZV3ntGa2Sbpz8UPJ', 'service_tier': 'default', 'finish_reason': 'stop', 'logprobs': None}, id='run--45e27132-7994-4010-91e7-cd18113d643c-0', usage_metadata={'input_tokens': 3327, 'output_tokens': 16, 'total_tokens': 3343, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})]}}\n"
 
 
331
  ]
332
  }
333
  ],
 
335
  "# Specify a thread\n",
336
  "config = {\"configurable\": {\"thread_id\": \"4\"}}\n",
337
  "\n",
338
+ "user_prompt = \"thanks!?\"\n",
339
  "input_state = {\n",
340
  " \"messages\": [HumanMessage(content=user_prompt)],\n",
341
  " \"question\": user_prompt,\n",