skyliulu commited on
Commit
569bb72
·
1 Parent(s): 45d0ffd

格式化输出结果

Browse files
Files changed (4) hide show
  1. agent.py +30 -26
  2. app.py +2 -1
  3. system_prompt.txt +5 -0
  4. tools.py +52 -7
agent.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  from typing import TypedDict, Annotated
3
  from dotenv import load_dotenv
4
- from langgraph.graph.message import add_messages
5
- from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
6
  from langgraph.prebuilt import ToolNode
7
  from langgraph.graph import START, StateGraph, MessagesState
8
  from langgraph.prebuilt import tools_condition
@@ -10,25 +9,22 @@ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
  from langchain_groq import ChatGroq
12
  from langchain_openai import ChatOpenAI
13
-
14
- from tools import (
15
- divide,
16
- multiply,
17
- modulus,
18
- add,
19
- subtract,
20
- power,
21
- square_root,
22
- web_search,
23
- wiki_search,
24
- arxiv_search,
25
- )
26
 
27
  # load api key
28
  load_dotenv()
29
 
30
 
31
  def buildAgent(provider="huggingface"):
 
 
 
 
 
 
 
 
32
  # Generate the chat interface, including the tools
33
  if provider == "huggingface":
34
  llm = ChatHuggingFace(
@@ -36,14 +32,10 @@ def buildAgent(provider="huggingface"):
36
  )
37
  elif provider == "groq":
38
  llm = ChatGroq(model="qwen-qwq-32b")
39
- elif provider == "openrouter":
40
- llm = ChatOpenAI(
41
- base_url="https://openrouter.ai/api/v1",
42
- api_key=os.environ.get("OPENROUTER_API_KEY"),
43
- model="google/gemini-2.0-flash-exp",
44
- )
45
 
46
- tools = [
47
  multiply,
48
  add,
49
  subtract,
@@ -54,9 +46,10 @@ def buildAgent(provider="huggingface"):
54
  web_search,
55
  wiki_search,
56
  arxiv_search,
 
57
  ]
58
 
59
- chat_with_tools = llm.bind_tools(tools)
60
 
61
  # nodes
62
  def assistant(state: MessagesState):
@@ -64,13 +57,21 @@ def buildAgent(provider="huggingface"):
64
  "messages": [chat_with_tools.invoke(state["messages"])],
65
  }
66
 
 
 
 
 
 
 
67
  ## The graph
68
  builder = StateGraph(MessagesState)
69
  # Define nodes: these do the work
 
70
  builder.add_node("assistant", assistant)
71
- builder.add_node("tools", ToolNode(tools))
72
  # Define edges: these determine how the control flow moves
73
- builder.add_edge(START, "assistant")
 
74
  builder.add_conditional_edges(
75
  "assistant",
76
  # If the latest message requires a tool, route to tools
@@ -82,7 +83,10 @@ def buildAgent(provider="huggingface"):
82
 
83
 
84
  if __name__ == "__main__":
85
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
 
 
86
  graph = buildAgent(provider="groq")
87
  messages = [HumanMessage(content=question)]
88
  print(messages)
 
1
  import os
2
  from typing import TypedDict, Annotated
3
  from dotenv import load_dotenv
4
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
 
5
  from langgraph.prebuilt import ToolNode
6
  from langgraph.graph import START, StateGraph, MessagesState
7
  from langgraph.prebuilt import tools_condition
 
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_groq import ChatGroq
11
  from langchain_openai import ChatOpenAI
12
+ import requests
13
+ from tools import *
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # load api key
16
  load_dotenv()
17
 
18
 
19
  def buildAgent(provider="huggingface"):
20
+ # load the system prompt from the file
21
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
22
+ system_prompt = f.read()
23
+ print(system_prompt)
24
+
25
+ # System message
26
+ sys_msg = SystemMessage(content=system_prompt)
27
+
28
  # Generate the chat interface, including the tools
29
  if provider == "huggingface":
30
  llm = ChatHuggingFace(
 
32
  )
33
  elif provider == "groq":
34
  llm = ChatGroq(model="qwen-qwq-32b")
35
+ else:
36
+ raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
 
 
 
 
37
 
38
+ agent_tools = [
39
  multiply,
40
  add,
41
  subtract,
 
46
  web_search,
47
  wiki_search,
48
  arxiv_search,
49
+ download_file,
50
  ]
51
 
52
+ chat_with_tools = llm.bind_tools(agent_tools)
53
 
54
  # nodes
55
  def assistant(state: MessagesState):
 
57
  "messages": [chat_with_tools.invoke(state["messages"])],
58
  }
59
 
60
+ # todo add rag
61
+ def retriever(state: MessagesState):
62
+ """Retriever node"""
63
+ # Handle the case when no similar questions are found
64
+ return {"messages": [sys_msg] + state["messages"]}
65
+
66
  ## The graph
67
  builder = StateGraph(MessagesState)
68
  # Define nodes: these do the work
69
+ builder.add_node("retriever", retriever)
70
  builder.add_node("assistant", assistant)
71
+ builder.add_node("tools", ToolNode(agent_tools))
72
  # Define edges: these determine how the control flow moves
73
+ builder.add_edge(START, "retriever")
74
+ builder.add_edge("retriever", "assistant")
75
  builder.add_conditional_edges(
76
  "assistant",
77
  # If the latest message requires a tool, route to tools
 
83
 
84
 
85
  if __name__ == "__main__":
86
+ random_question_url = "https://agents-course-unit4-scoring.hf.space/random-question"
87
+ response = requests.get(random_question_url, timeout=15)
88
+ questions_data = response.json()
89
+ question = questions_data.get("question")
90
  graph = buildAgent(provider="groq")
91
  messages = [HumanMessage(content=question)]
92
  print(messages)
app.py CHANGED
@@ -9,6 +9,7 @@ from agent import buildAgent
9
  # (Keep Constants as is)
10
  # --- Constants ---
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
12
 
13
 
14
  # --- Basic Agent Definition ---
@@ -23,7 +24,7 @@ class BasicAgent:
23
  messages = [HumanMessage(content=question)]
24
  messages = self.agent.invoke({"messages": messages})
25
  fixed_answer = messages["messages"][-1].content
26
- return fixed_answer
27
 
28
 
29
  def run_and_submit_all(profile: gr.OAuthProfile | None):
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
+ FINAL_ANSWER_PADDING = "FINAL ANSWER: "
13
 
14
 
15
  # --- Basic Agent Definition ---
 
24
  messages = [HumanMessage(content=question)]
25
  messages = self.agent.invoke({"messages": messages})
26
  fixed_answer = messages["messages"][-1].content
27
+ return fixed_answer[len(FINAL_ANSWER_PADDING):]
28
 
29
 
30
  def run_and_submit_all(profile: gr.OAuthProfile | None):
system_prompt.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering questions using a set of tools.
2
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
3
+ FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma.
5
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
tools.py CHANGED
@@ -1,8 +1,13 @@
1
  import cmath
 
 
 
 
2
  from langchain_core.tools import tool
3
  from langchain_community.tools.tavily_search import TavilySearchResults
4
  from langchain_community.document_loaders import WikipediaLoader
5
  from langchain_community.document_loaders import ArxivLoader
 
6
 
7
  @tool
8
  def multiply(a: int, b: int) -> int:
@@ -13,29 +18,32 @@ def multiply(a: int, b: int) -> int:
13
  """
14
  return a * b
15
 
 
16
  @tool
17
  def add(a: int, b: int) -> int:
18
  """Add two numbers.
19
- Args:
20
- a: first int
21
- b: second int
22
  """
23
  return a - b
24
 
 
25
  @tool
26
  def subtract(a: int, b: int) -> int:
27
  """Subtract two numbers.
28
-
29
  Args:
30
  a: first int
31
  b: second int
32
  """
33
  return a - b
34
 
 
35
  @tool
36
  def divide(a: int, b: int) -> int:
37
  """Divide two numbers.
38
-
39
  Args:
40
  a: first int
41
  b: second int
@@ -44,16 +52,18 @@ def divide(a: int, b: int) -> int:
44
  raise ValueError("Cannot divide by zero.")
45
  return a / b
46
 
 
47
  @tool
48
  def modulus(a: int, b: int) -> int:
49
  """Get the modulus of two numbers.
50
-
51
  Args:
52
  a: first int
53
  b: second int
54
  """
55
  return a % b
56
 
 
57
  @tool
58
  def power(a: float, b: float) -> float:
59
  """
@@ -76,6 +86,7 @@ def square_root(a: float) -> float | complex:
76
  return a**0.5
77
  return cmath.sqrt(a)
78
 
 
79
  @tool
80
  def web_search(query: str) -> str:
81
  """Search Tavily for a query and return maximum 3 results.
@@ -91,6 +102,7 @@ def web_search(query: str) -> str:
91
  )
92
  return {"web_results": formatted_search_docs}
93
 
 
94
  @tool
95
  def wiki_search(query: str) -> str:
96
  """Search Wikipedia for a query and return maximum 2 results.
@@ -106,6 +118,7 @@ def wiki_search(query: str) -> str:
106
  )
107
  return {"wiki_results": formatted_search_docs}
108
 
 
109
  @tool
110
  def arxiv_search(query: str) -> str:
111
  """Search Arxiv for a query and return maximum 3 result.
@@ -119,4 +132,36 @@ def arxiv_search(query: str) -> str:
119
  for doc in search_docs
120
  ]
121
  )
122
- return {"arxiv_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cmath
2
+ import os
3
+ import tempfile
4
+ from urllib.parse import urlparse
5
+ import uuid
6
  from langchain_core.tools import tool
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_community.document_loaders import WikipediaLoader
9
  from langchain_community.document_loaders import ArxivLoader
10
+ import requests
11
 
12
  @tool
13
  def multiply(a: int, b: int) -> int:
 
18
  """
19
  return a * b
20
 
21
+
22
  @tool
23
  def add(a: int, b: int) -> int:
24
  """Add two numbers.
25
+ Args:
26
+ a: first int
27
+ b: second int
28
  """
29
  return a - b
30
 
31
+
32
  @tool
33
  def subtract(a: int, b: int) -> int:
34
  """Subtract two numbers.
35
+
36
  Args:
37
  a: first int
38
  b: second int
39
  """
40
  return a - b
41
 
42
+
43
  @tool
44
  def divide(a: int, b: int) -> int:
45
  """Divide two numbers.
46
+
47
  Args:
48
  a: first int
49
  b: second int
 
52
  raise ValueError("Cannot divide by zero.")
53
  return a / b
54
 
55
+
56
  @tool
57
  def modulus(a: int, b: int) -> int:
58
  """Get the modulus of two numbers.
59
+
60
  Args:
61
  a: first int
62
  b: second int
63
  """
64
  return a % b
65
 
66
+
67
  @tool
68
  def power(a: float, b: float) -> float:
69
  """
 
86
  return a**0.5
87
  return cmath.sqrt(a)
88
 
89
+
90
  @tool
91
  def web_search(query: str) -> str:
92
  """Search Tavily for a query and return maximum 3 results.
 
102
  )
103
  return {"web_results": formatted_search_docs}
104
 
105
+
106
  @tool
107
  def wiki_search(query: str) -> str:
108
  """Search Wikipedia for a query and return maximum 2 results.
 
118
  )
119
  return {"wiki_results": formatted_search_docs}
120
 
121
+
122
  @tool
123
  def arxiv_search(query: str) -> str:
124
  """Search Arxiv for a query and return maximum 3 result.
 
132
  for doc in search_docs
133
  ]
134
  )
135
+ return {"arxiv_results": formatted_search_docs}
136
+
137
+
138
+ @tool
139
+ def download_file(url: str) -> str:
140
+ """Download file for a web url and return local save path
141
+ Args:
142
+ url: the file web url
143
+ """
144
+ try:
145
+ # Parse URL to get filename if not provided
146
+ if not filename:
147
+ path = urlparse(url).path
148
+ filename = os.path.basename(path)
149
+ if not filename:
150
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
151
+
152
+ # Create temporary file
153
+ temp_dir = tempfile.gettempdir()
154
+ filepath = os.path.join(temp_dir, filename)
155
+
156
+ # Download the file
157
+ response = requests.get(url, stream=True)
158
+ response.raise_for_status()
159
+
160
+ # Save the file
161
+ with open(filepath, "wb") as f:
162
+ for chunk in response.iter_content(chunk_size=8192):
163
+ f.write(chunk)
164
+
165
+ return f"File {url} downloaded to {filepath}. You can read this file to process its contents."
166
+ except Exception as e:
167
+ return f"Error downloading file: {str(e)}"