aloks111's picture
Update app.py
ea63030 verified
Raw
History Blame Contribute Delete
13.6 kB
"""
HF Agents Course Unit 4 Final Assignment
app.py - Part 1
Requirements:
- smolagents==1.21.3
- LiteLLMModel
- Cerebras GPT-OSS
- DuckDuckGoSearchTool
- VisitWebpageTool
"""
from __future__ import annotations
import hashlib
import json
import logging
import os
import pickle
import re
import time
from pathlib import Path
from typing import Dict, Any, Optional
from smolagents import (
CodeAgent,
LiteLLMModel,
ToolCallingAgent,
)
from smolagents import (
DuckDuckGoSearchTool,
VisitWebpageTool,
)
###############################################################################
# Configuration
###############################################################################
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
)
logger = logging.getLogger("gaia-agent")
CACHE_DIR = Path(".cache")
CACHE_DIR.mkdir(exist_ok=True)
MAX_RETRIES = 3
RETRY_DELAY = 2
MODEL_NAME = os.getenv(
"MODEL_ID",
"cerebras/gpt-oss-120b",
)
API_KEY = os.getenv("CEREBRAS_API_KEY")
###############################################################################
# LiteLLM Model
###############################################################################
model = LiteLLMModel(
model_id=MODEL_NAME,
api_key=API_KEY,
temperature=0.2,
max_tokens=4096,
)
###############################################################################
# Tools
###############################################################################
search_tool = DuckDuckGoSearchTool()
visit_tool = VisitWebpageTool()
TOOLS = [
search_tool,
visit_tool,
]
###############################################################################
# Cache
###############################################################################
class FileCache:
def __init__(self, folder: Path):
self.folder = folder
def _path(self, key: str):
digest = hashlib.sha256(
key.encode()
).hexdigest()
return self.folder / f"{digest}.pkl"
def get(self, key):
path = self._path(key)
if path.exists():
with open(path, "rb") as f:
return pickle.load(f)
return None
def set(self, key, value):
with open(self._path(key), "wb") as f:
pickle.dump(value, f)
cache = FileCache(CACHE_DIR)
###############################################################################
# Retry helper
###############################################################################
def retry(fn):
for attempt in range(MAX_RETRIES):
try:
return fn()
except Exception as e:
logger.warning(
"Attempt %d failed: %s",
attempt + 1,
e,
)
if attempt == MAX_RETRIES - 1:
raise
time.sleep(RETRY_DELAY)
###############################################################################
# Cleaning
###############################################################################
ANSWER_PATTERNS = [
r"^Answer\s*:",
r"^Final Answer\s*:",
r"```",
]
def clean_answer(answer: str) -> str:
answer = answer.strip()
for pattern in ANSWER_PATTERNS:
answer = re.sub(
pattern,
"",
answer,
flags=re.IGNORECASE,
)
answer = answer.strip()
return answer
###############################################################################
# Routing
###############################################################################
def classify_question(question: str):
q = question.lower()
if any(
word in q
for word in [
"who",
"when",
"where",
"latest",
"website",
"news",
"search",
]
):
return "web"
return "reasoning"
###############################################################################
# Agent Factory
###############################################################################
def build_web_agent():
return ToolCallingAgent(
tools=TOOLS,
model=model,
max_steps=8,
)
def build_reasoning_agent():
return CodeAgent(
tools=TOOLS,
model=model,
max_steps=10,
)
web_agent = build_web_agent()
reasoning_agent = build_reasoning_agent()
###############################################################################
# Hybrid Router
###############################################################################
class HybridGAIAAgent:
"""
Routes questions between a web-search oriented agent and a
reasoning-oriented agent.
"""
def __init__(
self,
web_agent,
reasoning_agent,
cache,
):
self.web_agent = web_agent
self.reasoning_agent = reasoning_agent
self.cache = cache
def _run_agent(self, agent, prompt: str) -> str:
"""
Execute an agent with retries.
"""
def _execute():
return agent.run(prompt)
return retry(_execute)
def answer(self, question: str) -> str:
"""
Main inference entrypoint.
"""
cache_key = question.strip()
cached = self.cache.get(cache_key)
if cached is not None:
logger.info("Cache hit.")
return cached
route = classify_question(question)
logger.info("Selected route: %s", route)
if route == "web":
raw_answer = self._run_agent(
self.web_agent,
question,
)
else:
raw_answer = self._run_agent(
self.reasoning_agent,
question,
)
cleaned = clean_answer(str(raw_answer))
self.cache.set(cache_key, cleaned)
return cleaned
###############################################################################
# Prompt Templates
###############################################################################
SYSTEM_PROMPT = """
You are an expert GAIA benchmark assistant.
Guidelines:
- Think carefully.
- Search the web whenever necessary.
- Visit webpages when search results require deeper inspection.
- Never fabricate facts.
- Return only the final answer.
"""
WEB_PROMPT = """
Use search tools whenever required.
Question:
{question}
"""
REASONING_PROMPT = """
Solve the problem carefully.
Question:
{question}
"""
###############################################################################
# Formatting Utilities
###############################################################################
def build_prompt(question: str) -> str:
route = classify_question(question)
if route == "web":
body = WEB_PROMPT.format(
question=question,
)
else:
body = REASONING_PROMPT.format(
question=question,
)
return (
SYSTEM_PROMPT
+ "\n\n"
+ body
)
###############################################################################
# JSON Helpers
###############################################################################
def safe_json_loads(text: str) -> Optional[Dict[str, Any]]:
try:
return json.loads(text)
except Exception:
return None
def looks_like_json(text: str) -> bool:
text = text.strip()
return (
text.startswith("{")
and text.endswith("}")
)
###############################################################################
# Validation
###############################################################################
def validate_answer(answer: str) -> str:
if answer is None:
return ""
answer = str(answer)
answer = clean_answer(answer)
if looks_like_json(answer):
parsed = safe_json_loads(answer)
if parsed is not None:
if "answer" in parsed:
return str(parsed["answer"]).strip()
if "final_answer" in parsed:
return str(parsed["final_answer"]).strip()
return answer.strip()
###############################################################################
# Instantiate Hybrid Agent
###############################################################################
hybrid_agent = HybridGAIAAgent(
web_agent=web_agent,
reasoning_agent=reasoning_agent,
cache=cache,
)
###############################################################################
# GAIA Solver
###############################################################################
class GAIASolver:
"""
High-level wrapper around the HybridGAIAAgent.
Responsible for preparing prompts, handling retries,
validating answers, and providing a stable interface.
"""
def __init__(self, agent: HybridGAIAAgent):
self.agent = agent
def solve(self, question: str) -> str:
prompt = build_prompt(question)
logger.info("=" * 80)
logger.info("Incoming Question")
logger.info(question)
logger.info("=" * 80)
answer = retry(
lambda: self.agent.answer(prompt)
)
answer = validate_answer(answer)
logger.info("Final Answer:")
logger.info(answer)
return answer
###############################################################################
# Statistics
###############################################################################
class AgentStatistics:
def __init__(self):
self.total_requests = 0
self.cache_hits = 0
self.failures = 0
def request(self):
self.total_requests += 1
def cache_hit(self):
self.cache_hits += 1
def failure(self):
self.failures += 1
def summary(self):
return {
"requests": self.total_requests,
"cache_hits": self.cache_hits,
"failures": self.failures,
}
stats = AgentStatistics()
###############################################################################
# Public API
###############################################################################
solver = GAIASolver(hybrid_agent)
def solve(question: str) -> str:
"""
Main API expected by evaluation scripts.
"""
stats.request()
cache_value = cache.get(question)
if cache_value is not None:
stats.cache_hit()
return cache_value
try:
answer = solver.solve(question)
cache.set(question, answer)
return answer
except Exception as exc:
stats.failure()
logger.exception(exc)
return ""
###############################################################################
# Batch Solver
###############################################################################
def solve_batch(questions):
outputs = []
for question in questions:
outputs.append(
solve(question)
)
return outputs
###############################################################################
# Optional CLI Utilities
###############################################################################
def interactive():
print("=" * 70)
print("Hybrid GAIA Agent")
print("Type 'exit' to quit.")
print("=" * 70)
while True:
question = input("\nQuestion> ").strip()
if not question:
continue
if question.lower() in {
"exit",
"quit",
}:
break
answer = solve(question)
print("\nAnswer:")
print(answer)
###############################################################################
# HF Unit 4 Entry Points
###############################################################################
def predict(question: str) -> str:
"""
Prediction entry point used by many evaluation scripts.
"""
return solve(question)
def answer(question: str) -> str:
"""
Alias for compatibility with some templates.
"""
return solve(question)
###############################################################################
# Diagnostics
###############################################################################
def print_statistics():
print("\n")
print("=" * 80)
print("Agent Statistics")
print("=" * 80)
summary = stats.summary()
for key, value in summary.items():
print(f"{key:15}: {value}")
print("=" * 80)
###############################################################################
# Health Check
###############################################################################
def healthcheck():
try:
logger.info("Running health check...")
test_question = "What is 2 + 2?"
result = solve(test_question)
logger.info("Health check completed.")
logger.info(result)
return True
except Exception as exc:
logger.exception(exc)
return False
###############################################################################
# Main
###############################################################################
if __name__ == "__main__":
print("GAIA Agent loaded successfully.")
if args.healthcheck:
ok = healthcheck()
print("OK" if ok else "FAILED")
elif args.interactive:
interactive()
elif args.question:
print(solve(args.question))
else:
interactive()
if args.stats:
print_statistics()