j14i commited on
Commit
86b8466
·
1 Parent(s): 90a12a7

Shoutout Warns

Browse files
Files changed (2) hide show
  1. agent.py +3 -0
  2. test_bench.py +21 -1
agent.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from typing import Annotated, TypedDict
3
 
4
  from dotenv import load_dotenv
@@ -11,6 +12,8 @@ from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph
11
  from langgraph.prebuilt import ToolNode
12
  from pydantic import SecretStr
13
 
 
 
14
  load_dotenv()
15
 
16
 
 
1
  import os
2
+ import warnings
3
  from typing import Annotated, TypedDict
4
 
5
  from dotenv import load_dotenv
 
12
  from langgraph.prebuilt import ToolNode
13
  from pydantic import SecretStr
14
 
15
+ warnings.filterwarnings("ignore", category=UserWarning, module="langchain_tavily")
16
+
17
  load_dotenv()
18
 
19
 
test_bench.py CHANGED
@@ -7,6 +7,7 @@ Usage:
7
  uv run python test_bench.py --level 1 # Run on level 1 only
8
  uv run python test_bench.py --level 1 --n 3 # Run on 3 level 1 questions
9
  uv run python test_bench.py --all # Run on all validation questions
 
10
 
11
  uv run python test_bench.py --type youtube,file,web,excel,pdf,image,audio,text-only
12
  """
@@ -146,6 +147,21 @@ def filter_by_type(
146
  return filtered
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def load_gaia_data(level: int | None = None) -> list[GaiaQuestion]:
150
  """Load GAIA validation dataset with all metadata."""
151
  print("Downloading GAIA dataset...")
@@ -179,6 +195,7 @@ def run_test_bench(
179
  task_type: str | None = None,
180
  run_all: bool = False,
181
  save_results: bool = True,
 
182
  ) -> list[TestResult]:
183
  """
184
  Run the test bench on the agent.
@@ -198,6 +215,7 @@ def run_test_bench(
198
 
199
  questions = load_gaia_data(level=level)
200
  questions = filter_by_type(questions, task_type)
 
201
 
202
  if not questions:
203
  print(f"No questions found for type '{task_type}'")
@@ -219,7 +237,7 @@ def run_test_bench(
219
  if q.file_name:
220
  print(f" File: {q.file_name}")
221
  if q.annotator.tools:
222
- print(f" Tools needed: {q.annotator.tools}")
223
 
224
  try:
225
  actual = agent(q.question)
@@ -349,6 +367,7 @@ def main():
349
  )
350
  parser.add_argument("--all", action="store_true", help="Run all questions")
351
  parser.add_argument("--no-save", action="store_true", help="Don't save results")
 
352
  args = parser.parse_args()
353
 
354
  from agent import BasicAgent
@@ -365,6 +384,7 @@ def main():
365
  task_type=args.type,
366
  run_all=args.all,
367
  save_results=not args.no_save,
 
368
  )
369
 
370
 
 
7
  uv run python test_bench.py --level 1 # Run on level 1 only
8
  uv run python test_bench.py --level 1 --n 3 # Run on 3 level 1 questions
9
  uv run python test_bench.py --all # Run on all validation questions
10
+ uv run python test_bench.py --task-id 1234 # Run on specific task ID
11
 
12
  uv run python test_bench.py --type youtube,file,web,excel,pdf,image,audio,text-only
13
  """
 
147
  return filtered
148
 
149
 
150
+ def filter_by_task_id(
151
+ questions: list[GaiaQuestion], task_id: str | None
152
+ ) -> list[GaiaQuestion]:
153
+ """Filter questions by task id."""
154
+ if not task_id:
155
+ return questions
156
+
157
+ filtered = []
158
+ for q in questions:
159
+ if q.task_id == task_id:
160
+ filtered.append(q)
161
+
162
+ return filtered
163
+
164
+
165
  def load_gaia_data(level: int | None = None) -> list[GaiaQuestion]:
166
  """Load GAIA validation dataset with all metadata."""
167
  print("Downloading GAIA dataset...")
 
195
  task_type: str | None = None,
196
  run_all: bool = False,
197
  save_results: bool = True,
198
+ task_id: str | None = None,
199
  ) -> list[TestResult]:
200
  """
201
  Run the test bench on the agent.
 
215
 
216
  questions = load_gaia_data(level=level)
217
  questions = filter_by_type(questions, task_type)
218
+ questions = filter_by_task_id(questions, task_id)
219
 
220
  if not questions:
221
  print(f"No questions found for type '{task_type}'")
 
237
  if q.file_name:
238
  print(f" File: {q.file_name}")
239
  if q.annotator.tools:
240
+ print(f" Tools needed:\n{q.annotator.tools}")
241
 
242
  try:
243
  actual = agent(q.question)
 
367
  )
368
  parser.add_argument("--all", action="store_true", help="Run all questions")
369
  parser.add_argument("--no-save", action="store_true", help="Don't save results")
370
+ parser.add_argument("--task-id", type=str, help="Run specific task ID")
371
  args = parser.parse_args()
372
 
373
  from agent import BasicAgent
 
384
  task_type=args.type,
385
  run_all=args.all,
386
  save_results=not args.no_save,
387
+ task_id=args.task_id,
388
  )
389
 
390