hhhhmmmm commited on
Commit
8cc78a9
·
verified ·
1 Parent(s): 6353325

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +213 -0
agent.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+ from google.colab import userdata
19
+
20
+ load_dotenv()
21
+
22
+ #os.environ["SUPABASE_URL"] = userdata.get('SUPABASE_URL')
23
+ #os.environ["SUPABASE_SERVICE_KEY"] = userdata.get('SUPABASE_SERVICE_KEY')
24
+ #os.environ["GOOGLE_API_KEY"] = userdata.get('GOOGLE_API_KEY')
25
+ #os.environ["TAVILY_API_KEY"] = userdata.get('TAVILY_API_KEY')
26
+ #os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
27
+
28
+ @tool
29
+ def multiply(a: int, b: int) -> int:
30
+ """Multiply two numbers.
31
+ Args:
32
+ a: first int
33
+ b: second int
34
+ """
35
+ return a * b
36
+
37
+ @tool
38
+ def add(a: int, b: int) -> int:
39
+ """Add two numbers.
40
+
41
+ Args:
42
+ a: first int
43
+ b: second int
44
+ """
45
+ return a + b
46
+
47
+ @tool
48
+ def subtract(a: int, b: int) -> int:
49
+ """Subtract two numbers.
50
+
51
+ Args:
52
+ a: first int
53
+ b: second int
54
+ """
55
+ return a - b
56
+
57
+ @tool
58
+ def divide(a: int, b: int) -> int:
59
+ """Divide two numbers.
60
+
61
+ Args:
62
+ a: first int
63
+ b: second int
64
+ """
65
+ if b == 0:
66
+ raise ValueError("Cannot divide by zero.")
67
+ return a / b
68
+
69
+ @tool
70
+ def modulus(a: int, b: int) -> int:
71
+ """Get the modulus of two numbers.
72
+
73
+ Args:
74
+ a: first int
75
+ b: second int
76
+ """
77
+ return a % b
78
+
79
+ @tool
80
+ def wiki_search(query: str) -> str:
81
+ """Search Wikipedia for a query and return maximum 2 results.
82
+
83
+ Args:
84
+ query: The search query."""
85
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
86
+ formatted_search_docs = "\n\n---\n\n".join(
87
+ [
88
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
89
+ for doc in search_docs
90
+ ])
91
+ return {"wiki_results": formatted_search_docs}
92
+
93
+ @tool
94
+ def web_search(query: str) -> str:
95
+ """Search Tavily for a query and return maximum 3 results.
96
+
97
+ Args:
98
+ query: The search query."""
99
+ search_docs = TavilySearchResults(max_results=3).invoke(input=query)
100
+ #formatted_search_docs = "\n\n---\n\n".join(
101
+ # [
102
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
103
+ # for doc in search_docs
104
+ # ])
105
+ formatted_search_docs = "\n\n---\n\n".join(
106
+ [
107
+ f'<Document source="{doc.get("url", "N/A")}" page="{doc.get("page", "")}"/>\n{doc.get("content", "No content available")}\n</Document>'
108
+ for doc in search_web_docs
109
+ ])
110
+ return {"web_results": formatted_search_docs}
111
+
112
+ @tool
113
+ def arvix_search(query: str) -> str:
114
+ """Search Arxiv for a query and return maximum 3 result.
115
+
116
+ Args:
117
+ query: The search query."""
118
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
119
+ formatted_search_docs = "\n\n---\n\n".join(
120
+ [
121
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
122
+ for doc in search_docs
123
+ ])
124
+ return {"arvix_results": formatted_search_docs}
125
+
126
+
127
+
128
+ # load the system prompt from the file
129
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
130
+ system_prompt = f.read()
131
+
132
+ # System message
133
+ sys_msg = SystemMessage(content=system_prompt)
134
+
135
+ # build a retriever
136
+ #embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
137
+ #supabase: Client = create_client(
138
+ # os.environ.get("SUPABASE_URL"),
139
+ # os.environ.get("SUPABASE_SERVICE_KEY"))
140
+ #vector_store = SupabaseVectorStore(
141
+ # client=supabase,
142
+ # embedding= embeddings,
143
+ # table_name="documents",
144
+ # query_name="match_documents_langchain",
145
+ #)
146
+ #create_retriever_tool = create_retriever_tool(
147
+ # retriever=vector_store.as_retriever(),
148
+ # name="Question Search",
149
+ # description="A tool to retrieve similar questions from a vector store.",
150
+ #)
151
+
152
+
153
+
154
+ tools = [
155
+ multiply,
156
+ add,
157
+ subtract,
158
+ divide,
159
+ modulus,
160
+ wiki_search,
161
+ web_search,
162
+ arvix_search,
163
+ ]
164
+
165
+ # Build graph function
166
+ def build_graph(provider: str = "huggingface"):
167
+ """Build the graph"""
168
+ # Load environment variables from .env file
169
+ if provider == "google":
170
+ # Google Gemini
171
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
172
+ elif provider == "groq":
173
+ # Groq https://console.groq.com/docs/models
174
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
175
+ elif provider == "huggingface":
176
+ llm = ChatHuggingFace(
177
+ llm=HuggingFaceEndpoint(
178
+ repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
179
+ ),
180
+ )
181
+ else:
182
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
183
+ # Bind tools to LLM
184
+ llm_with_tools = llm.bind_tools(tools)
185
+
186
+ # Node
187
+ def assistant(state: MessagesState):
188
+ """Assistant node"""
189
+ return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
190
+
191
+ # def retriever(state: MessagesState):
192
+ # """Retriever node"""
193
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
194
+ # example_msg = HumanMessage(
195
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
196
+ # )
197
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
198
+
199
+ builder = StateGraph(MessagesState)
200
+ #builder.add_node("retriever", retriever)
201
+ builder.add_node("assistant", assistant)
202
+ builder.add_node("tools", ToolNode(tools))
203
+ #builder.add_edge(START, "retriever")
204
+ #builder.add_edge("retriever", "assistant")
205
+ builder.add_edge(START, "assistant")
206
+ builder.add_conditional_edges(
207
+ "assistant",
208
+ tools_condition,
209
+ )
210
+ builder.add_edge("tools", "assistant")
211
+
212
+ # Compile graph
213
+ return builder.compile()