wlchee commited on
Commit
c1b0d5c
·
verified ·
1 Parent(s): 53c8dc9

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +203 -90
agent.py CHANGED
@@ -1,101 +1,214 @@
1
- from transformers import Tool, HfAgent
2
- from huggingface_hub import list_models
3
- import requests
4
- from typing import Optional, List
5
- import random
6
-
7
- # First, let's define some custom tools the agent can use
8
-
9
- class WebSearchTool(Tool):
10
- name = "web_search"
11
- description = ("A tool that performs a web search using a search engine API. "
12
- "Input should be a search query. Output will be search results.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- inputs = ["text"]
15
- outputs = ["text"]
 
 
 
 
 
 
 
16
 
17
- def __call__(self, query: str):
18
- # In a real implementation, you would call a search API here
19
- # For demonstration, we'll return mock results
20
- return f"Search results for '{query}': 1. Relevant result 1, 2. Relevant result 2"
21
-
22
- class CalculatorTool(Tool):
23
- name = "calculator"
24
- description = ("A tool for performing mathematical calculations. "
25
- "Input should be a mathematical expression. Output will be the result.")
26
 
27
- inputs = ["text"]
28
- outputs = ["text"]
 
 
 
 
 
 
 
 
 
29
 
30
- def __call__(self, expression: str):
31
- try:
32
- result = eval(expression) # Note: In production, use a safer eval method
33
- return str(result)
34
- except:
35
- return "Error: Could not evaluate the expression"
36
-
37
- class CurrentTimeTool(Tool):
38
- name = "get_current_time"
39
- description = ("A tool that returns the current time in UTC. "
40
- "No input needed. Output will be the current time.")
41
 
42
- inputs = []
43
- outputs = ["text"]
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def __call__(self):
46
- from datetime import datetime
47
- return datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")
48
-
49
- class WikipediaTool(Tool):
50
- name = "wikipedia_search"
51
- description = ("A tool that searches Wikipedia. "
52
- "Input should be a search term. Output will be a summary from Wikipedia.")
 
 
 
 
 
53
 
54
- inputs = ["text"]
55
- outputs = ["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- def __call__(self, term: str):
58
- # Mock implementation - in real use, you'd call the Wikipedia API
59
- return f"Wikipedia summary for '{term}': This is a summary about {term}."
60
-
61
- # Now let's create the agent with these tools and some pre-trained tools
62
-
63
- def create_agent():
64
- # Load pre-trained tools from the Hub
65
- agent = HfAgent(
66
- "https://api-inference.huggingface.co/models/bigcode/starcoder",
67
- additional_tools=[
68
- WebSearchTool(),
69
- CalculatorTool(),
70
- CurrentTimeTool(),
71
- WikipediaTool()
72
- ],
73
- # These parameters help with performance
74
- max_new_tokens=200,
75
- temperature=0.7,
76
- top_p=0.9,
77
  )
78
- return agent
 
 
 
79
 
80
- # Example usage of the agent
81
  if __name__ == "__main__":
82
- agent = create_agent()
83
-
84
- # Test the agent with some sample queries
85
- queries = [
86
- "What's the capital of France?",
87
- "Calculate 123 * 45",
88
- "What time is it now?",
89
- "Tell me about Albert Einstein",
90
- "Search for the latest news about AI"
91
- ]
92
-
93
- for query in queries:
94
- print(f"Query: {query}")
95
- result = agent.run(query)
96
- print(f"Result: {result}\n")
97
-
98
- # To evaluate on the benchmark, you would use:
99
- # from transformers.benchmarks import evaluate_agent
100
- # benchmark_score = evaluate_agent(agent)
101
- # print(f"Benchmark score: {benchmark_score}")
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+
19
+ load_dotenv()
20
+
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+
25
+ Args:
26
+ a: first int
27
+ b: second int
28
+ """
29
+ return a * b
30
+
31
+ @tool
32
+ def add(a: int, b: int) -> int:
33
+ """Add two numbers.
34
 
35
+ Args:
36
+ a: first int
37
+ b: second int
38
+ """
39
+ return a + b
40
+
41
+ @tool
42
+ def subtract(a: int, b: int) -> int:
43
+ """Subtract two numbers.
44
 
45
+ Args:
46
+ a: first int
47
+ b: second int
48
+ """
49
+ return a - b
50
+
51
+ @tool
52
+ def divide(a: int, b: int) -> int:
53
+ """Divide two numbers.
54
 
55
+ Args:
56
+ a: first int
57
+ b: second int
58
+ """
59
+ if b == 0:
60
+ raise ValueError("Cannot divide by zero.")
61
+ return a / b
62
+
63
+ @tool
64
+ def modulus(a: int, b: int) -> int:
65
+ """Get the modulus of two numbers.
66
 
67
+ Args:
68
+ a: first int
69
+ b: second int
70
+ """
71
+ return a % b
72
+
73
+ @tool
74
+ def wiki_search(query: str) -> str:
75
+ """Search Wikipedia for a query and return maximum 2 results.
 
 
76
 
77
+ Args:
78
+ query: The search query."""
79
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
+ formatted_search_docs = "\n\n---\n\n".join(
81
+ [
82
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
+ for doc in search_docs
84
+ ])
85
+ return {"wiki_results": formatted_search_docs}
86
+
87
+ @tool
88
+ def web_search(query: str) -> str:
89
+ """Search Tavily for a query and return maximum 3 results.
90
 
91
+ Args:
92
+ query: The search query."""
93
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
+ formatted_search_docs = "\n\n---\n\n".join(
95
+ [
96
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
+ for doc in search_docs
98
+ ])
99
+ return {"web_results": formatted_search_docs}
100
+
101
+ @tool
102
+ def arvix_search(query: str) -> str:
103
+ """Search Arxiv for a query and return maximum 3 result.
104
 
105
+ Args:
106
+ query: The search query."""
107
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
+ formatted_search_docs = "\n\n---\n\n".join(
109
+ [
110
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
+ for doc in search_docs
112
+ ])
113
+ return {"arvix_results": formatted_search_docs}
114
+
115
+
116
+
117
+ # load the system prompt from the file
118
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
+ system_prompt = f.read()
120
+
121
+ # System message
122
+ sys_msg = SystemMessage(content=system_prompt)
123
+
124
+ # build a retriever
125
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
+ supabase: Client = create_client(
127
+ os.environ.get("SUPABASE_URL"),
128
+ os.environ.get("SUPABASE_SERVICE_KEY"))
129
+ vector_store = SupabaseVectorStore(
130
+ client=supabase,
131
+ embedding= embeddings,
132
+ table_name="documents",
133
+ query_name="match_documents_langchain",
134
+ )
135
+ create_retriever_tool = create_retriever_tool(
136
+ retriever=vector_store.as_retriever(),
137
+ name="Question Search",
138
+ description="A tool to retrieve similar questions from a vector store.",
139
+ )
140
+
141
+
142
+
143
+ tools = [
144
+ multiply,
145
+ add,
146
+ subtract,
147
+ divide,
148
+ modulus,
149
+ wiki_search,
150
+ web_search,
151
+ arvix_search,
152
+ ]
153
+
154
+ # Build graph function
155
+ def build_graph(provider: str = "groq"):
156
+ """Build the graph"""
157
+ # Load environment variables from .env file
158
+ if provider == "google":
159
+ # Google Gemini
160
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
+ elif provider == "groq":
162
+ # Groq https://console.groq.com/docs/models
163
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
+ elif provider == "huggingface":
165
+ # TODO: Add huggingface endpoint
166
+ llm = ChatHuggingFace(
167
+ llm=HuggingFaceEndpoint(
168
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
169
+ temperature=0,
170
+ ),
171
+ )
172
+ else:
173
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
174
+ # Bind tools to LLM
175
+ llm_with_tools = llm.bind_tools(tools)
176
+
177
+ # Node
178
+ def assistant(state: MessagesState):
179
+ """Assistant node"""
180
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
 
182
+ def retriever(state: MessagesState):
183
+ """Retriever node"""
184
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
185
+ example_msg = HumanMessage(
186
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
+ )
188
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
189
+
190
+ builder = StateGraph(MessagesState)
191
+ builder.add_node("retriever", retriever)
192
+ builder.add_node("assistant", assistant)
193
+ builder.add_node("tools", ToolNode(tools))
194
+ builder.add_edge(START, "retriever")
195
+ builder.add_edge("retriever", "assistant")
196
+ builder.add_conditional_edges(
197
+ "assistant",
198
+ tools_condition,
 
 
 
199
  )
200
+ builder.add_edge("tools", "assistant")
201
+
202
+ # Compile graph
203
+ return builder.compile()
204
 
205
+ # test
206
  if __name__ == "__main__":
207
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
+ # Build the graph
209
+ graph = build_graph(provider="groq")
210
+ # Run the graph
211
+ messages = [HumanMessage(content=question)]
212
+ messages = graph.invoke({"messages": messages})
213
+ for m in messages["messages"]:
214
+ m.pretty_print()