datdevsteve commited on
Commit
8eb1cc8
·
verified ·
1 Parent(s): 4d5faa9

fixes for gaia submission

Browse files
Files changed (1) hide show
  1. gaia_agent.py +119 -98
gaia_agent.py CHANGED
@@ -1,133 +1,123 @@
1
  import os
2
  import requests
3
- from langchain.agents import create_agent
4
  from langchain.tools import tool
 
5
  from dotenv import load_dotenv
6
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
7
- from ddgs import DDGS
8
  from bs4 import BeautifulSoup
9
 
10
  # Load environment variables
11
- #load_dotenv()
12
 
13
  # --- Agent Setup ---
14
- openai_key = os.getenv("OPENAI_API_KEY")
15
- googleai_key = os.getenv("GOOGLE_API_KEY")
16
-
17
- # Use OpenRouter via LangChain's ChatOpenAI
18
  openrouter_key = os.getenv("OPENROUTER_API_KEY")
19
  if not openrouter_key:
20
  raise RuntimeError("Set OPENROUTER_API_KEY in your .env (OpenRouter API key)")
21
 
22
- # Defer ChatOpenAI import until runtime to avoid import-time errors in environments without the package
23
  from langchain_openai import ChatOpenAI
24
 
25
  model = ChatOpenAI(
26
- api_key=openrouter_key,
27
- base_url="https://openrouter.ai/api/v1",
28
- model="gpt-4o-mini",
29
- max_completion_tokens=10000,
 
30
  )
31
 
32
  # --- Tools Definition ---
33
  @tool
34
- def multiply(a: int, b: int) -> int:
35
  """Multiply two numbers.
36
  Args:
37
- a: first int
38
- b: second int
39
  """
40
  return a * b
41
 
42
  @tool
43
- def add(a: int, b: int) -> int:
44
  """Add two numbers.
45
-
46
  Args:
47
- a: first int
48
- b: second int
49
  """
50
  return a + b
51
 
52
  @tool
53
- def subtract(a: int, b: int) -> int:
54
  """Subtract two numbers.
55
-
56
  Args:
57
- a: first int
58
- b: second int
59
  """
60
  return a - b
61
 
62
  @tool
63
- def divide(a: int, b: int) -> int:
64
  """Divide two numbers.
65
-
66
  Args:
67
- a: first int
68
- b: second int
69
  """
70
  if b == 0:
71
  raise ValueError("Cannot divide by zero.")
72
  return a / b
73
 
74
  @tool
75
- def modulus(a: int, b: int) -> int:
76
  """Get the modulus of two numbers.
77
-
78
  Args:
79
- a: first int
80
- b: second int
81
  """
82
  return a % b
83
 
84
  @tool
85
  def wiki_search(query: str) -> str:
86
  """Search Wikipedia for a query and return maximum 2 results."""
87
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
88
- formatted_search_docs = "\n\n---\n\n".join(
89
- [
90
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
91
- for doc in search_docs
92
- ]
93
- )
94
- return formatted_search_docs
 
 
 
95
 
96
  @tool
97
  def web_search(query: str) -> str:
98
- """Search DDGS for a query and return maximum 3 results."""
99
- search_docs = DDGS().text(query, max_results=3)
100
- formatted_search_docs = "\n\n---\n\n".join(
101
- [
102
- f'Title:{doc["title"]}\nContent:{doc["body"]}\n--\n'
103
- for doc in search_docs
104
- ]
105
- )
106
- return formatted_search_docs
 
 
 
107
 
108
  @tool
109
  def arxiv_search(query: str) -> str:
110
  """Search arXiv for a query and return maximum 3 results."""
111
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
112
- formatted_search_docs = "\n\n---\n\n".join(
113
- [
114
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
115
- for doc in search_docs
116
- ]
117
- )
118
- return formatted_search_docs
119
-
120
- @tool
121
- def image_search(query: str) -> str:
122
- """Searches DDGS for an image query and returns maximum 10 image results"""
123
- search_images = DDGS().images(query=query)
124
- formatted_result = "\n\n---\n\n".join(
125
- [
126
- f'Image Title:{image["title"]}\nImage URL: {image["url"]}'
127
- for image in search_images
128
- ]
129
- )
130
- return formatted_result
131
 
132
  @tool
133
  def fetch_url_content(url: str) -> str:
@@ -139,52 +129,83 @@ def fetch_url_content(url: str) -> str:
139
  for script in soup(["script", "style"]):
140
  script.decompose()
141
  text = soup.get_text(separator='\n', strip=True)
142
- return text[:2000] + ("..." if len(text) > 2000 else "")
143
  except Exception as e:
144
  return f"Error fetching URL: {str(e)}"
145
 
146
  # Tools list
147
  tools = [
148
  multiply, add, subtract, divide, modulus,
149
- wiki_search, web_search, arxiv_search, image_search,
150
  fetch_url_content,
151
  ]
152
 
153
- # System prompt
154
- sys_prompt = """You are a helpful agent, please provide clear and concise answers to asked questions.
155
- Keep your word limit for answers as minimum as you can. You are equipped with the following tools:
156
- 1. [multiply], [add], [subtract], [divide], [modulus] - basic calculator operations.
157
- 2. [wiki_search] - search Wikipedia and return up to 2 documents as text.
158
- 3. [web_search] - perform a web search and return up to 3 documents as text.
159
- 4. [arxiv_search] - search arXiv and return up to 3 documents as text.
160
- 5. [image_search] - Searches the internet for an image query and returns maximum 10 image results
161
 
162
- Under any circumstances, if you fail to provide the accurate answer expected by the user, you may say the same to the user and provide a similar answer which is approximately the closest. Disregard spelling mistakes and provide answer with results retreived from the correct spelling.
 
 
 
 
 
 
 
163
 
164
- For every tool you use, append a single line at the end of your response exactly in this format:
165
- [TOOLS USED: (tool_name)]
166
- When no tools are used, append:
167
- [TOOLS USED WERE NONE]
168
- """
 
 
 
 
 
169
 
170
  class GAIAAgent:
171
  def __init__(self):
172
- # create internal agent
173
  try:
174
- self.agent = create_agent(model, tools=tools, system_prompt=sys_prompt)
 
 
 
 
 
 
 
175
  except Exception as e:
 
176
  raise
177
 
178
  def __call__(self, question: str) -> str:
179
- result = self.agent.invoke({"messages": [{"role": "user", "content": question}]})
180
- raw_content = result["messages"][-1].content
181
- if isinstance(raw_content, list) and len(raw_content) > 0:
182
- if isinstance(raw_content[0], dict) and 'text' in raw_content[0]:
183
- answer = raw_content[0]['text']
184
- else:
185
- answer = str(raw_content)
186
- elif isinstance(raw_content, str):
187
- answer = raw_content
188
- else:
189
- answer = str(raw_content)
190
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import requests
3
+ from langchain.agents import create_react_agent, AgentExecutor
4
  from langchain.tools import tool
5
+ from langchain_core.prompts import PromptTemplate
6
  from dotenv import load_dotenv
7
  from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
8
+ from duckduckgo_search import DDGS
9
  from bs4 import BeautifulSoup
10
 
11
  # Load environment variables
12
+ # load_dotenv()
13
 
14
  # --- Agent Setup ---
 
 
 
 
15
  openrouter_key = os.getenv("OPENROUTER_API_KEY")
16
  if not openrouter_key:
17
  raise RuntimeError("Set OPENROUTER_API_KEY in your .env (OpenRouter API key)")
18
 
 
19
  from langchain_openai import ChatOpenAI
20
 
21
  model = ChatOpenAI(
22
+ api_key=openrouter_key,
23
+ base_url="https://openrouter.ai/api/v1",
24
+ model="openai/gpt-4o-mini",
25
+ max_tokens=10000,
26
+ temperature=0
27
  )
28
 
29
  # --- Tools Definition ---
30
  @tool
31
+ def multiply(a: float, b: float) -> float:
32
  """Multiply two numbers.
33
  Args:
34
+ a: first number
35
+ b: second number
36
  """
37
  return a * b
38
 
39
  @tool
40
+ def add(a: float, b: float) -> float:
41
  """Add two numbers.
 
42
  Args:
43
+ a: first number
44
+ b: second number
45
  """
46
  return a + b
47
 
48
  @tool
49
+ def subtract(a: float, b: float) -> float:
50
  """Subtract two numbers.
 
51
  Args:
52
+ a: first number
53
+ b: second number
54
  """
55
  return a - b
56
 
57
  @tool
58
+ def divide(a: float, b: float) -> float:
59
  """Divide two numbers.
 
60
  Args:
61
+ a: first number
62
+ b: second number
63
  """
64
  if b == 0:
65
  raise ValueError("Cannot divide by zero.")
66
  return a / b
67
 
68
  @tool
69
+ def modulus(a: float, b: float) -> float:
70
  """Get the modulus of two numbers.
 
71
  Args:
72
+ a: first number
73
+ b: second number
74
  """
75
  return a % b
76
 
77
  @tool
78
  def wiki_search(query: str) -> str:
79
  """Search Wikipedia for a query and return maximum 2 results."""
80
+ try:
81
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
82
+ formatted_search_docs = "\n\n---\n\n".join(
83
+ [
84
+ f'\n{doc.page_content}\n'
85
+ for doc in search_docs
86
+ ]
87
+ )
88
+ return formatted_search_docs
89
+ except Exception as e:
90
+ return f"Error searching Wikipedia: {str(e)}"
91
 
92
  @tool
93
  def web_search(query: str) -> str:
94
+ """Search the web for a query and return maximum 3 results."""
95
+ try:
96
+ search_docs = DDGS().text(query, max_results=3)
97
+ formatted_search_docs = "\n\n---\n\n".join(
98
+ [
99
+ f'Title:{doc["title"]}\nContent:{doc["body"]}\n--\n'
100
+ for doc in search_docs
101
+ ]
102
+ )
103
+ return formatted_search_docs
104
+ except Exception as e:
105
+ return f"Error searching web: {str(e)}"
106
 
107
  @tool
108
  def arxiv_search(query: str) -> str:
109
  """Search arXiv for a query and return maximum 3 results."""
110
+ try:
111
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
112
+ formatted_search_docs = "\n\n---\n\n".join(
113
+ [
114
+ f'\n{doc.page_content[:1000]}\n'
115
+ for doc in search_docs
116
+ ]
117
+ )
118
+ return formatted_search_docs
119
+ except Exception as e:
120
+ return f"Error searching arXiv: {str(e)}"
 
 
 
 
 
 
 
 
 
121
 
122
  @tool
123
  def fetch_url_content(url: str) -> str:
 
129
  for script in soup(["script", "style"]):
130
  script.decompose()
131
  text = soup.get_text(separator='\n', strip=True)
132
+ return text[:3000] + ("..." if len(text) > 3000 else "")
133
  except Exception as e:
134
  return f"Error fetching URL: {str(e)}"
135
 
136
  # Tools list
137
  tools = [
138
  multiply, add, subtract, divide, modulus,
139
+ wiki_search, web_search, arxiv_search,
140
  fetch_url_content,
141
  ]
142
 
143
+ # React prompt template
144
+ react_prompt = PromptTemplate.from_template("""You are a helpful assistant that answers questions accurately and concisely.
145
+
146
+ Answer the following questions as best you can. You have access to the following tools:
147
+
148
+ {tools}
149
+
150
+ Use the following format:
151
 
152
+ Question: the input question you must answer
153
+ Thought: you should always think about what to do
154
+ Action: the action to take, should be one of [{tool_names}]
155
+ Action Input: the input to the action
156
+ Observation: the result of the action
157
+ ... (this Thought/Action/Action Input/Observation can repeat N times)
158
+ Thought: I now know the final answer
159
+ Final Answer: the final answer to the original input question
160
 
161
+ IMPORTANT: Your Final Answer must be:
162
+ - Short and direct (just the answer, no extra explanation)
163
+ - A single value or short phrase
164
+ - No formatting, no bullet points, no extra text
165
+ - Just the factual answer
166
+
167
+ Begin!
168
+
169
+ Question: {input}
170
+ Thought:{agent_scratchpad}""")
171
 
172
  class GAIAAgent:
173
  def __init__(self):
174
+ # create internal agent with React agent
175
  try:
176
+ agent = create_react_agent(model, tools, react_prompt)
177
+ self.agent_executor = AgentExecutor(
178
+ agent=agent,
179
+ tools=tools,
180
+ verbose=True,
181
+ handle_parsing_errors=True,
182
+ max_iterations=15
183
+ )
184
  except Exception as e:
185
+ print(f"Error creating agent: {e}")
186
  raise
187
 
188
  def __call__(self, question: str) -> str:
189
+ try:
190
+ result = self.agent_executor.invoke({"input": question})
191
+ answer = result.get("output", "")
192
+
193
+ # Clean up the answer - remove any extra formatting
194
+ answer = answer.strip()
195
+
196
+ # Remove common prefixes that might be added
197
+ prefixes_to_remove = [
198
+ "The answer is:",
199
+ "The final answer is:",
200
+ "Final Answer:",
201
+ "Answer:",
202
+ ]
203
+
204
+ for prefix in prefixes_to_remove:
205
+ if answer.startswith(prefix):
206
+ answer = answer[len(prefix):].strip()
207
+
208
+ return answer
209
+ except Exception as e:
210
+ print(f"Error invoking agent: {e}")
211
+ return f"Error: {str(e)}"