Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import mimetypes | |
| import requests | |
| import pandas as pd | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from smolagents import ( | |
| CodeAgent, | |
| DuckDuckGoSearchTool, | |
| OpenAIServerModel, | |
| WikipediaSearchTool, | |
| VisitWebpageTool, | |
| Tool, | |
| ) | |
| load_dotenv() | |
| # --- Constants --- | |
| DEFAULT_API_URL = ( | |
| "https://agents-course-unit4-scoring.hf.space" | |
| ) | |
| GROQ_API_BASE = "https://api.groq.com/openai/v1" | |
| TEXT_MODEL_ID = "llama-3.3-70b-versatile" | |
| VISION_MODEL_ID = ( | |
| "meta-llama/llama-4-scout-17b-16e-instruct" | |
| ) | |
| AUDIO_MODEL_ID = "whisper-large-v3" | |
| # Format instructions appended to every question | |
| # so that the agent returns exact-match-friendly | |
| # answers via final_answer(). | |
| ANSWER_FORMAT_INSTRUCTIONS = """ | |
| IMPORTANT FORMAT INSTRUCTIONS: | |
| Your final_answer must be as concise as possible: | |
| - If the answer is a number, return ONLY the number | |
| (no units, no commas, no $ or % unless asked). | |
| - If the answer is a string, return ONLY the | |
| essential words (no articles like "the"/"a", | |
| no abbreviations for cities, write digits in | |
| plain text unless told otherwise). | |
| - If the answer is a comma separated list, apply | |
| the rules above to each element. | |
| Do NOT include explanations in your final_answer, | |
| just the bare answer.""" | |
| # -------------------------------------------------- | |
| # Custom tool: download a GAIA task file | |
| # -------------------------------------------------- | |
| class GaiaFileFetcherTool(Tool): | |
| """Downloads the file attached to a GAIA task.""" | |
| name = "fetch_task_file" | |
| description = ( | |
| "Downloads the file attached to a GAIA task " | |
| "given its task_id. Returns the local path " | |
| "to the downloaded file so you can read it." | |
| ) | |
| inputs = { | |
| "task_id": { | |
| "type": "string", | |
| "description": ( | |
| "The task_id of the GAIA question " | |
| "whose attached file you need." | |
| ), | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, api_url: str, **kwargs): | |
| super().__init__(**kwargs) | |
| self.api_url = api_url | |
| def forward(self, task_id: str) -> str: | |
| import requests as _req | |
| import tempfile as _tmp | |
| import mimetypes as _mt | |
| url = f"{self.api_url}/files/{task_id}" | |
| resp = _req.get(url, timeout=30) | |
| resp.raise_for_status() | |
| # Derive a sensible extension from headers | |
| ct = resp.headers.get("Content-Type", "") | |
| ext = _mt.guess_extension(ct.split(";")[0]) or "" | |
| cd = resp.headers.get( | |
| "Content-Disposition", "" | |
| ) | |
| fname = "" | |
| if "filename=" in cd: | |
| fname = cd.split("filename=")[-1] | |
| fname = fname.strip('"').strip("'") | |
| if not fname: | |
| fname = f"{task_id}{ext}" | |
| fname = os.path.basename(fname) | |
| path = os.path.join( | |
| _tmp.gettempdir(), fname | |
| ) | |
| with open(path, "wb") as f: | |
| f.write(resp.content) | |
| return path | |
| class GroqAudioTranscriptionTool(Tool): | |
| """Transcribes an audio file with Groq Whisper.""" | |
| name = "transcribe_audio_file" | |
| description = ( | |
| "Transcribes a local audio file path, such as an " | |
| "MP3 downloaded with fetch_task_file. Returns the " | |
| "plain transcript text." | |
| ) | |
| inputs = { | |
| "file_path": { | |
| "type": "string", | |
| "description": "Local path to the audio file.", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, file_path: str) -> str: | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "GROQ_API_KEY is required for audio transcription." | |
| ) | |
| with open(file_path, "rb") as audio_file: | |
| response = requests.post( | |
| f"{GROQ_API_BASE}/audio/transcriptions", | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| }, | |
| files={ | |
| "file": ( | |
| os.path.basename(file_path), | |
| audio_file, | |
| ) | |
| }, | |
| data={ | |
| "model": AUDIO_MODEL_ID, | |
| "response_format": "json", | |
| "temperature": "0", | |
| }, | |
| timeout=120, | |
| ) | |
| response.raise_for_status() | |
| return response.json().get("text", "").strip() | |
| class GroqImageAnalysisTool(Tool): | |
| """Answers questions about a local image with Groq vision.""" | |
| name = "analyze_image_file" | |
| description = ( | |
| "Analyzes a local image file path and answers a " | |
| "specific visual question about it." | |
| ) | |
| inputs = { | |
| "file_path": { | |
| "type": "string", | |
| "description": "Local path to the image file.", | |
| }, | |
| "question": { | |
| "type": "string", | |
| "description": "The question to answer about the image.", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, file_path: str, question: str) -> str: | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "GROQ_API_KEY is required for image analysis." | |
| ) | |
| mime_type = ( | |
| mimetypes.guess_type(file_path)[0] | |
| or "application/octet-stream" | |
| ) | |
| with open(file_path, "rb") as image_file: | |
| encoded = base64.b64encode( | |
| image_file.read() | |
| ).decode("ascii") | |
| response = requests.post( | |
| f"{GROQ_API_BASE}/chat/completions", | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={ | |
| "model": VISION_MODEL_ID, | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": question, | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": ( | |
| f"data:{mime_type};" | |
| f"base64,{encoded}" | |
| ) | |
| }, | |
| }, | |
| ], | |
| } | |
| ], | |
| "temperature": 0.1, | |
| "max_completion_tokens": 512, | |
| }, | |
| timeout=120, | |
| ) | |
| response.raise_for_status() | |
| return ( | |
| response.json()["choices"][0]["message"] | |
| ["content"] | |
| .strip() | |
| ) | |
| # -------------------------------------------------- | |
| # Agent wrapper | |
| # -------------------------------------------------- | |
| class BasicAgent: | |
| def __init__(self): | |
| print("BasicAgent initialized.") | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise RuntimeError( | |
| "Missing GROQ_API_KEY. Add it to your " | |
| "Hugging Face Space secrets or local .env file." | |
| ) | |
| model = OpenAIServerModel( | |
| model_id=TEXT_MODEL_ID, | |
| api_base=GROQ_API_BASE, | |
| api_key=groq_api_key, | |
| ) | |
| self.file_tool = GaiaFileFetcherTool( | |
| api_url=DEFAULT_API_URL, | |
| ) | |
| self.audio_tool = GroqAudioTranscriptionTool() | |
| self.image_tool = GroqImageAnalysisTool() | |
| self.agent = CodeAgent( | |
| model=model, | |
| tools=[ | |
| DuckDuckGoSearchTool(), | |
| WikipediaSearchTool( | |
| user_agent="GaiaAgent/1.0" | |
| ), | |
| VisitWebpageTool(), | |
| self.file_tool, | |
| self.audio_tool, | |
| self.image_tool, | |
| ], | |
| max_steps=15, | |
| verbosity_level=0, | |
| additional_authorized_imports=[ | |
| "base64", | |
| "json", | |
| "re", | |
| "csv", | |
| "math", | |
| "statistics", | |
| "datetime", | |
| "collections", | |
| "itertools", | |
| "os", | |
| "pathlib", | |
| "mimetypes", | |
| "pandas", | |
| "openpyxl", | |
| ], | |
| ) | |
| def __call__( | |
| self, | |
| question: str, | |
| task_id: str, | |
| has_file: bool = False, | |
| ) -> str: | |
| # Build the prompt for the agent | |
| prompt = question | |
| if has_file: | |
| prompt += ( | |
| f"\n\n[This question has an attached " | |
| f"file. Use the fetch_task_file tool " | |
| f"with task_id='{task_id}' to " | |
| f"download it. If it is audio, use " | |
| f"transcribe_audio_file. If it is an " | |
| f"image, use analyze_image_file. If it " | |
| f"is a spreadsheet, read it with pandas.]" | |
| ) | |
| prompt += ANSWER_FORMAT_INSTRUCTIONS | |
| raw = str(self.agent.run(prompt)) | |
| return raw.strip() | |
| # -------------------------------------------------- | |
| # Gradio: run all & submit | |
| # -------------------------------------------------- | |
| def run_and_submit_all( | |
| profile: gr.OAuthProfile | None, | |
| ): | |
| """ | |
| Fetches all questions, runs the agent, | |
| submits answers, and displays results. | |
| """ | |
| space_id = os.getenv("SPACE_ID") | |
| if profile: | |
| username = f"{profile.username}" | |
| print(f"User logged in: {username}") | |
| else: | |
| print("User not logged in.") | |
| return ( | |
| "Please Login to Hugging Face " | |
| "with the button.", | |
| None, | |
| ) | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| submit_url = f"{api_url}/submit" | |
| # 1. Instantiate Agent | |
| try: | |
| agent = BasicAgent() | |
| except Exception as e: | |
| print(f"Error instantiating agent: {e}") | |
| return f"Error initializing agent: {e}", None | |
| agent_code = ( | |
| f"https://huggingface.co/spaces/" | |
| f"{space_id or 'unknown-space'}/tree/main" | |
| ) | |
| print(agent_code) | |
| # 2. Fetch Questions | |
| print( | |
| f"Fetching questions from: {questions_url}" | |
| ) | |
| try: | |
| response = requests.get( | |
| questions_url, timeout=15 | |
| ) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| print("Fetched questions list is empty.") | |
| return ( | |
| "Fetched questions list is empty " | |
| "or invalid format.", | |
| None, | |
| ) | |
| print( | |
| f"Fetched {len(questions_data)} " | |
| f"questions." | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error fetching questions: {e}") | |
| return f"Error fetching questions: {e}", None | |
| except requests.exceptions.JSONDecodeError as e: | |
| print( | |
| "Error decoding JSON from questions " | |
| f"endpoint: {e}" | |
| ) | |
| print(f"Response text: {response.text[:500]}") | |
| return ( | |
| "Error decoding server response " | |
| f"for questions: {e}", | |
| None, | |
| ) | |
| except Exception as e: | |
| print( | |
| "Unexpected error fetching " | |
| f"questions: {e}" | |
| ) | |
| return ( | |
| "Unexpected error fetching " | |
| f"questions: {e}", | |
| None, | |
| ) | |
| # 3. Run Agent on each question | |
| results_log = [] | |
| answers_payload = [] | |
| total = len(questions_data) | |
| print(f"Running agent on {total} questions...") | |
| for i, item in enumerate(questions_data): | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| if not task_id or question_text is None: | |
| print( | |
| "Skipping item with missing " | |
| f"task_id or question: {item}" | |
| ) | |
| continue | |
| # Check if the question has a file | |
| file_name = item.get("file_name", "") | |
| has_file = bool(file_name) | |
| print( | |
| f"[{i+1}/{total}] Task {task_id}" | |
| f"{' (has file)' if has_file else ''}" | |
| ) | |
| try: | |
| submitted_answer = agent( | |
| question_text, | |
| task_id, | |
| has_file, | |
| ) | |
| answers_payload.append( | |
| { | |
| "task_id": task_id, | |
| "submitted_answer": ( | |
| submitted_answer | |
| ), | |
| } | |
| ) | |
| results_log.append( | |
| { | |
| "Task ID": task_id, | |
| "Question": question_text, | |
| "Submitted Answer": ( | |
| submitted_answer | |
| ), | |
| } | |
| ) | |
| except Exception as e: | |
| print( | |
| f"Error on task {task_id}: {e}" | |
| ) | |
| results_log.append( | |
| { | |
| "Task ID": task_id, | |
| "Question": question_text, | |
| "Submitted Answer": ( | |
| f"AGENT ERROR: {e}" | |
| ), | |
| } | |
| ) | |
| if not answers_payload: | |
| print( | |
| "Agent did not produce any answers." | |
| ) | |
| return ( | |
| "Agent did not produce any answers " | |
| "to submit.", | |
| pd.DataFrame(results_log), | |
| ) | |
| # 4. Prepare Submission | |
| submission_data = { | |
| "username": username.strip(), | |
| "agent_code": agent_code, | |
| "answers": answers_payload, | |
| } | |
| status_update = ( | |
| f"Agent finished. Submitting " | |
| f"{len(answers_payload)} answers for " | |
| f"user '{username}'..." | |
| ) | |
| print(status_update) | |
| # 5. Submit | |
| print( | |
| f"Submitting {len(answers_payload)} " | |
| f"answers to: {submit_url}" | |
| ) | |
| try: | |
| response = requests.post( | |
| submit_url, | |
| json=submission_data, | |
| timeout=60, | |
| ) | |
| response.raise_for_status() | |
| result_data = response.json() | |
| final_status = ( | |
| f"Submission Successful!\n" | |
| f"User: {result_data.get('username')}\n" | |
| f"Overall Score: " | |
| f"{result_data.get('score', 'N/A')}% " | |
| f"({result_data.get('correct_count', '?')}" | |
| f"/{result_data.get('total_attempted', '?')}" | |
| f" correct)\n" | |
| f"Message: " | |
| f"{result_data.get('message', 'N/A')}" | |
| ) | |
| print("Submission successful.") | |
| results_df = pd.DataFrame(results_log) | |
| return final_status, results_df | |
| except requests.exceptions.HTTPError as e: | |
| error_detail = ( | |
| "Server responded with status " | |
| f"{e.response.status_code}." | |
| ) | |
| try: | |
| error_json = e.response.json() | |
| error_detail += ( | |
| " Detail: " | |
| f"{error_json.get('detail', e.response.text)}" | |
| ) | |
| except requests.exceptions.JSONDecodeError: | |
| error_detail += ( | |
| f" Response: " | |
| f"{e.response.text[:500]}" | |
| ) | |
| status_message = ( | |
| f"Submission Failed: {error_detail}" | |
| ) | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.Timeout: | |
| status_message = ( | |
| "Submission Failed: Request timed out." | |
| ) | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.RequestException as e: | |
| status_message = ( | |
| f"Submission Failed: Network error - {e}" | |
| ) | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except Exception as e: | |
| status_message = ( | |
| "Unexpected error during " | |
| f"submission: {e}" | |
| ) | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| # -------------------------------------------------- | |
| # Gradio UI | |
| # -------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# GAIA Agent Evaluation Runner") | |
| gr.Markdown( | |
| """ | |
| **Instructions:** | |
| 1. Clone this space and customise the agent. | |
| 2. Log in with the button below. | |
| 3. Click **Run Evaluation & Submit All Answers**. | |
| --- | |
| *Processing all 20 questions will take several | |
| minutes. The agent uses web search, Wikipedia, | |
| page fetching, and file download tools.* | |
| """ | |
| ) | |
| gr.LoginButton() | |
| run_button = gr.Button( | |
| "Run Evaluation & Submit All Answers" | |
| ) | |
| status_output = gr.Textbox( | |
| label="Run Status / Submission Result", | |
| lines=5, | |
| interactive=False, | |
| ) | |
| results_table = gr.DataFrame( | |
| label="Questions and Agent Answers", | |
| wrap=True, | |
| ) | |
| run_button.click( | |
| fn=run_and_submit_all, | |
| outputs=[status_output, results_table], | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| print( | |
| "\n" + "-" * 30 | |
| + " App Starting " | |
| + "-" * 30 | |
| ) | |
| space_host = os.getenv("SPACE_HOST") | |
| space_id = os.getenv("SPACE_ID") | |
| if space_host: | |
| print(f"✅ SPACE_HOST: {space_host}") | |
| else: | |
| print("ℹ️ SPACE_HOST not found.") | |
| if space_id: | |
| print(f"✅ SPACE_ID: {space_id}") | |
| else: | |
| print("ℹ️ SPACE_ID not found.") | |
| print("-" * 74 + "\n") | |
| print("Launching Gradio Interface...") | |
| demo.launch(debug=True, share=False) | |