yooke commited on
Commit
d348bcf
·
verified ·
1 Parent(s): 3a98536

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +68 -80
agent.py CHANGED
@@ -3,9 +3,10 @@ from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
- from langchain_community.tools.tavily_search import TavilySearchResults # 已经导入了
7
  from langchain_community.document_loaders import WikipediaLoader
8
- from langchain_community.document_loaders import ArxivLoader
 
9
  from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain_core.tools import tool
11
  # from langchain_openai import ChatOpenAI
@@ -18,47 +19,16 @@ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") # 需要在 Space Secrets 中添加
18
 
19
  if not DEEPSEEK_API_KEY:
20
  raise ValueError("DEEPSEEK_API_KEY not found in environment variables.")
 
 
21
  if not TAVILY_API_KEY:
22
- # Tavily is critical for most questions, raise error if not set
23
- raise ValueError("TAVILY_API_KEY not found in environment variables. Please add it to your Space Secrets.")
24
 
25
 
26
- @tool
27
- def multiply(a: int, b: int) -> int:
28
- """Multiplies two numbers."""
29
- return a * b
30
-
31
-
32
- @tool
33
- def add(a: int, b: int) -> int:
34
- """Adds two numbers."""
35
- return a + b
36
 
37
 
38
- @tool
39
- def subtract(a: int, b: int) -> int:
40
- """Subtracts two numbers."""
41
- return a - b
42
-
43
-
44
- @tool
45
- def divide(a: int, b: int) -> int:
46
- """Divides two numbers."""
47
- # Added check for division by zero
48
- if b == 0:
49
- return "Error: Division by zero."
50
- return a / b
51
-
52
-
53
- @tool
54
- def modulo(a: int, b: int) -> int:
55
- """Returns the remainder of two numbers."""
56
- # Added check for modulo by zero
57
- if b == 0:
58
- return "Error: Modulo by zero."
59
- return a % b
60
-
61
- # Keep Wikipedia and Arxiv, but the general search will be more used
62
  @tool
63
  def wiki_search(query: str) -> str:
64
  "Using Wikipedia, search for a query and return up to 2 relevant results."
@@ -76,32 +46,20 @@ def wiki_search(query: str) -> str:
76
  return f"An error occurred during Wikipedia search: {e}"
77
 
78
 
79
- @tool
80
- def arvix_search(query: str) -> str:
81
- """Search Arxiv for scientific papers by query and return maximum 3 results (title and summary)."""
82
- try:
83
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
84
- if not search_docs:
85
- return "Arxiv search found no relevant papers."
86
- # Format results to be more concise
87
- formatted_search_docs = "\n\n---\n\n".join(
88
- [
89
- f'<Document source="Arxiv - {doc.metadata.get("source", "")}">\nTitle: {doc.metadata.get("Title", "N/A")}\nSummary: {doc.page_content[:1000]}...\n</Document>' # Limit summary length
90
- for doc in search_docs
91
- ])
92
- return formatted_search_docs # Return string directly
93
- except Exception as e:
94
- return f"An error occurred during Arxiv search: {e}"
95
 
96
- # *** ADD TAVILY WEB SEARCH TOOL ***
97
  @tool
98
  def web_search(query: str) -> str:
99
  """Search the web for a query using Tavily and return relevant snippets."""
 
 
100
  try:
101
  tavily = TavilySearchResults(max_results=5) # Get up to 5 results
102
  results = tavily.invoke(query)
103
  if not results:
104
- return "Web search found no relevant results."
105
  # Format Tavily results
106
  formatted_results = "\n\n---\n\n".join([
107
  f'<SearchResult source="{r["source"]}">\nTitle: {r["title"]}\nContent: {r["content"]}\n</SearchResult>'
@@ -109,24 +67,41 @@ def web_search(query: str) -> str:
109
  ])
110
  return formatted_results # Return string directly
111
  except Exception as e:
112
- return f"An error occurred during web search: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  # load the system prompt from the file
116
  # Ensure this file exists and has the content from Step 2
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
- sys_msg = SystemMessage(content=system_prompt)
 
 
 
 
 
120
 
 
121
  tools = [
122
- multiply,
123
- add,
124
- subtract,
125
- divide,
126
- modulo,
127
  wiki_search,
128
- arvix_search,
129
- web_search, # *** ADDED TAVILY WEB SEARCH ***
130
  ]
131
 
132
 
@@ -140,17 +115,21 @@ def build_graph():
140
  api_key=DEEPSEEK_API_KEY,
141
  base_url="https://api.deepseek.com"
142
  )
 
143
  llm_with_tools = llm.bind_tools(tools)
144
 
145
  def assistant(state: MessagesState):
146
  """Assistant node: invoke LLM with tools."""
147
  print("---Calling Assistant---") # Added print for debugging
148
- result = llm_with_tools.invoke(state["messages"])
 
 
149
  print(f"---Assistant Response: {result}") # Added print for debugging
150
  return {"messages": [result]}
151
 
152
  builder = StateGraph(MessagesState)
153
  builder.add_node("assistant", assistant)
 
154
  builder.add_node("tools", ToolNode(tools))
155
 
156
  builder.add_edge(START, "assistant")
@@ -178,20 +157,24 @@ def build_graph():
178
 
179
  if __name__ == "__main__":
180
  # Example Usage (for local testing)
181
- # To run this part, make sure you have DEEPSEEK_API_KEY and TAVILY_API_KEY
182
- # set in your environment or a .env file loaded beforehand.
183
  # If running locally, you'd typically use `load_dotenv()` here or in app.py
184
 
 
 
 
185
  # Test questions covering different tool needs
 
 
186
  questions_for_testing = [
187
- "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?", # Web Search
188
- "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species seen?", # Requires video analysis (will likely fail with current tools)
189
- ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI", # Text manipulation (no tool needed)
190
- "What is 12345 * 6789?", # Calculator
191
- "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2023?", # Web Search/Wikipedia
192
- "What country had the least number of athletes at the 1928 Summer Olympics?", # Web Search
193
- "Review the chess position provided in the image. It is black's turn. Provide the correct next move from this position: [Describe the position or mention image input which is not supported]", # Requires image analysis (will likely fail)
194
- # Add more questions from your evaluation set to test
195
  ]
196
 
197
 
@@ -204,7 +187,7 @@ if __name__ == "__main__":
204
  # f.write(png_data)
205
  # print("Graph visualization saved to graph.png")
206
  # except Exception as e:
207
- # print(f"Could not draw graph: {e}")
208
 
209
 
210
  print("\n--- Running single question tests ---")
@@ -212,11 +195,16 @@ if __name__ == "__main__":
212
  print(f"\n--- Testing Question {i+1}: {question}")
213
  try:
214
  # LangGraph returns the final state after execution completes or hits recursion limit
 
 
215
  final_state = graph.invoke({"messages": [HumanMessage(content=question)]})
216
  print("\n--- Final State Messages ---")
 
217
  for m in final_state["messages"]:
218
- m.pretty_print()
219
  print("-" * 30)
220
  except Exception as e:
221
  print(f"--- Error running graph for this question: {e}")
 
 
222
  print("-" * 30)
 
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
  from langchain_community.document_loaders import WikipediaLoader
8
+ # Removed: from langchain_community.document_loaders import ArxivLoader
9
+ from langchain_community.tools import DuckDuckGoSearchRun # Added DuckDuckGo import
10
  from langchain_core.messages import SystemMessage, HumanMessage
11
  from langchain_core.tools import tool
12
  # from langchain_openai import ChatOpenAI
 
19
 
20
  if not DEEPSEEK_API_KEY:
21
  raise ValueError("DEEPSEEK_API_KEY not found in environment variables.")
22
+ # Tavily is still included, so its key is needed if you want to use it.
23
+ # If you ONLY want DuckDuckGo, you could remove Tavily and this check.
24
  if not TAVILY_API_KEY:
25
+ print("Warning: TAVILY_API_KEY not found. Tavily search tool may not work.")
 
26
 
27
 
28
+ # --- Removed math tools (multiply, add, subtract, divide, modulo) ---
 
 
 
 
 
 
 
 
 
29
 
30
 
31
+ # Keep Wikipedia search
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  @tool
33
  def wiki_search(query: str) -> str:
34
  "Using Wikipedia, search for a query and return up to 2 relevant results."
 
46
  return f"An error occurred during Wikipedia search: {e}"
47
 
48
 
49
+ # --- Removed Arxiv search (arvix_search) ---
50
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # *** ADD TAVILY WEB SEARCH TOOL *** (Kept as requested implicitly by keeping web_search)
53
  @tool
54
  def web_search(query: str) -> str:
55
  """Search the web for a query using Tavily and return relevant snippets."""
56
+ if not TAVILY_API_KEY:
57
+ return "Tavily search is not available because TAVILY_API_KEY is not set."
58
  try:
59
  tavily = TavilySearchResults(max_results=5) # Get up to 5 results
60
  results = tavily.invoke(query)
61
  if not results:
62
+ return "Web search (Tavily) found no relevant results."
63
  # Format Tavily results
64
  formatted_results = "\n\n---\n\n".join([
65
  f'<SearchResult source="{r["source"]}">\nTitle: {r["title"]}\nContent: {r["content"]}\n</SearchResult>'
 
67
  ])
68
  return formatted_results # Return string directly
69
  except Exception as e:
70
+ return f"An error occurred during web search (Tavily): {e}"
71
+
72
+ # *** ADD DUCKDUCKGO WEB SEARCH TOOL ***
73
+ duckduckgo_search_tool_instance = DuckDuckGoSearchRun() # Instantiate the DuckDuckGo tool
74
+
75
+ @tool
76
+ def duckduckgo_search(query: str) -> str:
77
+ """Search the web for a query using DuckDuckGo."""
78
+ try:
79
+ # The DuckDuckGoSearchRun tool directly returns a formatted string result
80
+ results = duckduckgo_search_tool_instance.run(query)
81
+ if not results:
82
+ return "DuckDuckGo search found no relevant results."
83
+ # DuckDuckGoSearchRun often returns results as a string ready to be used
84
+ return results
85
+ except Exception as e:
86
+ return f"An error occurred during DuckDuckGo search: {e}"
87
 
88
 
89
  # load the system prompt from the file
90
  # Ensure this file exists and has the content from Step 2
91
+ try:
92
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
93
+ system_prompt = f.read()
94
+ sys_msg = SystemMessage(content=system_prompt)
95
+ except FileNotFoundError:
96
+ print("Warning: system_prompt.txt not found. Using a default system message.")
97
+ sys_msg = SystemMessage(content="You are a helpful AI assistant. You can use tools to find information.")
98
+
99
 
100
+ # Updated tools list: Removed math and Arxiv, Added DuckDuckGo
101
  tools = [
 
 
 
 
 
102
  wiki_search,
103
+ web_search, # Tavily search
104
+ duckduckgo_search, # DuckDuckGo search
105
  ]
106
 
107
 
 
115
  api_key=DEEPSEEK_API_KEY,
116
  base_url="https://api.deepseek.com"
117
  )
118
+ # Bind the updated tools list to the LLM
119
  llm_with_tools = llm.bind_tools(tools)
120
 
121
  def assistant(state: MessagesState):
122
  """Assistant node: invoke LLM with tools."""
123
  print("---Calling Assistant---") # Added print for debugging
124
+ # Include the system message at the beginning of the conversation
125
+ messages_for_llm = [sys_msg] + state["messages"]
126
+ result = llm_with_tools.invoke(messages_for_llm)
127
  print(f"---Assistant Response: {result}") # Added print for debugging
128
  return {"messages": [result]}
129
 
130
  builder = StateGraph(MessagesState)
131
  builder.add_node("assistant", assistant)
132
+ # The ToolNode needs the list of functions, not just the names
133
  builder.add_node("tools", ToolNode(tools))
134
 
135
  builder.add_edge(START, "assistant")
 
157
 
158
  if __name__ == "__main__":
159
  # Example Usage (for local testing)
160
+ # To run this part, make sure you have DEEPSEEK_API_KEY set.
161
+ # TAVILY_API_KEY is needed for the web_search tool. DuckDuckGo usually works without a key.
162
  # If running locally, you'd typically use `load_dotenv()` here or in app.py
163
 
164
+ print("Note: Ensure DEEPSEEK_API_KEY is set.")
165
+ print("Note: TAVILY_API_KEY is required for the 'web_search' tool (Tavily). DuckDuckGo usually works without a key.")
166
+
167
  # Test questions covering different tool needs
168
+ # Removed purely math questions and Arxiv questions.
169
+ # Added questions that might benefit from multiple search tools.
170
  questions_for_testing = [
171
+ "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?", # Web Search (Tavily or DDG)
172
+ "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2023? Use Wikipedia first if possible.", # Wikipedia or Web Search
173
+ "What country had the least number of athletes at the 1928 Summer Olympics? Find this information using web search.", # Web Search (Tavily or DDG)
174
+ "Tell me about the Voyager 1 probe. Use Wikipedia.", # Wikipedia
175
+ "What is the current population of Tokyo?", # Web Search (Tavily or DDG)
176
+ "Give me a brief overview of the concept of 'LangGraph'.", # Web Search (Tavily or DDG)
177
+ ".rewsna eht sa \"tfel\" drow ehT etirw ,ecnetnes siht dnatsrednu uoy fI", # Text manipulation (no tool needed)
 
178
  ]
179
 
180
 
 
187
  # f.write(png_data)
188
  # print("Graph visualization saved to graph.png")
189
  # except Exception as e:
190
+ # print(f"Could not draw graph: {e}. Make sure 'pygraphviz' and graphviz system libraries are installed.")
191
 
192
 
193
  print("\n--- Running single question tests ---")
 
195
  print(f"\n--- Testing Question {i+1}: {question}")
196
  try:
197
  # LangGraph returns the final state after execution completes or hits recursion limit
198
+ # Need to start with the system message and the first human message
199
+ # The assistant node prepends the system message internally now.
200
  final_state = graph.invoke({"messages": [HumanMessage(content=question)]})
201
  print("\n--- Final State Messages ---")
202
+ # Print messages more readably
203
  for m in final_state["messages"]:
204
+ print(f"{m.__class__.__name__}: {m.content}")
205
  print("-" * 30)
206
  except Exception as e:
207
  print(f"--- Error running graph for this question: {e}")
208
+ import traceback
209
+ traceback.print_exc() # Print full traceback for debugging
210
  print("-" * 30)