Spaces:
Runtime error
Runtime error
| """ | |
| HF Agents Course Unit 4 Final Assignment | |
| app.py - Part 1 | |
| Requirements: | |
| - smolagents==1.21.3 | |
| - LiteLLMModel | |
| - Cerebras GPT-OSS | |
| - DuckDuckGoSearchTool | |
| - VisitWebpageTool | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import pickle | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional | |
| from smolagents import ( | |
| CodeAgent, | |
| LiteLLMModel, | |
| ToolCallingAgent, | |
| ) | |
| from smolagents import ( | |
| DuckDuckGoSearchTool, | |
| VisitWebpageTool, | |
| ) | |
| ############################################################################### | |
| # Configuration | |
| ############################################################################### | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(message)s", | |
| ) | |
| logger = logging.getLogger("gaia-agent") | |
| CACHE_DIR = Path(".cache") | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| MAX_RETRIES = 3 | |
| RETRY_DELAY = 2 | |
| MODEL_NAME = os.getenv( | |
| "MODEL_ID", | |
| "cerebras/gpt-oss-120b", | |
| ) | |
| API_KEY = os.getenv("CEREBRAS_API_KEY") | |
| ############################################################################### | |
| # LiteLLM Model | |
| ############################################################################### | |
| model = LiteLLMModel( | |
| model_id=MODEL_NAME, | |
| api_key=API_KEY, | |
| temperature=0.2, | |
| max_tokens=4096, | |
| ) | |
| ############################################################################### | |
| # Tools | |
| ############################################################################### | |
| search_tool = DuckDuckGoSearchTool() | |
| visit_tool = VisitWebpageTool() | |
| TOOLS = [ | |
| search_tool, | |
| visit_tool, | |
| ] | |
| ############################################################################### | |
| # Cache | |
| ############################################################################### | |
| class FileCache: | |
| def __init__(self, folder: Path): | |
| self.folder = folder | |
| def _path(self, key: str): | |
| digest = hashlib.sha256( | |
| key.encode() | |
| ).hexdigest() | |
| return self.folder / f"{digest}.pkl" | |
| def get(self, key): | |
| path = self._path(key) | |
| if path.exists(): | |
| with open(path, "rb") as f: | |
| return pickle.load(f) | |
| return None | |
| def set(self, key, value): | |
| with open(self._path(key), "wb") as f: | |
| pickle.dump(value, f) | |
| cache = FileCache(CACHE_DIR) | |
| ############################################################################### | |
| # Retry helper | |
| ############################################################################### | |
| def retry(fn): | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| return fn() | |
| except Exception as e: | |
| logger.warning( | |
| "Attempt %d failed: %s", | |
| attempt + 1, | |
| e, | |
| ) | |
| if attempt == MAX_RETRIES - 1: | |
| raise | |
| time.sleep(RETRY_DELAY) | |
| ############################################################################### | |
| # Cleaning | |
| ############################################################################### | |
| ANSWER_PATTERNS = [ | |
| r"^Answer\s*:", | |
| r"^Final Answer\s*:", | |
| r"```", | |
| ] | |
| def clean_answer(answer: str) -> str: | |
| answer = answer.strip() | |
| for pattern in ANSWER_PATTERNS: | |
| answer = re.sub( | |
| pattern, | |
| "", | |
| answer, | |
| flags=re.IGNORECASE, | |
| ) | |
| answer = answer.strip() | |
| return answer | |
| ############################################################################### | |
| # Routing | |
| ############################################################################### | |
| def classify_question(question: str): | |
| q = question.lower() | |
| if any( | |
| word in q | |
| for word in [ | |
| "who", | |
| "when", | |
| "where", | |
| "latest", | |
| "website", | |
| "news", | |
| "search", | |
| ] | |
| ): | |
| return "web" | |
| return "reasoning" | |
| ############################################################################### | |
| # Agent Factory | |
| ############################################################################### | |
| def build_web_agent(): | |
| return ToolCallingAgent( | |
| tools=TOOLS, | |
| model=model, | |
| max_steps=8, | |
| ) | |
| def build_reasoning_agent(): | |
| return CodeAgent( | |
| tools=TOOLS, | |
| model=model, | |
| max_steps=10, | |
| ) | |
| web_agent = build_web_agent() | |
| reasoning_agent = build_reasoning_agent() | |
| ############################################################################### | |
| # Hybrid Router | |
| ############################################################################### | |
| class HybridGAIAAgent: | |
| """ | |
| Routes questions between a web-search oriented agent and a | |
| reasoning-oriented agent. | |
| """ | |
| def __init__( | |
| self, | |
| web_agent, | |
| reasoning_agent, | |
| cache, | |
| ): | |
| self.web_agent = web_agent | |
| self.reasoning_agent = reasoning_agent | |
| self.cache = cache | |
| def _run_agent(self, agent, prompt: str) -> str: | |
| """ | |
| Execute an agent with retries. | |
| """ | |
| def _execute(): | |
| return agent.run(prompt) | |
| return retry(_execute) | |
| def answer(self, question: str) -> str: | |
| """ | |
| Main inference entrypoint. | |
| """ | |
| cache_key = question.strip() | |
| cached = self.cache.get(cache_key) | |
| if cached is not None: | |
| logger.info("Cache hit.") | |
| return cached | |
| route = classify_question(question) | |
| logger.info("Selected route: %s", route) | |
| if route == "web": | |
| raw_answer = self._run_agent( | |
| self.web_agent, | |
| question, | |
| ) | |
| else: | |
| raw_answer = self._run_agent( | |
| self.reasoning_agent, | |
| question, | |
| ) | |
| cleaned = clean_answer(str(raw_answer)) | |
| self.cache.set(cache_key, cleaned) | |
| return cleaned | |
| ############################################################################### | |
| # Prompt Templates | |
| ############################################################################### | |
| SYSTEM_PROMPT = """ | |
| You are an expert GAIA benchmark assistant. | |
| Guidelines: | |
| - Think carefully. | |
| - Search the web whenever necessary. | |
| - Visit webpages when search results require deeper inspection. | |
| - Never fabricate facts. | |
| - Return only the final answer. | |
| """ | |
| WEB_PROMPT = """ | |
| Use search tools whenever required. | |
| Question: | |
| {question} | |
| """ | |
| REASONING_PROMPT = """ | |
| Solve the problem carefully. | |
| Question: | |
| {question} | |
| """ | |
| ############################################################################### | |
| # Formatting Utilities | |
| ############################################################################### | |
| def build_prompt(question: str) -> str: | |
| route = classify_question(question) | |
| if route == "web": | |
| body = WEB_PROMPT.format( | |
| question=question, | |
| ) | |
| else: | |
| body = REASONING_PROMPT.format( | |
| question=question, | |
| ) | |
| return ( | |
| SYSTEM_PROMPT | |
| + "\n\n" | |
| + body | |
| ) | |
| ############################################################################### | |
| # JSON Helpers | |
| ############################################################################### | |
| def safe_json_loads(text: str) -> Optional[Dict[str, Any]]: | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| return None | |
| def looks_like_json(text: str) -> bool: | |
| text = text.strip() | |
| return ( | |
| text.startswith("{") | |
| and text.endswith("}") | |
| ) | |
| ############################################################################### | |
| # Validation | |
| ############################################################################### | |
| def validate_answer(answer: str) -> str: | |
| if answer is None: | |
| return "" | |
| answer = str(answer) | |
| answer = clean_answer(answer) | |
| if looks_like_json(answer): | |
| parsed = safe_json_loads(answer) | |
| if parsed is not None: | |
| if "answer" in parsed: | |
| return str(parsed["answer"]).strip() | |
| if "final_answer" in parsed: | |
| return str(parsed["final_answer"]).strip() | |
| return answer.strip() | |
| ############################################################################### | |
| # Instantiate Hybrid Agent | |
| ############################################################################### | |
| hybrid_agent = HybridGAIAAgent( | |
| web_agent=web_agent, | |
| reasoning_agent=reasoning_agent, | |
| cache=cache, | |
| ) | |
| ############################################################################### | |
| # GAIA Solver | |
| ############################################################################### | |
| class GAIASolver: | |
| """ | |
| High-level wrapper around the HybridGAIAAgent. | |
| Responsible for preparing prompts, handling retries, | |
| validating answers, and providing a stable interface. | |
| """ | |
| def __init__(self, agent: HybridGAIAAgent): | |
| self.agent = agent | |
| def solve(self, question: str) -> str: | |
| prompt = build_prompt(question) | |
| logger.info("=" * 80) | |
| logger.info("Incoming Question") | |
| logger.info(question) | |
| logger.info("=" * 80) | |
| answer = retry( | |
| lambda: self.agent.answer(prompt) | |
| ) | |
| answer = validate_answer(answer) | |
| logger.info("Final Answer:") | |
| logger.info(answer) | |
| return answer | |
| ############################################################################### | |
| # Statistics | |
| ############################################################################### | |
| class AgentStatistics: | |
| def __init__(self): | |
| self.total_requests = 0 | |
| self.cache_hits = 0 | |
| self.failures = 0 | |
| def request(self): | |
| self.total_requests += 1 | |
| def cache_hit(self): | |
| self.cache_hits += 1 | |
| def failure(self): | |
| self.failures += 1 | |
| def summary(self): | |
| return { | |
| "requests": self.total_requests, | |
| "cache_hits": self.cache_hits, | |
| "failures": self.failures, | |
| } | |
| stats = AgentStatistics() | |
| ############################################################################### | |
| # Public API | |
| ############################################################################### | |
| solver = GAIASolver(hybrid_agent) | |
| def solve(question: str) -> str: | |
| """ | |
| Main API expected by evaluation scripts. | |
| """ | |
| stats.request() | |
| cache_value = cache.get(question) | |
| if cache_value is not None: | |
| stats.cache_hit() | |
| return cache_value | |
| try: | |
| answer = solver.solve(question) | |
| cache.set(question, answer) | |
| return answer | |
| except Exception as exc: | |
| stats.failure() | |
| logger.exception(exc) | |
| return "" | |
| ############################################################################### | |
| # Batch Solver | |
| ############################################################################### | |
| def solve_batch(questions): | |
| outputs = [] | |
| for question in questions: | |
| outputs.append( | |
| solve(question) | |
| ) | |
| return outputs | |
| ############################################################################### | |
| # Optional CLI Utilities | |
| ############################################################################### | |
| def interactive(): | |
| print("=" * 70) | |
| print("Hybrid GAIA Agent") | |
| print("Type 'exit' to quit.") | |
| print("=" * 70) | |
| while True: | |
| question = input("\nQuestion> ").strip() | |
| if not question: | |
| continue | |
| if question.lower() in { | |
| "exit", | |
| "quit", | |
| }: | |
| break | |
| answer = solve(question) | |
| print("\nAnswer:") | |
| print(answer) | |
| ############################################################################### | |
| # HF Unit 4 Entry Points | |
| ############################################################################### | |
| def predict(question: str) -> str: | |
| """ | |
| Prediction entry point used by many evaluation scripts. | |
| """ | |
| return solve(question) | |
| def answer(question: str) -> str: | |
| """ | |
| Alias for compatibility with some templates. | |
| """ | |
| return solve(question) | |
| ############################################################################### | |
| # Diagnostics | |
| ############################################################################### | |
| def print_statistics(): | |
| print("\n") | |
| print("=" * 80) | |
| print("Agent Statistics") | |
| print("=" * 80) | |
| summary = stats.summary() | |
| for key, value in summary.items(): | |
| print(f"{key:15}: {value}") | |
| print("=" * 80) | |
| ############################################################################### | |
| # Health Check | |
| ############################################################################### | |
| def healthcheck(): | |
| try: | |
| logger.info("Running health check...") | |
| test_question = "What is 2 + 2?" | |
| result = solve(test_question) | |
| logger.info("Health check completed.") | |
| logger.info(result) | |
| return True | |
| except Exception as exc: | |
| logger.exception(exc) | |
| return False | |
| ############################################################################### | |
| # Main | |
| ############################################################################### | |
| if __name__ == "__main__": | |
| print("GAIA Agent loaded successfully.") | |
| if args.healthcheck: | |
| ok = healthcheck() | |
| print("OK" if ok else "FAILED") | |
| elif args.interactive: | |
| interactive() | |
| elif args.question: | |
| print(solve(args.question)) | |
| else: | |
| interactive() | |
| if args.stats: | |
| print_statistics() |