yushnitp commited on
Commit
2bbb8d4
·
verified ·
1 Parent(s): 698c93a

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +138 -0
agent.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """""LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
+ from langchain_community.vectorstores import SupabaseVectorStore
12
+ from langchain_core.messages import SystemMessage, HumanMessage
13
+ from langchain_core.tools import tool
14
+ from langchain.tools.retriever import create_retriever_tool
15
+ from supabase.client import Client, create_client
16
+
17
+ load_dotenv()
18
+
19
+ @tool
20
+ def multiply(a: int, b: int) -> int:
21
+ return a * b
22
+
23
+ @tool
24
+ def add(a: int, b: int) -> int:
25
+ return a + b
26
+
27
+ @tool
28
+ def subtract(a: int, b: int) -> int:
29
+ return a - b
30
+
31
+ @tool
32
+ def divide(a: int, b: int) -> int:
33
+ if b == 0:
34
+ raise ValueError("Cannot divide by zero.")
35
+ return a / b
36
+
37
+ @tool
38
+ def modulus(a: int, b: int) -> int:
39
+ return a % b
40
+
41
+ @tool
42
+ def wiki_search(query: str) -> str:
43
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
44
+ formatted_search_docs = "\n\n---\n\n".join(
45
+ [
46
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
47
+ for doc in search_docs
48
+ ])
49
+ return {"wiki_results": formatted_search_docs}
50
+
51
+ @tool
52
+ def web_search(query: str) -> str:
53
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
54
+ formatted_search_docs = "\n\n---\n\n".join(
55
+ [
56
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
57
+ for doc in search_docs
58
+ ])
59
+ return {"web_results": formatted_search_docs}
60
+
61
+ @tool
62
+ def arvix_search(query: str) -> str:
63
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
64
+ formatted_search_docs = "\n\n---\n\n".join(
65
+ [
66
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
67
+ for doc in search_docs
68
+ ])
69
+ return {"arvix_results": formatted_search_docs}
70
+
71
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
72
+ system_prompt = f.read()
73
+
74
+ sys_msg = SystemMessage(content=system_prompt)
75
+
76
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
77
+ supabase: Client = create_client(
78
+ os.environ.get("SUPABASE_URL"),
79
+ os.environ.get("SUPABASE_SERVICE_KEY"))
80
+ vector_store = SupabaseVectorStore(
81
+ client=supabase,
82
+ embedding=embeddings,
83
+ table_name="documents",
84
+ query_name="match_documents_langchain",
85
+ )
86
+ create_retriever_tool = create_retriever_tool(
87
+ retriever=vector_store.as_retriever(),
88
+ name="Question Search",
89
+ description="A tool to retrieve similar questions from a vector store.",
90
+ )
91
+
92
+ tools = [
93
+ multiply,
94
+ add,
95
+ subtract,
96
+ divide,
97
+ modulus,
98
+ wiki_search,
99
+ web_search,
100
+ arvix_search,
101
+ ]
102
+
103
+ def build_graph(provider: str = "openai"):
104
+ if provider == "openai":
105
+ llm = ChatOpenAI(model="gpt-4", temperature=0)
106
+ else:
107
+ raise ValueError("Invalid provider. Only 'openai' is supported in this version.")
108
+
109
+ llm_with_tools = llm.bind_tools(tools)
110
+
111
+ def assistant(state: MessagesState):
112
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
113
+
114
+ def retriever(state: MessagesState):
115
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
116
+ example_msg = HumanMessage(
117
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}"
118
+ )
119
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
120
+
121
+ builder = StateGraph(MessagesState)
122
+ builder.add_node("retriever", retriever)
123
+ builder.add_node("assistant", assistant)
124
+ builder.add_node("tools", ToolNode(tools))
125
+ builder.add_edge(START, "retriever")
126
+ builder.add_edge("retriever", "assistant")
127
+ builder.add_conditional_edges("assistant", tools_condition)
128
+ builder.add_edge("tools", "assistant")
129
+
130
+ return builder.compile()
131
+
132
+ if __name__ == "__main__":
133
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
134
+ graph = build_graph(provider="openai")
135
+ messages = [HumanMessage(content=question)]
136
+ messages = graph.invoke({"messages": messages})
137
+ for m in messages["messages"]:
138
+ m.pretty_print()