0f3dy commited on
Commit
c4ff6dd
·
verified ·
1 Parent(s): 0918d51

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -190
agent.py DELETED
@@ -1,190 +0,0 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langchain.tools import tool
4
- from langchain_core.messages import SystemMessage, HumanMessage
5
- from langchain_core.tools import tool
6
- from langchain.tools.retriever import create_retriever_tool
7
- from langchain_community.tools.tavily_search import TavilySearchResults
8
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
9
- from langchain_community.vectorstores import SupabaseVectorStore
10
- from langchain_groq import ChatGroq
11
- from langgraph.prebuilt import tools_condition, ToolNode
12
- from langgraph.graph import START, StateGraph, MessagesState
13
- from langchain_huggingface import HuggingFaceEmbeddings
14
- from langchain_huggingface import ChatHuggingFace
15
- from supabase.client import Client, create_client
16
-
17
- load_dotenv()
18
-
19
- @tool
20
- def multiply(a: int, b: int) -> int:
21
- """Multiply two numbers.
22
-
23
- Args:
24
- a: first int
25
- b: second int
26
- """
27
- return a * b
28
-
29
- @tool
30
- def add(a: int, b: int) -> int:
31
- """Add two numbers.
32
-
33
- Args:
34
- a: first int
35
- b: second int
36
- """
37
- return a + b
38
-
39
- @tool
40
- def subtract(a: int, b: int) -> int:
41
- """Subtract two numbers.
42
-
43
- Args:
44
- a: first int
45
- b: second int
46
- """
47
- return a - b
48
-
49
- @tool
50
- def divide(a: int, b: int) -> int:
51
- """Divide two numbers.
52
-
53
- Args:
54
- a: first int
55
- b: second int
56
- """
57
- if b == 0:
58
- raise ValueError("Cannot divide by zero.")
59
- return a / b
60
-
61
- @tool
62
- def modulus(a: int, b: int) -> int:
63
- """Get the modulus of two numbers.
64
-
65
- Args:
66
- a: first int
67
- b: second int
68
- """
69
- return a % b
70
-
71
- @tool
72
- def wiki_search(query: str) -> str:
73
- """Search Wikipedia for a query and return maximum 2 results.
74
-
75
- Args:
76
- query: The search query."""
77
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
78
- formatted_search_docs = "\n\n---\n\n".join(
79
- [
80
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
81
- for doc in search_docs
82
- ])
83
- return {"wiki_results": formatted_search_docs}
84
-
85
- @tool
86
- def web_search(query: str) -> str:
87
- """Search Tavily for a query and return maximum 3 results.
88
-
89
- Args:
90
- query: The search query."""
91
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
92
- formatted_search_docs = "\n\n---\n\n".join(
93
- [
94
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
95
- for doc in search_docs
96
- ])
97
- return {"web_results": formatted_search_docs}
98
-
99
- @tool
100
- def arvix_search(query: str) -> str:
101
- """Search Arxiv for a query and return maximum 3 result.
102
-
103
- Args:
104
- query: The search query."""
105
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
106
- formatted_search_docs = "\n\n---\n\n".join(
107
- [
108
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
109
- for doc in search_docs
110
- ])
111
- return {"arvix_results": formatted_search_docs}
112
-
113
- # load the system prompt from the file
114
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
115
- system_prompt = f.read()
116
-
117
- # System message
118
- sys_msg = SystemMessage(content=system_prompt)
119
-
120
- # build a retriever
121
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
122
- supabase: Client = create_client(
123
- os.environ.get("SUPABASE_URL"),
124
- os.environ.get("SUPABASE_SERVICE_KEY"))
125
- vector_store = SupabaseVectorStore(
126
- client=supabase,
127
- embedding= embeddings,
128
- table_name="documents",
129
- query_name="get_docs",
130
- )
131
- create_retriever_tool = create_retriever_tool(
132
- retriever=vector_store.as_retriever(),
133
- name="Question Search",
134
- description="A tool to retrieve similar questions from a vector store.",
135
- )
136
-
137
-
138
- tools = [
139
- multiply,
140
- add,
141
- subtract,
142
- divide,
143
- modulus,
144
- wiki_search,
145
- web_search,
146
- arvix_search,
147
- ]
148
-
149
- # Build the state graph
150
- def build_graph():
151
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
152
- llm_with_tools = llm.bind_tools(tools)
153
-
154
- def assistant_node(state: MessagesState):
155
- """Assistant node"""
156
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
157
-
158
- def retriever_node(state: MessagesState):
159
- """Retriever node"""
160
- similar_question = vector_store.similarity_search(state["messages"][0].content)
161
- if not similar_question:
162
- return {"messages": [HumanMessage(content="No similar questions found in the database.")]}
163
- example_msg = HumanMessage(
164
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
165
- )
166
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
167
-
168
- build_graph = StateGraph(MessagesState)
169
- build_graph.add_node("retreiver", retriever_node)
170
- build_graph.add_node("assistant", assistant_node)
171
- build_graph.add_node("tools", ToolNode(tools=tools))
172
- build_graph.add_edge(START, "retreiver")
173
- build_graph.add_edge("retreiver", "assistant")
174
- build_graph.add_conditional_edges(
175
- "assistant",
176
- tools_condition
177
- )
178
- build_graph.add_edge("tools", "assistant")
179
- return build_graph.compile()
180
-
181
- # test
182
- if __name__ == "__main__":
183
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
184
- # Build the graph
185
- graph = build_graph()
186
- # Run the graph
187
- messages = [HumanMessage(content=question)]
188
- messages = graph.invoke({"messages": messages})
189
- for m in messages["messages"]:
190
- m.pretty_print()