yooke commited on
Commit
979ff4b
·
verified ·
1 Parent(s): 8ff3469

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +113 -112
agent.py CHANGED
@@ -1,112 +1,113 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
- from langchain_community.tools.tavily_search import TavilySearchResults
7
- from langchain_community.document_loaders import WikipediaLoader
8
- from langchain_community.document_loaders import ArxivLoader
9
- from langchain_core.messages import SystemMessage, HumanMessage
10
- from langchain_core.tools import tool
11
- # from langchain_openai import ChatOpenAI
12
- from langchain_deepseek import ChatDeepSeek
13
-
14
-
15
-
16
- load_dotenv()
17
- DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
18
- @tool
19
- def multiply(a: int, b: int) -> int:
20
- """Multiplies two numbers."""
21
- return a * b
22
- @tool
23
- def add (a: int, b: int) -> int:
24
- """Adds two numbers."""
25
- return a + b
26
- @tool
27
- def subtract (a: int, b: int) -> int:
28
- """Subtracts two numbers."""
29
- return a - b
30
- @tool
31
- def divide (a: int, b: int) -> int:
32
- """Divides two numbers."""
33
- return a / b
34
- @tool
35
- def modulo (a: int, b: int) -> int:
36
- """Returns the remainder of two numbers."""
37
- return a % b
38
- @tool
39
- def wiki_search(query:str)->str:
40
- "Using Wikipedia, search for a query and return the first result."
41
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
42
- formatted_search_docs = "\n\n---\n\n".join(
43
- [
44
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
45
- for doc in search_docs
46
- ])
47
- return {"wiki_results": formatted_search_docs}
48
- @tool
49
- def arvix_search(query: str) -> str:
50
- """Search Arxiv for a query and return maximum 3 result.
51
-
52
- Args:
53
- query: The search query."""
54
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
55
- formatted_search_docs = "\n\n---\n\n".join(
56
- [
57
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
58
- for doc in search_docs
59
- ])
60
- return {"arvix_results": formatted_search_docs}
61
-
62
-
63
- # load the system prompt from the file
64
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
65
- system_prompt = f.read()
66
- sys_msg = SystemMessage(content=system_prompt)
67
-
68
- tools = [
69
- multiply,
70
- add,
71
- subtract,
72
- divide,
73
- modulo,
74
- wiki_search,
75
- arvix_search,
76
- ]
77
- def build_graph():
78
- llm = ChatDeepSeek(
79
- model="deepseek-chat",
80
- temperature=0,
81
- max_tokens=None,
82
- timeout=None,
83
- max_retries=2,
84
- api_key=DEEPSEEK_API_KEY,
85
- )
86
- llm_with_tools = llm.bind_tools(tools)
87
- def assistant(state: MessagesState):
88
- """Assistant node"""
89
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
90
-
91
- builder = StateGraph(MessagesState)
92
- builder.add_node("assistant", assistant)
93
- builder.add_node("tools",ToolNode(tools))
94
- builder.add_edge(START, "assistant")
95
- builder.add_conditional_edges(
96
- "assistant",
97
- tools_condition,
98
- )
99
- builder.add_edge("tools", "assistant")
100
- return builder.compile()
101
- if __name__ == "__main__":
102
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
103
- # Build the graph
104
- graph = build_graph()
105
- png_data = graph.get_graph().draw_mermaid_png()
106
- with open("graph.png", "wb") as f:
107
- f.write(png_data)
108
- # Run the graph
109
- messages = [HumanMessage(content=question)]
110
- messages = graph.invoke({"messages": messages})
111
- for m in messages["messages"]:
112
- m.pretty_print()
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ from langchain_community.document_loaders import WikipediaLoader
8
+ from langchain_community.document_loaders import ArxivLoader
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
+ from langchain_core.tools import tool
11
+ # from langchain_openai import ChatOpenAI
12
+ from langchain_deepseek import ChatDeepSeek
13
+
14
+
15
+
16
+ load_dotenv()
17
+ DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
18
+ @tool
19
+ def multiply(a: int, b: int) -> int:
20
+ """Multiplies two numbers."""
21
+ return a * b
22
+ @tool
23
+ def add (a: int, b: int) -> int:
24
+ """Adds two numbers."""
25
+ return a + b
26
+ @tool
27
+ def subtract (a: int, b: int) -> int:
28
+ """Subtracts two numbers."""
29
+ return a - b
30
+ @tool
31
+ def divide (a: int, b: int) -> int:
32
+ """Divides two numbers."""
33
+ return a / b
34
+ @tool
35
+ def modulo (a: int, b: int) -> int:
36
+ """Returns the remainder of two numbers."""
37
+ return a % b
38
+ @tool
39
+ def wiki_search(query:str)->str:
40
+ "Using Wikipedia, search for a query and return the first result."
41
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
42
+ formatted_search_docs = "\n\n---\n\n".join(
43
+ [
44
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
45
+ for doc in search_docs
46
+ ])
47
+ return {"wiki_results": formatted_search_docs}
48
+ @tool
49
+ def arvix_search(query: str) -> str:
50
+ """Search Arxiv for a query and return maximum 3 result.
51
+
52
+ Args:
53
+ query: The search query."""
54
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
55
+ formatted_search_docs = "\n\n---\n\n".join(
56
+ [
57
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
58
+ for doc in search_docs
59
+ ])
60
+ return {"arvix_results": formatted_search_docs}
61
+
62
+
63
+ # load the system prompt from the file
64
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
65
+ system_prompt = f.read()
66
+ sys_msg = SystemMessage(content=system_prompt)
67
+
68
+ tools = [
69
+ multiply,
70
+ add,
71
+ subtract,
72
+ divide,
73
+ modulo,
74
+ wiki_search,
75
+ arvix_search,
76
+ ]
77
+ def build_graph():
78
+ llm = ChatDeepSeek(
79
+ model="deepseek-chat",
80
+ temperature=0,
81
+ max_tokens=None,
82
+ timeout=None,
83
+ max_retries=2,
84
+ api_key=DEEPSEEK_API_KEY,
85
+ base_url="https://api.deepseek.com/v1",
86
+ )
87
+ llm_with_tools = llm.bind_tools(tools)
88
+ def assistant(state: MessagesState):
89
+ """Assistant node"""
90
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
91
+
92
+ builder = StateGraph(MessagesState)
93
+ builder.add_node("assistant", assistant)
94
+ builder.add_node("tools",ToolNode(tools))
95
+ builder.add_edge(START, "assistant")
96
+ builder.add_conditional_edges(
97
+ "assistant",
98
+ tools_condition,
99
+ )
100
+ builder.add_edge("tools", "assistant")
101
+ return builder.compile()
102
+ if __name__ == "__main__":
103
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
104
+ # Build the graph
105
+ graph = build_graph()
106
+ png_data = graph.get_graph().draw_mermaid_png()
107
+ with open("graph.png", "wb") as f:
108
+ f.write(png_data)
109
+ # Run the graph
110
+ messages = [HumanMessage(content=question)]
111
+ messages = graph.invoke({"messages": messages})
112
+ for m in messages["messages"]:
113
+ m.pretty_print()