kamorou commited on
Commit
9590080
·
verified ·
1 Parent(s): 0fa0473

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +103 -210
agent.py CHANGED
@@ -1,253 +1,146 @@
 
 
 
1
  import os
2
  from dotenv import load_dotenv
3
- import gradio as gr
4
 
5
-
6
- # LangGraph & LangChain
7
- from langgraph.graph import START, StateGraph, MessagesState
8
- from langgraph.prebuilt import ToolNode, tools_condition
9
- from langchain_core.messages import SystemMessage, HumanMessage
10
- from langchain_core.tools import tool
11
-
12
- #infrence provider
13
  from langchain_huggingface import HuggingFaceEndpoint
14
- # Web search tool
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
-
17
- # # NEW IMPORT
18
- # from langchain_experimental.tools import PythonREPLTool
19
- # from langchain_core.messages import BaseMessage, HumanMessage
20
- # from langgraph.graph import StateGraph, END
21
- # from langgraph.prebuilt import ToolNode
22
- # from typing import TypedDict, Annotated, List
23
-
24
-
25
- # --- 1. LOAD API KEYS ---
26
  load_dotenv()
27
  hf_token = os.getenv("HF_TOKEN")
28
  tavily_api_key = os.getenv("TAVILY_API_KEY")
29
 
30
  if not hf_token or not tavily_api_key:
31
- raise ValueError("Hugging Face Token or Tavily API Key is not set in the environment variables.")
 
32
  os.environ["TAVILY_API_KEY"] = tavily_api_key
33
 
34
-
35
- # --- 2. DEFINE TOOLS and INITIALIZE LLM ---
36
- # UPDATED TOOLS LIST
37
- # tools = [TavilySearchResults(max_results=3), PythonREPLTool()]
38
- # tool_node = ToolNode(tools)
39
-
40
-
41
- ### TOOLS
42
-
43
- @tool
44
- def multiply(a: int, b: int) -> int:
45
- """Multiply two numbers.
46
- Args:
47
- a: first int
48
- b: second int
49
- """
50
- return a * b
51
-
52
- @tool
53
- def add(a: int, b: int) -> int:
54
- """Add two numbers.
55
-
56
- Args:
57
- a: first int
58
- b: second int
59
- """
60
- return a + b
61
-
62
- @tool
63
- def subtract(a: int, b: int) -> int:
64
- """Subtract two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
- return a - b
71
-
72
- @tool
73
- def divide(a: int, b: int) -> int:
74
- """Divide two numbers.
75
-
76
- Args:
77
- a: first int
78
- b: second int
79
- """
80
- if b == 0:
81
- raise ValueError("Cannot divide by zero.")
82
- return a / b
83
-
84
- @tool
85
- def modulus(a: int, b: int) -> int:
86
- """Get the modulus of two numbers.
87
-
88
- Args:
89
- a: first int
90
- b: second int
91
- """
92
- return a % b
93
-
94
-
95
-
96
- @tool
97
- def web_search(query: str) -> str:
98
- """Search Tavily for a query and return maximum 3 results.
99
-
100
- Args:
101
- query: The search query."""
102
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
103
- formatted_search_docs = "\n\n---\n\n".join(
104
- [
105
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
106
- for doc in search_docs
107
- ])
108
- return {"web_results": formatted_search_docs}
109
-
110
- # SYSTEM PROMPT
111
- system_prompt = """
112
- You are a helpful assistant tasked with answering questions using a set of tools.
113
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
114
- FINAL ANSWER: [YOUR FINAL ANSWER].
115
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
116
- Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
117
  """
118
- tools = [divide, add, multiply,subtract, web_search,modulus ]
119
-
120
- ### LLM
121
- # repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
122
- # llm = HuggingFaceEndpoint(
123
- # repo_id=repo_id,
124
- # huggingfacehub_api_token=hf_token,
125
- # temperature=0.1,
126
- # max_new_tokens=1024,
127
- # )
128
- # llm_with_tools = llm.bind_tools(tools)
129
-
130
-
131
- def build_graph():
132
- """Builds and returns the LangGraph graph."""
133
- #llm = ChatGroq(model="qwen-qwq-32b", temperature=0,api_key=groq_api_key)
134
- repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
135
- llm = HuggingFaceEndpoint(
136
  repo_id=repo_id,
137
  huggingfacehub_api_token=hf_token,
138
- temperature=0,
139
- max_new_tokens=1024,
140
  )
141
- llm_with_tools = llm.bind_tools(tools)
142
-
143
- # Node
144
- def assistant(state: MessagesState):
145
- """Assistant node"""
146
- return {"messages": [llm_with_tools.invoke([system_prompt] + state["messages"])]}
147
-
148
-
149
- builder = StateGraph(MessagesState)
150
- # Nodes
151
- builder.add_node("assistant", assistant)
152
- builder.add_node("tools", ToolNode(tools))
153
-
154
- # Edges
155
- builder.add_edge(START, "assistant")
156
- builder.add_conditional_edges("assistant", tools_condition)
157
- builder.add_edge("tools", "assistant")
158
-
159
- #Compile graph
160
- return builder.compile()
161
-
162
-
163
-
164
-
165
-
166
 
 
 
 
167
 
168
-
169
-
170
-
171
-
172
-
173
-
174
-
175
-
176
-
177
-
178
- # --- 3. DEFINE THE AGENT'S STATE ---
179
- """
180
  class AgentState(TypedDict):
181
  messages: Annotated[List[BaseMessage], lambda x, y: x + y]
182
 
 
 
 
183
 
 
 
 
 
184
 
 
 
 
 
 
185
 
186
-
187
- # --- 4. DEFINE THE NODES OF THE GRAPH ---
188
- def agent_node(state):
189
- response = llm_with_tools.invoke(state["messages"])
190
  return {"messages": [response]}
191
 
192
-
193
- # --- 5. DEFINE THE EDGES OF THE GRAPH ---
194
  def should_continue(state):
195
  last_message = state["messages"][-1]
196
  if last_message.tool_calls:
197
- return "tools"
198
- return END
199
-
200
 
201
- # --- 6. ASSEMBLE THE GRAPH ---
202
  workflow = StateGraph(AgentState)
203
-
204
  workflow.add_node("agent", agent_node)
205
  workflow.add_node("tools", tool_node)
206
-
207
  workflow.set_entry_point("agent")
208
-
209
  workflow.add_conditional_edges(
210
  "agent",
211
  should_continue,
212
- {
213
- "tools": "tools",
214
- "end": END,
215
- },
216
  )
217
-
218
  workflow.add_edge("tools", "agent")
219
 
 
220
  app = workflow.compile()
221
 
222
 
223
- # --- 7. CREATE THE USER INTERFACE (UI) ---
224
- def run_agent(query: str):
225
- try:
226
- inputs = {"messages": [HumanMessage(content=query)]}
227
- final_response = None
228
- # Using stream to get final output, can be slow for complex tasks
229
- for s in app.stream(inputs, {"recursion_limit": 10}):
230
- if "agent" in s:
231
- final_response = s["agent"]["messages"][-1].content
232
- return final_response if final_response else "Agent did not produce a final answer."
233
-
234
- except Exception as e:
235
- return f"An error occurred: {e}"
236
-
237
- iface = gr.Interface(
238
- fn=run_agent,
239
- inputs=gr.Textbox(lines=2, placeholder="Ask the agent anything..."),
240
- outputs="markdown",
241
- title="GAIA Agent v0.3 (LangGraph + Code Interpreter)",
242
- description="This agent can use web search and a Python code interpreter.",
243
- examples=[
244
- ["What is the square root of the number of states in the USA?"],
245
- ["What is the total number of letters in the names of the first three planets in our solar system?"]
246
- ],
247
- )
248
-
249
-
250
- # --- 8. LAUNCH THE APP ---
251
- iface.launch()
252
- # gr.launch()
253
- """
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # 1. IMPORTS AND SETUP
3
+ # ==============================================================================
4
  import os
5
  from dotenv import load_dotenv
6
+ from typing import TypedDict, Annotated, List
7
 
8
+ # LangChain and LangGraph imports
 
 
 
 
 
 
 
9
  from langchain_huggingface import HuggingFaceEndpoint
 
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_experimental.tools import PythonREPLTool
12
+ from langchain_core.messages import BaseMessage, HumanMessage
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langgraph.graph import StateGraph, END
15
+ from langgraph.prebuilt import ToolNode
16
+
17
+ # ==============================================================================
18
+ # 2. LOAD API KEYS AND DEFINE TOOLS
19
+ # ==============================================================================
 
20
  load_dotenv()
21
  hf_token = os.getenv("HF_TOKEN")
22
  tavily_api_key = os.getenv("TAVILY_API_KEY")
23
 
24
  if not hf_token or not tavily_api_key:
25
+ # This will show a clear error in the logs if keys are missing
26
+ raise ValueError("HF_TOKEN or TAVILY_API_KEY not set. Please add them to your Space secrets.")
27
  os.environ["TAVILY_API_KEY"] = tavily_api_key
28
 
29
+ # The agent's tools
30
+ tools = [TavilySearchResults(max_results=3, description="A search engine for finding up-to-date information on the web."), PythonREPLTool()]
31
+ tool_node = ToolNode(tools)
32
+
33
+ # ==============================================================================
34
+ # 3. CONFIGURE THE LLM (THE "BRAIN")
35
+ # ==============================================================================
36
+ # The model we'll use as the agent's brain
37
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
38
+
39
+ # The system prompt gives the agent its mission and instructions
40
+ SYSTEM_PROMPT = """You are a highly capable AI agent named 'GAIA-Solver'. Your mission is to accurately answer complex questions.
41
+
42
+ **Your Instructions:**
43
+ 1. **Analyze:** Carefully read the user's question to understand all parts of what is being asked.
44
+ 2. **Plan:** Think step-by-step. Break the problem into smaller tasks. Decide which tool is best for each task. (e.g., use 'tavily_search_results_json' for web searches, use 'python_repl' for calculations or code execution).
45
+ 3. **Execute:** Call ONE tool at a time.
46
+ 4. **Observe & Reason:** After getting a tool's result, observe it. Decide if you have the final answer or if you need to use another tool.
47
+ 5. **Final Answer:** Once you are confident, provide a clear, direct, and concise final answer. Do not include your thought process in the final answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  """
49
+
50
+ # Initialize the LLM endpoint
51
+ llm = HuggingFaceEndpoint(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  repo_id=repo_id,
53
  huggingfacehub_api_token=hf_token,
54
+ temperature=0, # Set to 0 for deterministic, less random output
55
+ max_new_tokens=2048,
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # ==============================================================================
59
+ # 4. BUILD THE LANGGRAPH AGENT
60
+ # ==============================================================================
61
 
62
+ # Define the Agent's State (its memory)
 
 
 
 
 
 
 
 
 
 
 
63
  class AgentState(TypedDict):
64
  messages: Annotated[List[BaseMessage], lambda x, y: x + y]
65
 
66
+ # This is a more robust way to combine the prompt, model, and tool binding
67
+ # It ensures the system prompt is always used.
68
+ llm_with_tools = llm.bind_tools(tools)
69
 
70
+ # Define the Agent Node
71
+ def agent_node(state):
72
+ # Get the last message to pass to the model
73
+ last_message = state['messages'][-1]
74
 
75
+ # Prepend the system prompt to every call
76
+ prompt_with_system = [
77
+ HumanMessage(content=SYSTEM_PROMPT, name="system_prompt"),
78
+ last_message
79
+ ]
80
 
81
+ response = llm_with_tools.invoke(prompt_with_system)
 
 
 
82
  return {"messages": [response]}
83
 
84
+ # Define the Edge Logic
 
85
  def should_continue(state):
86
  last_message = state["messages"][-1]
87
  if last_message.tool_calls:
88
+ return "tools" # Route to the tool node
89
+ return END # End the process
 
90
 
91
+ # Assemble the graph
92
  workflow = StateGraph(AgentState)
 
93
  workflow.add_node("agent", agent_node)
94
  workflow.add_node("tools", tool_node)
 
95
  workflow.set_entry_point("agent")
 
96
  workflow.add_conditional_edges(
97
  "agent",
98
  should_continue,
99
+ {"tools": "tools", "end": END},
 
 
 
100
  )
 
101
  workflow.add_edge("tools", "agent")
102
 
103
+ # Compile the graph into a runnable app
104
  app = workflow.compile()
105
 
106
 
107
+ # ==============================================================================
108
+ # 5. THE BASICAGENT CLASS (FOR THE TEST HARNESS)
109
+ # This MUST be at the end, after `app` is defined.
110
+ # ==============================================================================
111
+ class BasicAgent:
112
+ """
113
+ This is the agent class that the GAIA test harness will use.
114
+ """
115
+ def __init__(self):
116
+ # The compiled LangGraph app is our agent executor
117
+ self.agent_executor = app
118
+
119
+ def run(self, question: str) -> str:
120
+ """
121
+ This method is called by the test script with each question.
122
+ It runs the LangGraph agent and returns the final answer.
123
+ """
124
+ print(f"Agent received question (first 80 chars): {question[:80]}...")
125
+ try:
126
+ # Format the input for our graph
127
+ inputs = {"messages": [HumanMessage(content=question)]}
128
+
129
+ # Stream the response to get the final answer
130
+ final_response = ""
131
+ for s in self.agent_executor.stream(inputs, {"recursion_limit": 15}):
132
+ if "agent" in s:
133
+ # The final answer is the content of the last message from the agent node
134
+ if s["agent"]["messages"][-1].content:
135
+ final_response = s["agent"]["messages"][-1].content
136
+
137
+ # A fallback in case the agent finishes without a clear message
138
+ if not final_response:
139
+ final_response = "Agent finished but did not produce a final answer."
140
+
141
+ print(f"Agent returning final answer (first 80 chars): {final_response[:80]}...")
142
+ return final_response
143
+
144
+ except Exception as e:
145
+ print(f"An error occurred in agent execution: {e}")
146
+ return f"Error: {e}"