import os import base64 import mimetypes import requests import pandas as pd import gradio as gr from dotenv import load_dotenv from smolagents import ( CodeAgent, DuckDuckGoSearchTool, OpenAIServerModel, WikipediaSearchTool, VisitWebpageTool, Tool, ) load_dotenv() # --- Constants --- DEFAULT_API_URL = ( "https://agents-course-unit4-scoring.hf.space" ) GROQ_API_BASE = "https://api.groq.com/openai/v1" TEXT_MODEL_ID = "llama-3.3-70b-versatile" VISION_MODEL_ID = ( "meta-llama/llama-4-scout-17b-16e-instruct" ) AUDIO_MODEL_ID = "whisper-large-v3" # Format instructions appended to every question # so that the agent returns exact-match-friendly # answers via final_answer(). ANSWER_FORMAT_INSTRUCTIONS = """ IMPORTANT FORMAT INSTRUCTIONS: Your final_answer must be as concise as possible: - If the answer is a number, return ONLY the number (no units, no commas, no $ or % unless asked). - If the answer is a string, return ONLY the essential words (no articles like "the"/"a", no abbreviations for cities, write digits in plain text unless told otherwise). - If the answer is a comma separated list, apply the rules above to each element. Do NOT include explanations in your final_answer, just the bare answer.""" # -------------------------------------------------- # Custom tool: download a GAIA task file # -------------------------------------------------- class GaiaFileFetcherTool(Tool): """Downloads the file attached to a GAIA task.""" name = "fetch_task_file" description = ( "Downloads the file attached to a GAIA task " "given its task_id. Returns the local path " "to the downloaded file so you can read it." ) inputs = { "task_id": { "type": "string", "description": ( "The task_id of the GAIA question " "whose attached file you need." ), } } output_type = "string" def __init__(self, api_url: str, **kwargs): super().__init__(**kwargs) self.api_url = api_url def forward(self, task_id: str) -> str: import requests as _req import tempfile as _tmp import mimetypes as _mt url = f"{self.api_url}/files/{task_id}" resp = _req.get(url, timeout=30) resp.raise_for_status() # Derive a sensible extension from headers ct = resp.headers.get("Content-Type", "") ext = _mt.guess_extension(ct.split(";")[0]) or "" cd = resp.headers.get( "Content-Disposition", "" ) fname = "" if "filename=" in cd: fname = cd.split("filename=")[-1] fname = fname.strip('"').strip("'") if not fname: fname = f"{task_id}{ext}" fname = os.path.basename(fname) path = os.path.join( _tmp.gettempdir(), fname ) with open(path, "wb") as f: f.write(resp.content) return path class GroqAudioTranscriptionTool(Tool): """Transcribes an audio file with Groq Whisper.""" name = "transcribe_audio_file" description = ( "Transcribes a local audio file path, such as an " "MP3 downloaded with fetch_task_file. Returns the " "plain transcript text." ) inputs = { "file_path": { "type": "string", "description": "Local path to the audio file.", } } output_type = "string" def forward(self, file_path: str) -> str: api_key = os.getenv("GROQ_API_KEY") if not api_key: raise RuntimeError( "GROQ_API_KEY is required for audio transcription." ) with open(file_path, "rb") as audio_file: response = requests.post( f"{GROQ_API_BASE}/audio/transcriptions", headers={ "Authorization": f"Bearer {api_key}", }, files={ "file": ( os.path.basename(file_path), audio_file, ) }, data={ "model": AUDIO_MODEL_ID, "response_format": "json", "temperature": "0", }, timeout=120, ) response.raise_for_status() return response.json().get("text", "").strip() class GroqImageAnalysisTool(Tool): """Answers questions about a local image with Groq vision.""" name = "analyze_image_file" description = ( "Analyzes a local image file path and answers a " "specific visual question about it." ) inputs = { "file_path": { "type": "string", "description": "Local path to the image file.", }, "question": { "type": "string", "description": "The question to answer about the image.", }, } output_type = "string" def forward(self, file_path: str, question: str) -> str: api_key = os.getenv("GROQ_API_KEY") if not api_key: raise RuntimeError( "GROQ_API_KEY is required for image analysis." ) mime_type = ( mimetypes.guess_type(file_path)[0] or "application/octet-stream" ) with open(file_path, "rb") as image_file: encoded = base64.b64encode( image_file.read() ).decode("ascii") response = requests.post( f"{GROQ_API_BASE}/chat/completions", headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, json={ "model": VISION_MODEL_ID, "messages": [ { "role": "user", "content": [ { "type": "text", "text": question, }, { "type": "image_url", "image_url": { "url": ( f"data:{mime_type};" f"base64,{encoded}" ) }, }, ], } ], "temperature": 0.1, "max_completion_tokens": 512, }, timeout=120, ) response.raise_for_status() return ( response.json()["choices"][0]["message"] ["content"] .strip() ) # -------------------------------------------------- # Agent wrapper # -------------------------------------------------- class BasicAgent: def __init__(self): print("BasicAgent initialized.") groq_api_key = os.getenv("GROQ_API_KEY") if not groq_api_key: raise RuntimeError( "Missing GROQ_API_KEY. Add it to your " "Hugging Face Space secrets or local .env file." ) model = OpenAIServerModel( model_id=TEXT_MODEL_ID, api_base=GROQ_API_BASE, api_key=groq_api_key, ) self.file_tool = GaiaFileFetcherTool( api_url=DEFAULT_API_URL, ) self.audio_tool = GroqAudioTranscriptionTool() self.image_tool = GroqImageAnalysisTool() self.agent = CodeAgent( model=model, tools=[ DuckDuckGoSearchTool(), WikipediaSearchTool( user_agent="GaiaAgent/1.0" ), VisitWebpageTool(), self.file_tool, self.audio_tool, self.image_tool, ], max_steps=15, verbosity_level=0, additional_authorized_imports=[ "base64", "json", "re", "csv", "math", "statistics", "datetime", "collections", "itertools", "os", "pathlib", "mimetypes", "pandas", "openpyxl", ], ) def __call__( self, question: str, task_id: str, has_file: bool = False, ) -> str: # Build the prompt for the agent prompt = question if has_file: prompt += ( f"\n\n[This question has an attached " f"file. Use the fetch_task_file tool " f"with task_id='{task_id}' to " f"download it. If it is audio, use " f"transcribe_audio_file. If it is an " f"image, use analyze_image_file. If it " f"is a spreadsheet, read it with pandas.]" ) prompt += ANSWER_FORMAT_INSTRUCTIONS raw = str(self.agent.run(prompt)) return raw.strip() # -------------------------------------------------- # Gradio: run all & submit # -------------------------------------------------- def run_and_submit_all( profile: gr.OAuthProfile | None, ): """ Fetches all questions, runs the agent, submits answers, and displays results. """ space_id = os.getenv("SPACE_ID") if profile: username = f"{profile.username}" print(f"User logged in: {username}") else: print("User not logged in.") return ( "Please Login to Hugging Face " "with the button.", None, ) api_url = DEFAULT_API_URL questions_url = f"{api_url}/questions" submit_url = f"{api_url}/submit" # 1. Instantiate Agent try: agent = BasicAgent() except Exception as e: print(f"Error instantiating agent: {e}") return f"Error initializing agent: {e}", None agent_code = ( f"https://huggingface.co/spaces/" f"{space_id or 'unknown-space'}/tree/main" ) print(agent_code) # 2. Fetch Questions print( f"Fetching questions from: {questions_url}" ) try: response = requests.get( questions_url, timeout=15 ) response.raise_for_status() questions_data = response.json() if not questions_data: print("Fetched questions list is empty.") return ( "Fetched questions list is empty " "or invalid format.", None, ) print( f"Fetched {len(questions_data)} " f"questions." ) except requests.exceptions.RequestException as e: print(f"Error fetching questions: {e}") return f"Error fetching questions: {e}", None except requests.exceptions.JSONDecodeError as e: print( "Error decoding JSON from questions " f"endpoint: {e}" ) print(f"Response text: {response.text[:500]}") return ( "Error decoding server response " f"for questions: {e}", None, ) except Exception as e: print( "Unexpected error fetching " f"questions: {e}" ) return ( "Unexpected error fetching " f"questions: {e}", None, ) # 3. Run Agent on each question results_log = [] answers_payload = [] total = len(questions_data) print(f"Running agent on {total} questions...") for i, item in enumerate(questions_data): task_id = item.get("task_id") question_text = item.get("question") if not task_id or question_text is None: print( "Skipping item with missing " f"task_id or question: {item}" ) continue # Check if the question has a file file_name = item.get("file_name", "") has_file = bool(file_name) print( f"[{i+1}/{total}] Task {task_id}" f"{' (has file)' if has_file else ''}" ) try: submitted_answer = agent( question_text, task_id, has_file, ) answers_payload.append( { "task_id": task_id, "submitted_answer": ( submitted_answer ), } ) results_log.append( { "Task ID": task_id, "Question": question_text, "Submitted Answer": ( submitted_answer ), } ) except Exception as e: print( f"Error on task {task_id}: {e}" ) results_log.append( { "Task ID": task_id, "Question": question_text, "Submitted Answer": ( f"AGENT ERROR: {e}" ), } ) if not answers_payload: print( "Agent did not produce any answers." ) return ( "Agent did not produce any answers " "to submit.", pd.DataFrame(results_log), ) # 4. Prepare Submission submission_data = { "username": username.strip(), "agent_code": agent_code, "answers": answers_payload, } status_update = ( f"Agent finished. Submitting " f"{len(answers_payload)} answers for " f"user '{username}'..." ) print(status_update) # 5. Submit print( f"Submitting {len(answers_payload)} " f"answers to: {submit_url}" ) try: response = requests.post( submit_url, json=submission_data, timeout=60, ) response.raise_for_status() result_data = response.json() final_status = ( f"Submission Successful!\n" f"User: {result_data.get('username')}\n" f"Overall Score: " f"{result_data.get('score', 'N/A')}% " f"({result_data.get('correct_count', '?')}" f"/{result_data.get('total_attempted', '?')}" f" correct)\n" f"Message: " f"{result_data.get('message', 'N/A')}" ) print("Submission successful.") results_df = pd.DataFrame(results_log) return final_status, results_df except requests.exceptions.HTTPError as e: error_detail = ( "Server responded with status " f"{e.response.status_code}." ) try: error_json = e.response.json() error_detail += ( " Detail: " f"{error_json.get('detail', e.response.text)}" ) except requests.exceptions.JSONDecodeError: error_detail += ( f" Response: " f"{e.response.text[:500]}" ) status_message = ( f"Submission Failed: {error_detail}" ) print(status_message) results_df = pd.DataFrame(results_log) return status_message, results_df except requests.exceptions.Timeout: status_message = ( "Submission Failed: Request timed out." ) print(status_message) results_df = pd.DataFrame(results_log) return status_message, results_df except requests.exceptions.RequestException as e: status_message = ( f"Submission Failed: Network error - {e}" ) print(status_message) results_df = pd.DataFrame(results_log) return status_message, results_df except Exception as e: status_message = ( "Unexpected error during " f"submission: {e}" ) print(status_message) results_df = pd.DataFrame(results_log) return status_message, results_df # -------------------------------------------------- # Gradio UI # -------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("# GAIA Agent Evaluation Runner") gr.Markdown( """ **Instructions:** 1. Clone this space and customise the agent. 2. Log in with the button below. 3. Click **Run Evaluation & Submit All Answers**. --- *Processing all 20 questions will take several minutes. The agent uses web search, Wikipedia, page fetching, and file download tools.* """ ) gr.LoginButton() run_button = gr.Button( "Run Evaluation & Submit All Answers" ) status_output = gr.Textbox( label="Run Status / Submission Result", lines=5, interactive=False, ) results_table = gr.DataFrame( label="Questions and Agent Answers", wrap=True, ) run_button.click( fn=run_and_submit_all, outputs=[status_output, results_table], ) demo.queue() if __name__ == "__main__": print( "\n" + "-" * 30 + " App Starting " + "-" * 30 ) space_host = os.getenv("SPACE_HOST") space_id = os.getenv("SPACE_ID") if space_host: print(f"✅ SPACE_HOST: {space_host}") else: print("ℹ️ SPACE_HOST not found.") if space_id: print(f"✅ SPACE_ID: {space_id}") else: print("ℹ️ SPACE_ID not found.") print("-" * 74 + "\n") print("Launching Gradio Interface...") demo.launch(debug=True, share=False)