wahibtim's picture
Update app.py
466c18b verified
import os
import gradio as gr
import requests
import pandas as pd
import time
import io
import re
from smolagents import LiteLLMModel, tool, CodeAgent
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# ====================== TOOLS ======================
@tool
def web_search(query: str) -> str:
"""
Search the web using DuckDuckGo.
Args:
query: The search query string.
"""
try:
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=5))
if not results:
return "No results found."
return "\n".join([f"{r.get('title')}: {r.get('body')}" for r in results])
except Exception as e:
return f"Search failed: {str(e)}"
@tool
def download_and_read_file(task_id: str) -> str:
"""
Downloads the file for a task and returns its content.
Args:
task_id: The unique ID for the task file.
"""
url = f"{DEFAULT_API_URL}/files/{task_id}"
try:
r = requests.get(url, timeout=30)
r.raise_for_status()
content_type = r.headers.get("content-type", "").lower()
if "csv" in content_type or task_id.lower().endswith(".csv"):
df = pd.read_csv(io.BytesIO(r.content))
return f"CSV Content (First 15 rows):\n{df.head(15).to_string()}\n\nColumns: {df.columns.tolist()}"
elif "text" in content_type or task_id.lower().endswith(".txt"):
return f"Text Content (Snippet):\n{r.text[:2000]}"
else:
return (
f"File downloaded. Size: {len(r.content)} bytes. "
f"If this is an image/pdf, use web_search to find related facts about task {task_id}."
)
except Exception as e:
return f"Download failed: {str(e)}"
# ====================== AGENT ======================
class GaiaAgent:
def __init__(self):
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("❌ GROQ_API_KEY secret is not set! Add it in HF Spaces β†’ Settings β†’ Secrets.")
# llama-3.3-70b-versatile is the best free model on Groq for reasoning
self.model = LiteLLMModel(
model_id="groq/llama-3.3-70b-versatile",
api_key=groq_api_key,
)
self.agent = CodeAgent(
tools=[web_search, download_and_read_file],
model=self.model,
add_base_tools=True,
max_steps=12,
)
def clean_answer(self, raw_result: str) -> str:
"""Removes conversational filler that fails the GAIA grader."""
text = str(raw_result).strip()
# Remove common prefixes like "The answer is:"
text = re.sub(
r'^(the answer is|final answer|result is|answer)[:\s]*',
'', text, flags=re.IGNORECASE
)
# Strip trailing punctuation
text = text.strip(".").strip()
return text
def __call__(self, question: str, task_id: str) -> str:
prompt = f"""Task ID: {task_id}
Question: {question}
INSTRUCTIONS:
- Use your tools to find the exact factual answer.
- If the question mentions a file or attachment, call download_and_read_file("{task_id}") first.
- If you need up-to-date facts, use web_search.
- YOUR FINAL ANSWER MUST BE EXTREMELY BRIEF AND EXACT:
* Numbers: just the number, e.g. '42' or '4.52'
* Names: just the name, e.g. 'Marie Curie'
* Dates: just the date, e.g. '1923' or 'July 4, 1776'
* Lists: comma-separated, e.g. 'apple, banana, cherry'
- Do NOT write sentences. Do NOT explain. Just the answer.
"""
try:
result = self.agent.run(prompt)
return self.clean_answer(str(result))
except Exception as e:
print(f"Agent error on task {task_id}: {e}")
return "Unknown"
# ====================== MAIN LOGIC ======================
def run_and_submit_all(profile: gr.OAuthProfile | None):
if not profile:
return "❌ Please Login with Hugging Face first!", None
username = profile.username
print(f"βœ… Logged in as: {username}")
try:
agent = GaiaAgent()
except ValueError as e:
return str(e), None
try:
resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=30)
resp.raise_for_status()
questions = resp.json()
except Exception as e:
return f"❌ Failed to fetch questions: {e}", None
print(f"πŸ“‹ Fetched {len(questions)} questions.")
answers_payload = []
results_log = []
for i, item in enumerate(questions):
t_id = item.get("task_id")
q_text = item.get("question")
print(f"\n--- [{i+1}/{len(questions)}] Task: {t_id} ---")
print(f"Q: {q_text[:120]}...")
answer = agent(q_text, t_id)
print(f"A: {answer}")
answers_payload.append({"task_id": t_id, "submitted_answer": str(answer)})
results_log.append({"Task ID": t_id, "Question": q_text[:80], "Answer": str(answer)})
# Small sleep β€” Groq free tier allows ~30 req/min, no need for 38s waits
time.sleep(3)
# ===== SUBMIT =====
space_id = os.getenv("SPACE_ID", "unknown")
submission_data = {
"username": username,
"agent_code": f"https://huggingface.co/spaces/{space_id}",
"answers": answers_payload,
}
try:
r = requests.post(f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=300)
if r.status_code == 200:
res = r.json()
score = res.get("score", 0)
message = res.get("message", "")
return f"βœ… SCORE: {score}% | {message}", pd.DataFrame(results_log)
else:
return f"❌ Submission Error {r.status_code}: {r.text}", pd.DataFrame(results_log)
except Exception as e:
return f"❌ Submission Failed: {str(e)}", pd.DataFrame(results_log)
# ====================== UI ======================
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# πŸ† GAIA Certificate Agent (Unit 4 Final)")
gr.Markdown(
"**Steps:** 1) Login with HF below β†’ 2) Click Start β†’ 3) Wait ~5 mins β†’ 4) Check your score!\n\n"
"> Make sure `GROQ_API_KEY` is set in your Space **Settings β†’ Secrets**."
)
with gr.Row():
gr.LoginButton()
run_btn = gr.Button("πŸš€ Start Evaluation", variant="primary")
status_output = gr.Textbox(label="Final Result", lines=3)
table_output = gr.DataFrame(label="Answer Log")
run_btn.click(fn=run_and_submit_all, outputs=[status_output, table_output])
if __name__ == "__main__":
demo.launch()