claudi47's picture
Fix Groq auth and add support for GAIA file tasks
a765bf8
import os
import base64
import mimetypes
import requests
import pandas as pd
import gradio as gr
from dotenv import load_dotenv
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
OpenAIServerModel,
WikipediaSearchTool,
VisitWebpageTool,
Tool,
)
load_dotenv()
# --- Constants ---
DEFAULT_API_URL = (
"https://agents-course-unit4-scoring.hf.space"
)
GROQ_API_BASE = "https://api.groq.com/openai/v1"
TEXT_MODEL_ID = "llama-3.3-70b-versatile"
VISION_MODEL_ID = (
"meta-llama/llama-4-scout-17b-16e-instruct"
)
AUDIO_MODEL_ID = "whisper-large-v3"
# Format instructions appended to every question
# so that the agent returns exact-match-friendly
# answers via final_answer().
ANSWER_FORMAT_INSTRUCTIONS = """
IMPORTANT FORMAT INSTRUCTIONS:
Your final_answer must be as concise as possible:
- If the answer is a number, return ONLY the number
(no units, no commas, no $ or % unless asked).
- If the answer is a string, return ONLY the
essential words (no articles like "the"/"a",
no abbreviations for cities, write digits in
plain text unless told otherwise).
- If the answer is a comma separated list, apply
the rules above to each element.
Do NOT include explanations in your final_answer,
just the bare answer."""
# --------------------------------------------------
# Custom tool: download a GAIA task file
# --------------------------------------------------
class GaiaFileFetcherTool(Tool):
"""Downloads the file attached to a GAIA task."""
name = "fetch_task_file"
description = (
"Downloads the file attached to a GAIA task "
"given its task_id. Returns the local path "
"to the downloaded file so you can read it."
)
inputs = {
"task_id": {
"type": "string",
"description": (
"The task_id of the GAIA question "
"whose attached file you need."
),
}
}
output_type = "string"
def __init__(self, api_url: str, **kwargs):
super().__init__(**kwargs)
self.api_url = api_url
def forward(self, task_id: str) -> str:
import requests as _req
import tempfile as _tmp
import mimetypes as _mt
url = f"{self.api_url}/files/{task_id}"
resp = _req.get(url, timeout=30)
resp.raise_for_status()
# Derive a sensible extension from headers
ct = resp.headers.get("Content-Type", "")
ext = _mt.guess_extension(ct.split(";")[0]) or ""
cd = resp.headers.get(
"Content-Disposition", ""
)
fname = ""
if "filename=" in cd:
fname = cd.split("filename=")[-1]
fname = fname.strip('"').strip("'")
if not fname:
fname = f"{task_id}{ext}"
fname = os.path.basename(fname)
path = os.path.join(
_tmp.gettempdir(), fname
)
with open(path, "wb") as f:
f.write(resp.content)
return path
class GroqAudioTranscriptionTool(Tool):
"""Transcribes an audio file with Groq Whisper."""
name = "transcribe_audio_file"
description = (
"Transcribes a local audio file path, such as an "
"MP3 downloaded with fetch_task_file. Returns the "
"plain transcript text."
)
inputs = {
"file_path": {
"type": "string",
"description": "Local path to the audio file.",
}
}
output_type = "string"
def forward(self, file_path: str) -> str:
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise RuntimeError(
"GROQ_API_KEY is required for audio transcription."
)
with open(file_path, "rb") as audio_file:
response = requests.post(
f"{GROQ_API_BASE}/audio/transcriptions",
headers={
"Authorization": f"Bearer {api_key}",
},
files={
"file": (
os.path.basename(file_path),
audio_file,
)
},
data={
"model": AUDIO_MODEL_ID,
"response_format": "json",
"temperature": "0",
},
timeout=120,
)
response.raise_for_status()
return response.json().get("text", "").strip()
class GroqImageAnalysisTool(Tool):
"""Answers questions about a local image with Groq vision."""
name = "analyze_image_file"
description = (
"Analyzes a local image file path and answers a "
"specific visual question about it."
)
inputs = {
"file_path": {
"type": "string",
"description": "Local path to the image file.",
},
"question": {
"type": "string",
"description": "The question to answer about the image.",
},
}
output_type = "string"
def forward(self, file_path: str, question: str) -> str:
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise RuntimeError(
"GROQ_API_KEY is required for image analysis."
)
mime_type = (
mimetypes.guess_type(file_path)[0]
or "application/octet-stream"
)
with open(file_path, "rb") as image_file:
encoded = base64.b64encode(
image_file.read()
).decode("ascii")
response = requests.post(
f"{GROQ_API_BASE}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": VISION_MODEL_ID,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": question,
},
{
"type": "image_url",
"image_url": {
"url": (
f"data:{mime_type};"
f"base64,{encoded}"
)
},
},
],
}
],
"temperature": 0.1,
"max_completion_tokens": 512,
},
timeout=120,
)
response.raise_for_status()
return (
response.json()["choices"][0]["message"]
["content"]
.strip()
)
# --------------------------------------------------
# Agent wrapper
# --------------------------------------------------
class BasicAgent:
def __init__(self):
print("BasicAgent initialized.")
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise RuntimeError(
"Missing GROQ_API_KEY. Add it to your "
"Hugging Face Space secrets or local .env file."
)
model = OpenAIServerModel(
model_id=TEXT_MODEL_ID,
api_base=GROQ_API_BASE,
api_key=groq_api_key,
)
self.file_tool = GaiaFileFetcherTool(
api_url=DEFAULT_API_URL,
)
self.audio_tool = GroqAudioTranscriptionTool()
self.image_tool = GroqImageAnalysisTool()
self.agent = CodeAgent(
model=model,
tools=[
DuckDuckGoSearchTool(),
WikipediaSearchTool(
user_agent="GaiaAgent/1.0"
),
VisitWebpageTool(),
self.file_tool,
self.audio_tool,
self.image_tool,
],
max_steps=15,
verbosity_level=0,
additional_authorized_imports=[
"base64",
"json",
"re",
"csv",
"math",
"statistics",
"datetime",
"collections",
"itertools",
"os",
"pathlib",
"mimetypes",
"pandas",
"openpyxl",
],
)
def __call__(
self,
question: str,
task_id: str,
has_file: bool = False,
) -> str:
# Build the prompt for the agent
prompt = question
if has_file:
prompt += (
f"\n\n[This question has an attached "
f"file. Use the fetch_task_file tool "
f"with task_id='{task_id}' to "
f"download it. If it is audio, use "
f"transcribe_audio_file. If it is an "
f"image, use analyze_image_file. If it "
f"is a spreadsheet, read it with pandas.]"
)
prompt += ANSWER_FORMAT_INSTRUCTIONS
raw = str(self.agent.run(prompt))
return raw.strip()
# --------------------------------------------------
# Gradio: run all & submit
# --------------------------------------------------
def run_and_submit_all(
profile: gr.OAuthProfile | None,
):
"""
Fetches all questions, runs the agent,
submits answers, and displays results.
"""
space_id = os.getenv("SPACE_ID")
if profile:
username = f"{profile.username}"
print(f"User logged in: {username}")
else:
print("User not logged in.")
return (
"Please Login to Hugging Face "
"with the button.",
None,
)
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
# 1. Instantiate Agent
try:
agent = BasicAgent()
except Exception as e:
print(f"Error instantiating agent: {e}")
return f"Error initializing agent: {e}", None
agent_code = (
f"https://huggingface.co/spaces/"
f"{space_id or 'unknown-space'}/tree/main"
)
print(agent_code)
# 2. Fetch Questions
print(
f"Fetching questions from: {questions_url}"
)
try:
response = requests.get(
questions_url, timeout=15
)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
print("Fetched questions list is empty.")
return (
"Fetched questions list is empty "
"or invalid format.",
None,
)
print(
f"Fetched {len(questions_data)} "
f"questions."
)
except requests.exceptions.RequestException as e:
print(f"Error fetching questions: {e}")
return f"Error fetching questions: {e}", None
except requests.exceptions.JSONDecodeError as e:
print(
"Error decoding JSON from questions "
f"endpoint: {e}"
)
print(f"Response text: {response.text[:500]}")
return (
"Error decoding server response "
f"for questions: {e}",
None,
)
except Exception as e:
print(
"Unexpected error fetching "
f"questions: {e}"
)
return (
"Unexpected error fetching "
f"questions: {e}",
None,
)
# 3. Run Agent on each question
results_log = []
answers_payload = []
total = len(questions_data)
print(f"Running agent on {total} questions...")
for i, item in enumerate(questions_data):
task_id = item.get("task_id")
question_text = item.get("question")
if not task_id or question_text is None:
print(
"Skipping item with missing "
f"task_id or question: {item}"
)
continue
# Check if the question has a file
file_name = item.get("file_name", "")
has_file = bool(file_name)
print(
f"[{i+1}/{total}] Task {task_id}"
f"{' (has file)' if has_file else ''}"
)
try:
submitted_answer = agent(
question_text,
task_id,
has_file,
)
answers_payload.append(
{
"task_id": task_id,
"submitted_answer": (
submitted_answer
),
}
)
results_log.append(
{
"Task ID": task_id,
"Question": question_text,
"Submitted Answer": (
submitted_answer
),
}
)
except Exception as e:
print(
f"Error on task {task_id}: {e}"
)
results_log.append(
{
"Task ID": task_id,
"Question": question_text,
"Submitted Answer": (
f"AGENT ERROR: {e}"
),
}
)
if not answers_payload:
print(
"Agent did not produce any answers."
)
return (
"Agent did not produce any answers "
"to submit.",
pd.DataFrame(results_log),
)
# 4. Prepare Submission
submission_data = {
"username": username.strip(),
"agent_code": agent_code,
"answers": answers_payload,
}
status_update = (
f"Agent finished. Submitting "
f"{len(answers_payload)} answers for "
f"user '{username}'..."
)
print(status_update)
# 5. Submit
print(
f"Submitting {len(answers_payload)} "
f"answers to: {submit_url}"
)
try:
response = requests.post(
submit_url,
json=submission_data,
timeout=60,
)
response.raise_for_status()
result_data = response.json()
final_status = (
f"Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: "
f"{result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}"
f"/{result_data.get('total_attempted', '?')}"
f" correct)\n"
f"Message: "
f"{result_data.get('message', 'N/A')}"
)
print("Submission successful.")
results_df = pd.DataFrame(results_log)
return final_status, results_df
except requests.exceptions.HTTPError as e:
error_detail = (
"Server responded with status "
f"{e.response.status_code}."
)
try:
error_json = e.response.json()
error_detail += (
" Detail: "
f"{error_json.get('detail', e.response.text)}"
)
except requests.exceptions.JSONDecodeError:
error_detail += (
f" Response: "
f"{e.response.text[:500]}"
)
status_message = (
f"Submission Failed: {error_detail}"
)
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
except requests.exceptions.Timeout:
status_message = (
"Submission Failed: Request timed out."
)
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
except requests.exceptions.RequestException as e:
status_message = (
f"Submission Failed: Network error - {e}"
)
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
except Exception as e:
status_message = (
"Unexpected error during "
f"submission: {e}"
)
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
# --------------------------------------------------
# Gradio UI
# --------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# GAIA Agent Evaluation Runner")
gr.Markdown(
"""
**Instructions:**
1. Clone this space and customise the agent.
2. Log in with the button below.
3. Click **Run Evaluation & Submit All Answers**.
---
*Processing all 20 questions will take several
minutes. The agent uses web search, Wikipedia,
page fetching, and file download tools.*
"""
)
gr.LoginButton()
run_button = gr.Button(
"Run Evaluation & Submit All Answers"
)
status_output = gr.Textbox(
label="Run Status / Submission Result",
lines=5,
interactive=False,
)
results_table = gr.DataFrame(
label="Questions and Agent Answers",
wrap=True,
)
run_button.click(
fn=run_and_submit_all,
outputs=[status_output, results_table],
)
demo.queue()
if __name__ == "__main__":
print(
"\n" + "-" * 30
+ " App Starting "
+ "-" * 30
)
space_host = os.getenv("SPACE_HOST")
space_id = os.getenv("SPACE_ID")
if space_host:
print(f"✅ SPACE_HOST: {space_host}")
else:
print("ℹ️ SPACE_HOST not found.")
if space_id:
print(f"✅ SPACE_ID: {space_id}")
else:
print("ℹ️ SPACE_ID not found.")
print("-" * 74 + "\n")
print("Launching Gradio Interface...")
demo.launch(debug=True, share=False)