| |
| import argparse |
| import json |
| import os |
| import threading |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any |
|
|
| import datasets |
| import pandas as pd |
| from dotenv import load_dotenv |
| from huggingface_hub import login, snapshot_download |
| from scripts.reformulator import prepare_response |
| from scripts.run_agents import ( |
| get_single_file_description, |
| get_zip_description, |
| ) |
| from scripts.text_inspector_tool import TextInspectorTool |
| from scripts.text_web_browser import ( |
| ArchiveSearchTool, |
| FinderTool, |
| FindNextTool, |
| PageDownTool, |
| PageUpTool, |
| SimpleTextBrowser, |
| VisitTool, |
| ) |
| from scripts.visual_qa import visualizer |
| from tqdm import tqdm |
|
|
| from smolagents import ( |
| CodeAgent, |
| GoogleSearchTool, |
| LiteLLMModel, |
| Model, |
| ToolCallingAgent, |
| ) |
|
|
|
|
| load_dotenv(override=True) |
| login(os.getenv("HF_TOKEN")) |
|
|
| append_answer_lock = threading.Lock() |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--concurrency", type=int, default=8) |
| parser.add_argument("--model-id", type=str, default="o1") |
| parser.add_argument("--run-name", type=str, required=True) |
| parser.add_argument("--set-to-run", type=str, default="validation") |
| parser.add_argument("--use-open-models", type=bool, default=False) |
| parser.add_argument("--use-raw-dataset", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| |
|
|
| print("Make sure you deactivated any VPN like Tailscale, else some URLs will be blocked!") |
|
|
| custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"} |
|
|
|
|
| user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0" |
|
|
| BROWSER_CONFIG = { |
| "viewport_size": 1024 * 5, |
| "downloads_folder": "downloads_folder", |
| "request_kwargs": { |
| "headers": {"User-Agent": user_agent}, |
| "timeout": 300, |
| }, |
| "serpapi_key": os.getenv("SERPAPI_API_KEY"), |
| } |
|
|
| os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True) |
|
|
|
|
| def create_agent_team(model: Model): |
| text_limit = 100000 |
| ti_tool = TextInspectorTool(model, text_limit) |
|
|
| browser = SimpleTextBrowser(**BROWSER_CONFIG) |
|
|
| WEB_TOOLS = [ |
| GoogleSearchTool(provider="serper"), |
| VisitTool(browser), |
| PageUpTool(browser), |
| PageDownTool(browser), |
| FinderTool(browser), |
| FindNextTool(browser), |
| ArchiveSearchTool(browser), |
| TextInspectorTool(model, text_limit), |
| ] |
|
|
| text_webbrowser_agent = ToolCallingAgent( |
| model=model, |
| tools=WEB_TOOLS, |
| max_steps=20, |
| verbosity_level=2, |
| planning_interval=4, |
| name="search_agent", |
| description="""A team member that will search the internet to answer your question. |
| Ask him for all your questions that require browsing the web. |
| Provide him as much context as possible, in particular if you need to search on a specific timeframe! |
| And don't hesitate to provide him with a complex search task, like finding a difference between two webpages. |
| Your request must be a real sentence, not a google search! Like "Find me this information (...)" rather than a few keywords. |
| """, |
| provide_run_summary=True, |
| ) |
| text_webbrowser_agent.prompt_templates["managed_agent"]["task"] += """You can navigate to .txt online files. |
| If a non-html page is in another format, especially .pdf or a Youtube video, use tool 'inspect_file_as_text' to inspect it. |
| Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information.""" |
|
|
| manager_agent = CodeAgent( |
| model=model, |
| tools=[visualizer, ti_tool], |
| max_steps=12, |
| verbosity_level=2, |
| additional_authorized_imports=["*"], |
| planning_interval=4, |
| managed_agents=[text_webbrowser_agent], |
| ) |
| return manager_agent |
|
|
|
|
| def load_gaia_dataset(use_raw_dataset: bool, set_to_run: str) -> datasets.Dataset: |
| if not os.path.exists("data/gaia"): |
| if use_raw_dataset: |
| snapshot_download( |
| repo_id="gaia-benchmark/GAIA", |
| repo_type="dataset", |
| local_dir="data/gaia", |
| ignore_patterns=[".gitattributes", "README.md"], |
| ) |
| else: |
| |
| snapshot_download( |
| repo_id="smolagents/GAIA-annotated", |
| repo_type="dataset", |
| local_dir="data/gaia", |
| ignore_patterns=[".gitattributes", "README.md"], |
| ) |
|
|
| def preprocess_file_paths(row): |
| if len(row["file_name"]) > 0: |
| row["file_name"] = f"data/gaia/{set_to_run}/" + row["file_name"] |
| return row |
|
|
| eval_ds = datasets.load_dataset( |
| "data/gaia/GAIA.py", |
| name="2023_all", |
| split=set_to_run, |
| |
| ) |
|
|
| eval_ds = eval_ds.rename_columns({"Question": "question", "Final answer": "true_answer", "Level": "task"}) |
| eval_ds = eval_ds.map(preprocess_file_paths) |
| return eval_ds |
|
|
|
|
| def append_answer(entry: dict, jsonl_file: str) -> None: |
| jsonl_path = Path(jsonl_file) |
| jsonl_path.parent.mkdir(parents=True, exist_ok=True) |
| with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp: |
| fp.write(json.dumps(entry) + "\n") |
| assert jsonl_path.exists(), "File not found!" |
| print("Answer exported to file:", jsonl_path.resolve()) |
|
|
|
|
| def answer_single_question( |
| example: dict, model_id: str, answers_file: str, visual_inspection_tool: TextInspectorTool |
| ) -> None: |
| model_params: dict[str, Any] = { |
| "model_id": model_id, |
| "custom_role_conversions": custom_role_conversions, |
| } |
| if model_id == "o1": |
| model_params["reasoning_effort"] = "high" |
| model_params["max_completion_tokens"] = 8192 |
| else: |
| model_params["max_tokens"] = 4096 |
| model = LiteLLMModel(**model_params) |
| |
| document_inspection_tool = TextInspectorTool(model, 100000) |
|
|
| agent = create_agent_team(model) |
|
|
| augmented_question = """You have one question to answer. It is paramount that you provide a correct answer. |
| Give it all you can: I know for a fact that you have access to all the relevant tools to solve it and find the correct answer (the answer does exist). |
| Failure or 'I cannot answer' or 'None found' will not be tolerated, success will be rewarded. |
| Run verification steps if that's needed, you must make sure you find the correct answer! Here is the task: |
| |
| """ + example["question"] |
|
|
| if example["file_name"]: |
| if ".zip" in example["file_name"]: |
| prompt_use_files = "\n\nTo solve the task above, you will have to use these attached files:\n" |
| prompt_use_files += get_zip_description( |
| example["file_name"], example["question"], visual_inspection_tool, document_inspection_tool |
| ) |
| else: |
| prompt_use_files = "\n\nTo solve the task above, you will have to use this attached file:\n" |
| prompt_use_files += get_single_file_description( |
| example["file_name"], example["question"], visual_inspection_tool, document_inspection_tool |
| ) |
| augmented_question += prompt_use_files |
|
|
| start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| try: |
| |
| final_result = agent.run(augmented_question) |
|
|
| agent_memory = agent.write_memory_to_messages() |
|
|
| final_result = prepare_response(augmented_question, agent_memory, reformulation_model=model) |
|
|
| output = str(final_result) |
| for memory_step in agent.memory.steps: |
| memory_step.model_input_messages = None |
| intermediate_steps = agent_memory |
|
|
| |
| parsing_error = True if any(["AgentParsingError" in step for step in intermediate_steps]) else False |
|
|
| |
| iteration_limit_exceeded = True if "Agent stopped due to iteration limit or time limit." in output else False |
| raised_exception = False |
|
|
| except Exception as e: |
| print("Error on ", augmented_question, e) |
| output = None |
| intermediate_steps = [] |
| parsing_error = False |
| iteration_limit_exceeded = False |
| exception = e |
| raised_exception = True |
| end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| token_counts_manager = agent.monitor.get_total_token_counts() |
| token_counts_web = list(agent.managed_agents.values())[0].monitor.get_total_token_counts() |
| total_token_counts = { |
| "input": token_counts_manager["input"] + token_counts_web["input"], |
| "output": token_counts_manager["output"] + token_counts_web["output"], |
| } |
| annotated_example = { |
| "agent_name": model.model_id, |
| "question": example["question"], |
| "augmented_question": augmented_question, |
| "prediction": output, |
| "intermediate_steps": intermediate_steps, |
| "parsing_error": parsing_error, |
| "iteration_limit_exceeded": iteration_limit_exceeded, |
| "agent_error": str(exception) if raised_exception else None, |
| "task": example["task"], |
| "task_id": example["task_id"], |
| "true_answer": example["true_answer"], |
| "start_time": start_time, |
| "end_time": end_time, |
| "token_counts": total_token_counts, |
| } |
| append_answer(annotated_example, answers_file) |
|
|
|
|
| def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset) -> list[dict]: |
| print(f"Loading answers from {answers_file}...") |
| try: |
| done_questions = pd.read_json(answers_file, lines=True)["question"].tolist() |
| print(f"Found {len(done_questions)} previous results!") |
| except Exception as e: |
| print("Error when loading records: ", e) |
| print("No usable records! ▶️ Starting new.") |
| done_questions = [] |
| return [line for line in eval_ds.to_list() if line["question"] not in done_questions and line["file_name"]] |
|
|
|
|
| def main(): |
| args = parse_args() |
| print(f"Starting run with arguments: {args}") |
|
|
| eval_ds = load_gaia_dataset(args.use_raw_dataset, args.set_to_run) |
| print("Loaded evaluation dataset:") |
| print(pd.DataFrame(eval_ds)["task"].value_counts()) |
|
|
| answers_file = f"output/{args.set_to_run}/{args.run_name}.jsonl" |
| tasks_to_run = get_examples_to_answer(answers_file, eval_ds) |
|
|
| with ThreadPoolExecutor(max_workers=args.concurrency) as exe: |
| futures = [ |
| exe.submit(answer_single_question, example, args.model_id, answers_file, visualizer) |
| for example in tasks_to_run |
| ] |
| for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"): |
| f.result() |
|
|
| |
| |
| print("All tasks processed.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|