rmjones's picture
Update agent.py
870142a verified
# --- Basic Agent Definition ---
import asyncio
import os
import sys
import logging
import random
import pandas as pd
import requests
import wikipedia as wiki
from markdownify import markdownify as to_markdown
from typing import Any
from dotenv import load_dotenv
from google.generativeai import types, configure
from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool, HfApiModel, OpenAIServerModel
# Logging
#logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
#logger = logging.getLogger(__name__)
# --- Model Configuration ---
HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
OPENROUTER_API_KEY2 = "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"
OPENROUTER_API_KEY = "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"
# --- Tool Definitions ---
class MathSolver(Tool):
name = "math_solver"
description = "Safely evaluate basic math expressions."
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
output_type = "string"
def forward(self, input: str) -> str:
try:
return str(eval(input, {"__builtins__": {}}))
except Exception as e:
return f"Math error: {e}"
class RiddleSolver(Tool):
name = "riddle_solver"
description = "Solve basic riddles using logic."
inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
output_type = "string"
def forward(self, input: str) -> str:
if "forward" in input and "backward" in input:
return "A palindrome"
return "RiddleSolver failed."
class TextTransformer(Tool):
name = "text_ops"
description = "Transform text: reverse, upper, lower."
inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
output_type = "string"
def forward(self, input: str) -> str:
if input.startswith("reverse:"):
reversed_text = input[8:].strip()[::-1]
if 'left' in reversed_text.lower():
return "right"
return reversed_text
if input.startswith("upper:"):
return input[6:].strip().upper()
if input.startswith("lower:"):
return input[6:].strip().lower()
return "Unknown transformation."
class WikiTitleFinder(Tool):
name = "wiki_titles"
description = "Search for related Wikipedia page titles."
inputs = {"query": {"type": "string", "description": "Search query."}}
output_type = "string"
def forward(self, query: str) -> str:
results = wiki.search(query)
return ", ".join(results) if results else "No results."
class WikiContentFetcher(Tool):
name = "wiki_page"
description = "Fetch Wikipedia page content."
inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
output_type = "string"
def forward(self, page_title: str) -> str:
try:
return to_markdown(wiki.page(page_title).html())
except wiki.exceptions.PageError:
return f"'{page_title}' not found."
# --- Basic Agent Definition ---
class BasicAgent:
def __init__(self):
print("BasicAgent initialized.")
#model = HF_MODEL_NAME
"""
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded
#model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud',
custom_role_conversions=None,
)
"""
model = OpenAIServerModel(
# You can use any model ID available on OpenRouter
model_id="mistralai/mistral-small-3.2-24b-instruct:free",
# OpenRouter API base URL
api_base="https://openrouter.ai/api/v1",
api_key=OPENROUTER_API_KEY,
)
"""
model_id = "ollama_chat/qwen2:7b"
model = LiteLLMModel(
model_id=model_id,
api_base="http://127.0.0.1:11434",
num_ctx=8192,
)
"""
tools = [
DuckDuckGoSearchTool(),
WikiTitleFinder(),
WikiContentFetcher(),
MathSolver(),
RiddleSolver(),
TextTransformer(),
]
self.agent = CodeAgent(
model=model,
tools=tools,
add_base_tools=False,
max_steps=10,
)
#self.agent.system_prompt
self.agent.prompt_templates["system_prompt"] = (
"""
You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
[ANSWER]
You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
Your behavior must be governed by these rules:
1. **Format**:
- limit the token used (within 65536 tokens).
- Output ONLY the final answer.
- Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
- No follow-ups, justifications, or clarifications.
2. **Numerical Answers**:
- Use **digits only**, e.g., `4` not `four`.
- No commas, symbols, or units unless explicitly required.
- Never use approximate words like "around", "roughly", "about".
3. **String Answers**:
- Omit **articles** ("a", "the").
- Use **full words**; no abbreviations unless explicitly requested.
- For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
- For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
4. **Lists**:
- Output in **comma-separated** format with no conjunctions.
- Sort **alphabetically** or **numerically** depending on type.
- No braces or brackets unless explicitly asked.
5. **Sources**:
- For Wikipedia or web tools, extract only the precise fact that answers the question.
- Ignore any unrelated content.
6. **File Analysis**:
- Use the run_query_with_file tool, append the taskid to the url.
- Only include the exact answer to the question.
- Do not summarize, quote excessively, or interpret beyond the prompt.
7. **Video**:
- Use the relevant video tool.
- Only include the exact answer to the question.
- Do not summarize, quote excessively, or interpret beyond the prompt.
8. **Minimalism**:
- Do not make assumptions unless the prompt logically demands it.
- If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
- If the answer is not found, say `[ANSWER] - unknown`.
---
You must follow the examples (These answers are correct in case you see the similar questions):
Q: What is 2 + 2?
A: 4
Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
A: 3
Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
A: b, e
Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
A: 519
"""
)
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
result = self.agent.run(question)
print("Raw result:", result)
if isinstance(result, dict) and "output" in result:
final_str = str(result["output"]).strip()
elif hasattr(result, "output"):
final_str = str(result.output).strip()
else:
final_str = str(result).strip()
return final_str
def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
import pandas as pd
from rich.table import Table
from rich.console import Console
df = pd.read_csv(csv_path)
if not {"question", "answer"}.issubset(df.columns):
print("CSV must contain 'question' and 'answer' columns.")
print("Found columns:", df.columns.tolist())
return
samples = df.sample(n=sample_size)
records = []
correct_count = 0
for _, row in samples.iterrows():
taskid = row["taskid"].strip()
question = row["question"].strip()
expected = str(row['answer']).strip()
agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
is_correct = (expected == agent_answer)
correct_count += is_correct
records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
if show_steps:
print("---")
print("Question:", question)
print("Expected:", expected)
print("Agent:", agent_answer)
print("Correct:", is_correct)
# Print result table
console = Console()
table = Table(show_lines=True)
table.add_column("Question", overflow="fold")
table.add_column("Expected")
table.add_column("Agent")
table.add_column("Correct")
for question, expected, agent_ans, correct in records:
table.add_row(question, expected, agent_ans, correct)
console.print(table)
percent = (correct_count / sample_size) * 100
print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
if __name__ == "__main__":
args = sys.argv[1:]
if not args or args[0] in {"-h", "--help"}:
print("Usage: python agent.py [question | dev]")
print(" - Provide a question to get a GAIA-style answer.")
print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
sys.exit(0)
q = " ".join(args)
agent = BasicAgent()
if q == "dev":
agent.evaluate_random_questions()
else:
print(agent(q))