skyliulu commited on
Commit
9bcd561
·
1 Parent(s): d93030e

tool & agent

Browse files
Files changed (2) hide show
  1. agent.py +91 -0
  2. tools.py +122 -0
agent.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TypedDict, Annotated
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph.message import add_messages
5
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
6
+ from langgraph.prebuilt import ToolNode
7
+ from langgraph.graph import START, StateGraph, MessagesState
8
+ from langgraph.prebuilt import tools_condition
9
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ from langchain_groq import ChatGroq
12
+ from langchain_openai import ChatOpenAI
13
+
14
+ from tools import (
15
+ divide,
16
+ multiply,
17
+ modulus,
18
+ add,
19
+ subtract,
20
+ power,
21
+ square_root,
22
+ web_search,
23
+ wiki_search,
24
+ arxiv_search,
25
+ )
26
+
27
+ # load api key
28
+ load_dotenv()
29
+
30
+
31
+ def buildAgent(provider="huggingface"):
32
+ # Generate the chat interface, including the tools
33
+ if provider == "huggingface":
34
+ llm = ChatHuggingFace(
35
+ llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
36
+ )
37
+ elif provider == "groq":
38
+ llm = ChatGroq(model="qwen-qwq-32b")
39
+ elif provider == "openrouter":
40
+ llm = ChatOpenAI(
41
+ base_url="https://openrouter.ai/api/v1",
42
+ api_key=os.environ.get("OPENROUTER_API_KEY"),
43
+ model="google/gemini-2.0-flash-exp",
44
+ )
45
+
46
+ tools = [
47
+ multiply,
48
+ add,
49
+ subtract,
50
+ divide,
51
+ modulus,
52
+ power,
53
+ square_root,
54
+ web_search,
55
+ wiki_search,
56
+ arxiv_search,
57
+ ]
58
+
59
+ chat_with_tools = llm.bind_tools(tools)
60
+
61
+ # nodes
62
+ def assistant(state: MessagesState):
63
+ return {
64
+ "messages": [chat_with_tools.invoke(state["messages"])],
65
+ }
66
+
67
+ ## The graph
68
+ builder = StateGraph(MessagesState)
69
+ # Define nodes: these do the work
70
+ builder.add_node("assistant", assistant)
71
+ builder.add_node("tools", ToolNode(tools))
72
+ # Define edges: these determine how the control flow moves
73
+ builder.add_edge(START, "assistant")
74
+ builder.add_conditional_edges(
75
+ "assistant",
76
+ # If the latest message requires a tool, route to tools
77
+ # Otherwise, provide a direct response
78
+ tools_condition,
79
+ )
80
+ builder.add_edge("tools", "assistant")
81
+ return builder.compile()
82
+
83
+
84
+ if __name__ == "__main__":
85
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
86
+ graph = buildAgent(provider="groq")
87
+ messages = [HumanMessage(content=question)]
88
+ print(messages)
89
+ messages = graph.invoke({"messages": messages})
90
+ for m in messages["messages"]:
91
+ m.pretty_print()
tools.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cmath
2
+ from langchain_core.tools import tool
3
+ from langchain_community.tools.tavily_search import TavilySearchResults
4
+ from langchain_community.document_loaders import WikipediaLoader
5
+ from langchain_community.document_loaders import ArxivLoader
6
+
7
+ @tool
8
+ def multiply(a: int, b: int) -> int:
9
+ """Multiply two numbers.
10
+ Args:
11
+ a: first int
12
+ b: second int
13
+ """
14
+ return a * b
15
+
16
+ @tool
17
+ def add(a: int, b: int) -> int:
18
+ """Add two numbers.
19
+ Args:
20
+ a: first int
21
+ b: second int
22
+ """
23
+ return a - b
24
+
25
+ @tool
26
+ def subtract(a: int, b: int) -> int:
27
+ """Subtract two numbers.
28
+
29
+ Args:
30
+ a: first int
31
+ b: second int
32
+ """
33
+ return a - b
34
+
35
+ @tool
36
+ def divide(a: int, b: int) -> int:
37
+ """Divide two numbers.
38
+
39
+ Args:
40
+ a: first int
41
+ b: second int
42
+ """
43
+ if b == 0:
44
+ raise ValueError("Cannot divide by zero.")
45
+ return a / b
46
+
47
+ @tool
48
+ def modulus(a: int, b: int) -> int:
49
+ """Get the modulus of two numbers.
50
+
51
+ Args:
52
+ a: first int
53
+ b: second int
54
+ """
55
+ return a % b
56
+
57
+ @tool
58
+ def power(a: float, b: float) -> float:
59
+ """
60
+ Get the power of two numbers.
61
+ Args:
62
+ a (float): the first number
63
+ b (float): the second number
64
+ """
65
+ return a**b
66
+
67
+
68
+ @tool
69
+ def square_root(a: float) -> float | complex:
70
+ """
71
+ Get the square root of a number.
72
+ Args:
73
+ a (float): the number to get the square root of
74
+ """
75
+ if a >= 0:
76
+ return a**0.5
77
+ return cmath.sqrt(a)
78
+
79
+ @tool
80
+ def web_search(query: str) -> str:
81
+ """Search Tavily for a query and return maximum 3 results.
82
+ Args:
83
+ query: The search query.
84
+ """
85
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
86
+ formatted_search_docs = "\n\n---\n\n".join(
87
+ [
88
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
89
+ for doc in search_docs
90
+ ]
91
+ )
92
+ return {"web_results": formatted_search_docs}
93
+
94
+ @tool
95
+ def wiki_search(query: str) -> str:
96
+ """Search Wikipedia for a query and return maximum 2 results.
97
+ Args:
98
+ query: The search query.
99
+ """
100
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
101
+ formatted_search_docs = "\n\n---\n\n".join(
102
+ [
103
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
104
+ for doc in search_docs
105
+ ]
106
+ )
107
+ return {"wiki_results": formatted_search_docs}
108
+
109
+ @tool
110
+ def arxiv_search(query: str) -> str:
111
+ """Search Arxiv for a query and return maximum 3 result.
112
+ Args:
113
+ query: The search query.
114
+ """
115
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
116
+ formatted_search_docs = "\n\n---\n\n".join(
117
+ [
118
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
119
+ for doc in search_docs
120
+ ]
121
+ )
122
+ return {"arxiv_results": formatted_search_docs}