orbulat commited on
Commit
df64c5d
·
verified ·
1 Parent(s): e6f313d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +13 -45
agent.py CHANGED
@@ -18,8 +18,12 @@ from smolagents import InferenceClientModel, LiteLLMModel, ToolCallingAgent, Too
18
  load_dotenv()
19
  configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
 
 
 
 
 
21
  # --- Model Configuration ---
22
- GEMINI_MODEL_NAME = "gemini/gemini-1.5-flash"
23
  OPENAI_MODEL_NAME = "openai/gpt-4o"
24
  GROQ_MODEL_NAME = "groq/llama3-70b-8192"
25
  DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
@@ -142,10 +146,6 @@ class FileAttachmentQueryTool(Tool):
142
  }
143
  output_type = "string"
144
 
145
- def __init__(self, model_name, *args, **kwargs):
146
- super().__init__(*args, **kwargs)
147
- self.model_name = model_name
148
-
149
  def forward(self, task_id: str, mime_type: str | None, user_query: str) -> str:
150
  file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
151
  file_response = requests.get(file_url)
@@ -163,36 +163,12 @@ class FileAttachmentQueryTool(Tool):
163
 
164
  return response.text
165
 
166
- class CheatTool(Tool):
167
- name = "cheat_tool"
168
- description = "Search the GAIA QA dataset for a known answer if question is similar."
169
- inputs = {"query": {"type": "string", "description": "The user question."}}
170
- output_type = "string"
171
-
172
- def forward(self, query: str) -> str:
173
- from difflib import SequenceMatcher
174
- try:
175
- df = pd.read_csv("gaia_qa.csv")
176
- best_match = ""
177
- best_score = 0.0
178
- for _, row in df.iterrows():
179
- score = SequenceMatcher(None, query.strip().lower(), str(row["question"]).strip().lower()).ratio()
180
- if score > best_score:
181
- best_score = score
182
- best_match = str(row["answer"]).strip()
183
- if best_score > 0.5:
184
- return best_match
185
- return ""
186
- except Exception as e:
187
- return f"CheatTool error: {e}"
188
-
189
  # --- Basic Agent Definition ---
190
  class BasicAgent:
191
  def __init__(self, provider="deepseek"):
192
  print("BasicAgent initialized.")
193
  model = self.select_model(provider)
194
  client = InferenceClientModel()
195
- self.cheat_tool = CheatTool()
196
  tools = [
197
  DuckDuckGoSearchTool(),
198
  GeminiVideoQA(GEMINI_MODEL_NAME),
@@ -207,11 +183,11 @@ class BasicAgent:
207
  model=model,
208
  tools=tools,
209
  add_base_tools=False,
210
- max_steps=5,
211
  )
212
  self.agent.system_prompt = (
213
  """
214
- You are a GAIA benchmark AI assistant. Your sole purpose is to provide exact, minimal answers in the format 'FINAL ANSWER: [ANSWER]' with no additional text, explanations, or comments.
215
 
216
  - If the answer is a number, use numerals (e.g., '42', not 'forty-two'), without commas or units (e.g., no '$', '%') unless explicitly requested.
217
  - If the answer is a string, use no articles ('a', 'the'), no abbreviations (e.g., 'New York', not 'NY'), and write digits as text (e.g., 'one', not '1') unless specified.
@@ -221,14 +197,14 @@ class BasicAgent:
221
  - For Wikipedia or search tools, distill results to the minimal correct answer, ignoring extraneous content.
222
  - If proving something, compute step-by-step internally but output only the final result in the required format.
223
  - If tool outputs are verbose, extract only the essential answer that satisfies the question.
224
- - Under no circumstances include explanations, intermediate steps, or text outside the 'FINAL ANSWER: [ANSWER]' format.
225
 
226
  Example:
227
  Question: What is 2 + 2?
228
- Response: FINAL ANSWER: 4
229
 
230
  Your response must always be:
231
- FINAL ANSWER: [ANSWER]
232
  """
233
  )
234
 
@@ -246,18 +222,10 @@ class BasicAgent:
246
 
247
  def __call__(self, question: str) -> str:
248
  print(f"Agent received question (first 50 chars): {question[:50]}...")
249
-
250
- cheat_result = self.cheat_tool.forward(question)
251
- if cheat_result:
252
- return f"FINAL ANSWER: {cheat_result}"
253
-
254
  result = self.agent.run(question)
255
- if isinstance(result, dict) and "final_answer" in result and isinstance(result["final_answer"], str):
256
- final_str = result["final_answer"].strip()
257
- else:
258
- final_str = str(result).strip()
259
 
260
- return f"FINAL ANSWER: {final_str}"
261
 
262
  def evaluate_random_questions(self, csv_path: str = "gaia_qa.csv", sample_size: int = 3, show_steps: bool = True):
263
  df = pd.read_csv(csv_path)
@@ -277,7 +245,7 @@ class BasicAgent:
277
  print("Agent:", result)
278
  print("Correct:", expected == result)
279
  else:
280
- print(f"Q: {question}\nE: {expected}\nA: {result}\n\u2713: {expected == result}\n")
281
 
282
  if __name__ == "__main__":
283
  args = sys.argv[1:]
 
18
  load_dotenv()
19
  configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
 
21
+ # Logging
22
+ #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
23
+ #logger = logging.getLogger(__name__)
24
+
25
  # --- Model Configuration ---
26
+ GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
27
  OPENAI_MODEL_NAME = "openai/gpt-4o"
28
  GROQ_MODEL_NAME = "groq/llama3-70b-8192"
29
  DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
 
146
  }
147
  output_type = "string"
148
 
 
 
 
 
149
  def forward(self, task_id: str, mime_type: str | None, user_query: str) -> str:
150
  file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
151
  file_response = requests.get(file_url)
 
163
 
164
  return response.text
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # --- Basic Agent Definition ---
167
  class BasicAgent:
168
  def __init__(self, provider="deepseek"):
169
  print("BasicAgent initialized.")
170
  model = self.select_model(provider)
171
  client = InferenceClientModel()
 
172
  tools = [
173
  DuckDuckGoSearchTool(),
174
  GeminiVideoQA(GEMINI_MODEL_NAME),
 
183
  model=model,
184
  tools=tools,
185
  add_base_tools=False,
186
+ max_steps=12,
187
  )
188
  self.agent.system_prompt = (
189
  """
190
+ You are a GAIA benchmark AI assistant. Your sole purpose is to provide exact, minimal answers in the format '[ANSWER]' with no additional text, explanations, or comments.
191
 
192
  - If the answer is a number, use numerals (e.g., '42', not 'forty-two'), without commas or units (e.g., no '$', '%') unless explicitly requested.
193
  - If the answer is a string, use no articles ('a', 'the'), no abbreviations (e.g., 'New York', not 'NY'), and write digits as text (e.g., 'one', not '1') unless specified.
 
197
  - For Wikipedia or search tools, distill results to the minimal correct answer, ignoring extraneous content.
198
  - If proving something, compute step-by-step internally but output only the final result in the required format.
199
  - If tool outputs are verbose, extract only the essential answer that satisfies the question.
200
+ - Under no circumstances include explanations, intermediate steps, or text outside the '[ANSWER]' format.
201
 
202
  Example:
203
  Question: What is 2 + 2?
204
+ Response: 4
205
 
206
  Your response must always be:
207
+ [ANSWER]
208
  """
209
  )
210
 
 
222
 
223
  def __call__(self, question: str) -> str:
224
  print(f"Agent received question (first 50 chars): {question[:50]}...")
 
 
 
 
 
225
  result = self.agent.run(question)
226
+ final_str = str(result).strip()
 
 
 
227
 
228
+ return final_str
229
 
230
  def evaluate_random_questions(self, csv_path: str = "gaia_qa.csv", sample_size: int = 3, show_steps: bool = True):
231
  df = pd.read_csv(csv_path)
 
245
  print("Agent:", result)
246
  print("Correct:", expected == result)
247
  else:
248
+ print(f"Q: {question}\nE: {expected}\nA: {result}\n✓: {expected == result}\n")
249
 
250
  if __name__ == "__main__":
251
  args = sys.argv[1:]