""" HF Agents Course Unit 4 Final Assignment app.py - Part 1 Requirements: - smolagents==1.21.3 - LiteLLMModel - Cerebras GPT-OSS - DuckDuckGoSearchTool - VisitWebpageTool """ from __future__ import annotations import hashlib import json import logging import os import pickle import re import time from pathlib import Path from typing import Dict, Any, Optional from smolagents import ( CodeAgent, LiteLLMModel, ToolCallingAgent, ) from smolagents import ( DuckDuckGoSearchTool, VisitWebpageTool, ) ############################################################################### # Configuration ############################################################################### logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", ) logger = logging.getLogger("gaia-agent") CACHE_DIR = Path(".cache") CACHE_DIR.mkdir(exist_ok=True) MAX_RETRIES = 3 RETRY_DELAY = 2 MODEL_NAME = os.getenv( "MODEL_ID", "cerebras/gpt-oss-120b", ) API_KEY = os.getenv("CEREBRAS_API_KEY") ############################################################################### # LiteLLM Model ############################################################################### model = LiteLLMModel( model_id=MODEL_NAME, api_key=API_KEY, temperature=0.2, max_tokens=4096, ) ############################################################################### # Tools ############################################################################### search_tool = DuckDuckGoSearchTool() visit_tool = VisitWebpageTool() TOOLS = [ search_tool, visit_tool, ] ############################################################################### # Cache ############################################################################### class FileCache: def __init__(self, folder: Path): self.folder = folder def _path(self, key: str): digest = hashlib.sha256( key.encode() ).hexdigest() return self.folder / f"{digest}.pkl" def get(self, key): path = self._path(key) if path.exists(): with open(path, "rb") as f: return pickle.load(f) return None def set(self, key, value): with open(self._path(key), "wb") as f: pickle.dump(value, f) cache = FileCache(CACHE_DIR) ############################################################################### # Retry helper ############################################################################### def retry(fn): for attempt in range(MAX_RETRIES): try: return fn() except Exception as e: logger.warning( "Attempt %d failed: %s", attempt + 1, e, ) if attempt == MAX_RETRIES - 1: raise time.sleep(RETRY_DELAY) ############################################################################### # Cleaning ############################################################################### ANSWER_PATTERNS = [ r"^Answer\s*:", r"^Final Answer\s*:", r"```", ] def clean_answer(answer: str) -> str: answer = answer.strip() for pattern in ANSWER_PATTERNS: answer = re.sub( pattern, "", answer, flags=re.IGNORECASE, ) answer = answer.strip() return answer ############################################################################### # Routing ############################################################################### def classify_question(question: str): q = question.lower() if any( word in q for word in [ "who", "when", "where", "latest", "website", "news", "search", ] ): return "web" return "reasoning" ############################################################################### # Agent Factory ############################################################################### def build_web_agent(): return ToolCallingAgent( tools=TOOLS, model=model, max_steps=8, ) def build_reasoning_agent(): return CodeAgent( tools=TOOLS, model=model, max_steps=10, ) web_agent = build_web_agent() reasoning_agent = build_reasoning_agent() ############################################################################### # Hybrid Router ############################################################################### class HybridGAIAAgent: """ Routes questions between a web-search oriented agent and a reasoning-oriented agent. """ def __init__( self, web_agent, reasoning_agent, cache, ): self.web_agent = web_agent self.reasoning_agent = reasoning_agent self.cache = cache def _run_agent(self, agent, prompt: str) -> str: """ Execute an agent with retries. """ def _execute(): return agent.run(prompt) return retry(_execute) def answer(self, question: str) -> str: """ Main inference entrypoint. """ cache_key = question.strip() cached = self.cache.get(cache_key) if cached is not None: logger.info("Cache hit.") return cached route = classify_question(question) logger.info("Selected route: %s", route) if route == "web": raw_answer = self._run_agent( self.web_agent, question, ) else: raw_answer = self._run_agent( self.reasoning_agent, question, ) cleaned = clean_answer(str(raw_answer)) self.cache.set(cache_key, cleaned) return cleaned ############################################################################### # Prompt Templates ############################################################################### SYSTEM_PROMPT = """ You are an expert GAIA benchmark assistant. Guidelines: - Think carefully. - Search the web whenever necessary. - Visit webpages when search results require deeper inspection. - Never fabricate facts. - Return only the final answer. """ WEB_PROMPT = """ Use search tools whenever required. Question: {question} """ REASONING_PROMPT = """ Solve the problem carefully. Question: {question} """ ############################################################################### # Formatting Utilities ############################################################################### def build_prompt(question: str) -> str: route = classify_question(question) if route == "web": body = WEB_PROMPT.format( question=question, ) else: body = REASONING_PROMPT.format( question=question, ) return ( SYSTEM_PROMPT + "\n\n" + body ) ############################################################################### # JSON Helpers ############################################################################### def safe_json_loads(text: str) -> Optional[Dict[str, Any]]: try: return json.loads(text) except Exception: return None def looks_like_json(text: str) -> bool: text = text.strip() return ( text.startswith("{") and text.endswith("}") ) ############################################################################### # Validation ############################################################################### def validate_answer(answer: str) -> str: if answer is None: return "" answer = str(answer) answer = clean_answer(answer) if looks_like_json(answer): parsed = safe_json_loads(answer) if parsed is not None: if "answer" in parsed: return str(parsed["answer"]).strip() if "final_answer" in parsed: return str(parsed["final_answer"]).strip() return answer.strip() ############################################################################### # Instantiate Hybrid Agent ############################################################################### hybrid_agent = HybridGAIAAgent( web_agent=web_agent, reasoning_agent=reasoning_agent, cache=cache, ) ############################################################################### # GAIA Solver ############################################################################### class GAIASolver: """ High-level wrapper around the HybridGAIAAgent. Responsible for preparing prompts, handling retries, validating answers, and providing a stable interface. """ def __init__(self, agent: HybridGAIAAgent): self.agent = agent def solve(self, question: str) -> str: prompt = build_prompt(question) logger.info("=" * 80) logger.info("Incoming Question") logger.info(question) logger.info("=" * 80) answer = retry( lambda: self.agent.answer(prompt) ) answer = validate_answer(answer) logger.info("Final Answer:") logger.info(answer) return answer ############################################################################### # Statistics ############################################################################### class AgentStatistics: def __init__(self): self.total_requests = 0 self.cache_hits = 0 self.failures = 0 def request(self): self.total_requests += 1 def cache_hit(self): self.cache_hits += 1 def failure(self): self.failures += 1 def summary(self): return { "requests": self.total_requests, "cache_hits": self.cache_hits, "failures": self.failures, } stats = AgentStatistics() ############################################################################### # Public API ############################################################################### solver = GAIASolver(hybrid_agent) def solve(question: str) -> str: """ Main API expected by evaluation scripts. """ stats.request() cache_value = cache.get(question) if cache_value is not None: stats.cache_hit() return cache_value try: answer = solver.solve(question) cache.set(question, answer) return answer except Exception as exc: stats.failure() logger.exception(exc) return "" ############################################################################### # Batch Solver ############################################################################### def solve_batch(questions): outputs = [] for question in questions: outputs.append( solve(question) ) return outputs ############################################################################### # Optional CLI Utilities ############################################################################### def interactive(): print("=" * 70) print("Hybrid GAIA Agent") print("Type 'exit' to quit.") print("=" * 70) while True: question = input("\nQuestion> ").strip() if not question: continue if question.lower() in { "exit", "quit", }: break answer = solve(question) print("\nAnswer:") print(answer) ############################################################################### # HF Unit 4 Entry Points ############################################################################### def predict(question: str) -> str: """ Prediction entry point used by many evaluation scripts. """ return solve(question) def answer(question: str) -> str: """ Alias for compatibility with some templates. """ return solve(question) ############################################################################### # Diagnostics ############################################################################### def print_statistics(): print("\n") print("=" * 80) print("Agent Statistics") print("=" * 80) summary = stats.summary() for key, value in summary.items(): print(f"{key:15}: {value}") print("=" * 80) ############################################################################### # Health Check ############################################################################### def healthcheck(): try: logger.info("Running health check...") test_question = "What is 2 + 2?" result = solve(test_question) logger.info("Health check completed.") logger.info(result) return True except Exception as exc: logger.exception(exc) return False ############################################################################### # Main ############################################################################### if __name__ == "__main__": print("GAIA Agent loaded successfully.") if args.healthcheck: ok = healthcheck() print("OK" if ok else "FAILED") elif args.interactive: interactive() elif args.question: print(solve(args.question)) else: interactive() if args.stats: print_statistics()