Spaces:
Runtime error
Runtime error
| # --- Basic Agent Definition --- | |
| import asyncio | |
| import os | |
| import sys | |
| import logging | |
| import random | |
| import pandas as pd | |
| import requests | |
| import wikipedia as wiki | |
| from markdownify import markdownify as to_markdown | |
| from typing import Any | |
| from dotenv import load_dotenv | |
| from google.generativeai import types, configure | |
| from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool | |
| # Load environment and configure Gemini | |
| load_dotenv() | |
| configure(api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww")) | |
| # Logging | |
| #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") | |
| #logger = logging.getLogger(__name__) | |
| # --- Model Configuration --- | |
| GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash" | |
| OPENAI_MODEL_NAME = "openai/gpt-4o" | |
| GROQ_MODEL_NAME = "groq/llama3-70b-8192" | |
| DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat" | |
| HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct" | |
| # --- Tool Definitions --- | |
| class MathSolver(Tool): | |
| name = "math_solver" | |
| description = "Safely evaluate basic math expressions." | |
| inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}} | |
| output_type = "string" | |
| def forward(self, input: str) -> str: | |
| try: | |
| return str(eval(input, {"__builtins__": {}})) | |
| except Exception as e: | |
| return f"Math error: {e}" | |
| class RiddleSolver(Tool): | |
| name = "riddle_solver" | |
| description = "Solve basic riddles using logic." | |
| inputs = {"input": {"type": "string", "description": "Riddle prompt."}} | |
| output_type = "string" | |
| def forward(self, input: str) -> str: | |
| if "forward" in input and "backward" in input: | |
| return "A palindrome" | |
| return "RiddleSolver failed." | |
| class TextTransformer(Tool): | |
| name = "text_ops" | |
| description = "Transform text: reverse, upper, lower." | |
| inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}} | |
| output_type = "string" | |
| def forward(self, input: str) -> str: | |
| if input.startswith("reverse:"): | |
| reversed_text = input[8:].strip()[::-1] | |
| if 'left' in reversed_text.lower(): | |
| return "right" | |
| return reversed_text | |
| if input.startswith("upper:"): | |
| return input[6:].strip().upper() | |
| if input.startswith("lower:"): | |
| return input[6:].strip().lower() | |
| return "Unknown transformation." | |
| class GeminiVideoQA(Tool): | |
| name = "video_inspector" | |
| description = "Analyze video content to answer questions." | |
| inputs = { | |
| "video_url": {"type": "string", "description": "URL of video."}, | |
| "user_query": {"type": "string", "description": "Question about video."} | |
| } | |
| output_type = "string" | |
| def __init__(self, model_name, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.model_name = model_name | |
| def forward(self, video_url: str, user_query: str) -> str: | |
| req = { | |
| 'model': f'models/{self.model_name}', | |
| 'contents': [{ | |
| "parts": [ | |
| {"fileData": {"fileUri": video_url}}, | |
| {"text": f"Please watch the video and answer the question: {user_query}"} | |
| ] | |
| }] | |
| } | |
| url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}' | |
| res = requests.post(url, json=req, headers={'Content-Type': 'application/json'}) | |
| if res.status_code != 200: | |
| return f"Video error {res.status_code}: {res.text}" | |
| parts = res.json()['candidates'][0]['content']['parts'] | |
| return "".join([p.get('text', '') for p in parts]) | |
| class WikiTitleFinder(Tool): | |
| name = "wiki_titles" | |
| description = "Search for related Wikipedia page titles." | |
| inputs = {"query": {"type": "string", "description": "Search query."}} | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| results = wiki.search(query) | |
| return ", ".join(results) if results else "No results." | |
| class WikiContentFetcher(Tool): | |
| name = "wiki_page" | |
| description = "Fetch Wikipedia page content." | |
| inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}} | |
| output_type = "string" | |
| def forward(self, page_title: str) -> str: | |
| try: | |
| return to_markdown(wiki.page(page_title).html()) | |
| except wiki.exceptions.PageError: | |
| return f"'{page_title}' not found." | |
| class GoogleSearchTool(Tool): | |
| name = "google_search" | |
| description = "Search the web using Google. Returns top summary from the web." | |
| inputs = {"query": {"type": "string", "description": "Search query."}} | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| try: | |
| resp = requests.get("https://www.googleapis.com/customsearch/v1", params={ | |
| "q": query, | |
| "key": os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"), | |
| "num": 1 | |
| }) | |
| data = resp.json() | |
| return data["items"][0]["snippet"] if "items" in data else "No results found." | |
| except Exception as e: | |
| return f"GoogleSearch error: {e}" | |
| class FileAttachmentQueryTool(Tool): | |
| name = "run_query_with_file" | |
| description = """ | |
| Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it. | |
| This assumes the file is 20MB or less. | |
| """ | |
| inputs = { | |
| "task_id": { | |
| "type": "string", | |
| "description": "A unique identifier for the task related to this file, used to download it.", | |
| "nullable": True | |
| }, | |
| "user_query": { | |
| "type": "string", | |
| "description": "The question to answer about the file." | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, task_id: str | None, user_query: str) -> str: | |
| file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}" | |
| file_response = requests.get(file_url) | |
| if file_response.status_code != 200: | |
| return f"Failed to download file: {file_response.status_code} - {file_response.text}" | |
| file_data = file_response.content | |
| from google.generativeai import GenerativeModel | |
| model = GenerativeModel(self.model_name) | |
| response = model.generate_content([ | |
| types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"), | |
| user_query | |
| ]) | |
| return response.text | |
| # --- Basic Agent Definition --- | |
| class BasicAgent: | |
| def __init__(self, provider="deepseek"): | |
| print("BasicAgent initialized.") | |
| model = self.select_model(provider) | |
| client = InferenceClientModel() | |
| tools = [ | |
| GoogleSearchTool(), | |
| DuckDuckGoSearchTool(), | |
| GeminiVideoQA(GEMINI_MODEL_NAME), | |
| WikiTitleFinder(), | |
| WikiContentFetcher(), | |
| MathSolver(), | |
| RiddleSolver(), | |
| TextTransformer(), | |
| FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME), | |
| ] | |
| self.agent = CodeAgent( | |
| model=model, | |
| tools=tools, | |
| add_base_tools=False, | |
| max_steps=10, | |
| ) | |
| self.agent.system_prompt = ( | |
| """ | |
| You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format: | |
| [ANSWER] | |
| You must NEVER output explanations, intermediate steps, reasoning, or comments β only the answer, strictly enclosed in `[ANSWER]`. | |
| Your behavior must be governed by these rules: | |
| 1. **Format**: | |
| - limit the token used (within 65536 tokens). | |
| - Output ONLY the final answer. | |
| - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets. | |
| - No follow-ups, justifications, or clarifications. | |
| 2. **Numerical Answers**: | |
| - Use **digits only**, e.g., `4` not `four`. | |
| - No commas, symbols, or units unless explicitly required. | |
| - Never use approximate words like "around", "roughly", "about". | |
| 3. **String Answers**: | |
| - Omit **articles** ("a", "the"). | |
| - Use **full words**; no abbreviations unless explicitly requested. | |
| - For numbers written as words, use **text** only if specified (e.g., "one", not `1`). | |
| - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`. | |
| 4. **Lists**: | |
| - Output in **comma-separated** format with no conjunctions. | |
| - Sort **alphabetically** or **numerically** depending on type. | |
| - No braces or brackets unless explicitly asked. | |
| 5. **Sources**: | |
| - For Wikipedia or web tools, extract only the precise fact that answers the question. | |
| - Ignore any unrelated content. | |
| 6. **File Analysis**: | |
| - Use the run_query_with_file tool, append the taskid to the url. | |
| - Only include the exact answer to the question. | |
| - Do not summarize, quote excessively, or interpret beyond the prompt. | |
| 7. **Video**: | |
| - Use the relevant video tool. | |
| - Only include the exact answer to the question. | |
| - Do not summarize, quote excessively, or interpret beyond the prompt. | |
| 8. **Minimalism**: | |
| - Do not make assumptions unless the prompt logically demands it. | |
| - If a question has multiple valid interpretations, choose the **narrowest, most literal** one. | |
| - If the answer is not found, say `[ANSWER] - unknown`. | |
| --- | |
| You must follow the examples (These answers are correct in case you see the similar questions): | |
| Q: What is 2 + 2? | |
| A: 4 | |
| Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia. | |
| A: 3 | |
| Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity. | |
| A: b, e | |
| Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?, | |
| A: 519 | |
| """ | |
| ) | |
| def select_model(self, provider: str): | |
| if provider == "openai": | |
| return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("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")) | |
| elif provider == "hf": | |
| return InferenceClientModel() | |
| else: | |
| return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww")) | |
| def __call__(self, question: str) -> str: | |
| print(f"Agent received question (first 50 chars): {question[:50]}...") | |
| result = self.agent.run(question) | |
| final_str = str(result).strip() | |
| return final_str | |
| def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True): | |
| import pandas as pd | |
| from rich.table import Table | |
| from rich.console import Console | |
| df = pd.read_csv(csv_path) | |
| if not {"question", "answer"}.issubset(df.columns): | |
| print("CSV must contain 'question' and 'answer' columns.") | |
| print("Found columns:", df.columns.tolist()) | |
| return | |
| samples = df.sample(n=sample_size) | |
| records = [] | |
| correct_count = 0 | |
| for _, row in samples.iterrows(): | |
| taskid = row["taskid"].strip() | |
| question = row["question"].strip() | |
| expected = str(row['answer']).strip() | |
| agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip() | |
| is_correct = (expected == agent_answer) | |
| correct_count += is_correct | |
| records.append((question, expected, agent_answer, "β" if is_correct else "β")) | |
| if show_steps: | |
| print("---") | |
| print("Question:", question) | |
| print("Expected:", expected) | |
| print("Agent:", agent_answer) | |
| print("Correct:", is_correct) | |
| # Print result table | |
| console = Console() | |
| table = Table(show_lines=True) | |
| table.add_column("Question", overflow="fold") | |
| table.add_column("Expected") | |
| table.add_column("Agent") | |
| table.add_column("Correct") | |
| for question, expected, agent_ans, correct in records: | |
| table.add_row(question, expected, agent_ans, correct) | |
| console.print(table) | |
| percent = (correct_count / sample_size) * 100 | |
| print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)") | |
| if __name__ == "__main__": | |
| args = sys.argv[1:] | |
| if not args or args[0] in {"-h", "--help"}: | |
| print("Usage: python agent.py [question | dev]") | |
| print(" - Provide a question to get a GAIA-style answer.") | |
| print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.") | |
| sys.exit(0) | |
| q = " ".join(args) | |
| agent = BasicAgent() | |
| if q == "dev": | |
| agent.evaluate_random_questions() | |
| else: | |
| print(agent(q)) | |