| |
| import os |
| import re |
| import json |
| import math |
| import csv |
| import time |
| import gradio as gr |
| import requests |
| import inspect |
| import pandas as pd |
| from ast import literal_eval |
| from dotenv import load_dotenv |
| from pathlib import Path |
| from typing import Optional, Tuple, Dict, Any, List |
|
|
| from smolagents import CodeAgent, DuckDuckGoSearchTool, OpenAIServerModel |
| from tools import ( |
| ReverseTextTool, |
| ExtractTextFromImageTool, |
| AnalyzeCSVTool, |
| AnalyzeExcelTool, |
| DateCalculatorTool, |
| DownloadFileTool |
| ) |
|
|
| |
| try: |
| load_dotenv() |
| print("Loaded environment variables from .env file") |
| except Exception as e: |
| print(f"Note: Could not load .env file - {e}") |
|
|
| |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
| DEFAULT_GOLD_CSV = os.environ.get("GAIA_GOLD_CSV", "answers.csv") |
|
|
| |
| FINAL_ANSWER_RE = re.compile(r"FINAL ANSWER\s*:\s*(.*)", re.IGNORECASE) |
|
|
| def extract_final_answer(text: str) -> str: |
| """ |
| Extract whatever comes after 'FINAL ANSWER:' on the last occurrence. |
| Falls back to entire text if pattern not found. |
| """ |
| if not isinstance(text, str): |
| return "" |
| matches = FINAL_ANSWER_RE.findall(text) |
| if matches: |
| return matches[-1].strip() |
| return (text or "").strip() |
|
|
| def is_number(s: str) -> bool: |
| try: |
| float(s) |
| return True |
| except Exception: |
| return False |
|
|
| def try_parse_number(s: str) -> Optional[float]: |
| try: |
| return float(s) |
| except Exception: |
| return None |
|
|
| def split_csv_like(s: str) -> List[str]: |
| parts = [p.strip() for p in re.split(r",", s)] |
| parts = [p for p in parts if p != ""] |
| return parts |
|
|
| def normalize_final(s: str) -> str: |
| """ |
| GAIA-style normalization: |
| - lowercase |
| - strip articles 'a', 'an', 'the' |
| - strip $, % if any |
| - collapse spaces, trim punctuation |
| """ |
| s = (s or "").strip().lower() |
| s = s.replace("%", "").replace("$", "") |
| s = re.sub(r"[^\w\s,.\-]+", "", s) |
| s = re.sub(r"\s+", " ", s).strip() |
| s = re.sub(r"^(a|an|the)\s+", "", s) |
| return s |
|
|
| def list_like_equal(a: str, b: str) -> bool: |
| la = [normalize_final(p) for p in split_csv_like(a)] |
| lb = [normalize_final(p) for p in split_csv_like(b)] |
| return sorted(la) == sorted(lb) |
|
|
| def numeric_close(a: str, b: str, rel_tol=1e-9, abs_tol=1e-6) -> bool: |
| na, nb = try_parse_number(a), try_parse_number(b) |
| if na is None or nb is None: |
| return False |
| return math.isclose(na, nb, rel_tol=rel_tol, abs_tol=abs_tol) |
|
|
| def fast_heuristic_match(pred: str, gold: str) -> bool: |
| pn = normalize_final(pred) |
| gn = normalize_final(gold) |
| if pn == gn: |
| return True |
| if numeric_close(pn, gn): |
| return True |
| if ("," in pred) or ("," in gold): |
| if list_like_equal(pred, gold): |
| return True |
| return False |
|
|
| def quick_format_fix(answer: str, question: str) -> str: |
| """ |
| Deterministic, judge-friendly cleanup. We DO NOT use gold here. |
| - Remove leading articles for strings |
| - Strip currency & percent unless explicitly requested by question |
| - Remove thousands commas in numbers |
| - Trim trailing punctuation |
| - Normalize whitespace |
| - Unify separators to comma for list-like strings |
| """ |
| if not isinstance(answer, str): |
| return answer |
|
|
| s = answer.strip() |
|
|
| |
| s = re.sub(r"^```.*?\n", "", s, flags=re.DOTALL) |
| s = s.replace("```", "").strip() |
|
|
| |
| s = re.sub(r"\s+", " ", s).strip() |
|
|
| |
| s = re.sub(r"[.。]+$", "", s) |
|
|
| |
| if ";" in s or "/" in s: |
| s = re.sub(r"[;/]+", ",", s) |
| s = re.sub(r"\s*,\s*", ", ", s) |
|
|
| |
| s = re.sub(r"^(?i)(a|an|the)\s+", "", s) |
|
|
| |
| |
| if "," in s and not re.search(r".*,.*", s): |
| if re.fullmatch(r"\d{1,3}(,\d{3})+(\.\d+)?", s): |
| s = s.replace(",", "") |
|
|
| |
| if "$" in s and not re.search(r"(?i)\b(dollar|usd|\$)\b.*(include|keep|use)|include\s*\$", question): |
| s = s.replace("$", "") |
|
|
| |
| needs_percent = bool(re.search(r"(?i)\b(percent|%)\b.*(include|with|as sign)|include\s*%", question)) |
| if "%" in s and not needs_percent: |
| s = s.replace("%", "") |
|
|
| return s.strip() |
|
|
| |
| class GoldAnswers: |
| """ |
| Loads answers.csv like your example and indexes by task_id. |
| - content column includes "... Final answer : X" |
| - metadata column includes "{'task_id': '...'}" |
| """ |
| def __init__(self, path: str = DEFAULT_GOLD_CSV): |
| self.by_task_id: Dict[str, str] = {} |
| self.load(path) |
|
|
| def load(self, path: str): |
| p = Path(path) |
| if not p.exists(): |
| print(f"[GoldAnswers] Warning: {path} not found. Local judge will skip.") |
| return |
| with p.open("r", encoding="utf-8") as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| content = row.get("content", "") |
| metadata_str = row.get("metadata", "") |
| |
| gold_full = extract_final_answer(content) |
| gold_full = re.sub(r"^final answer\s*:\s*", "", gold_full, flags=re.IGNORECASE).strip() |
|
|
| task_id = None |
| try: |
| md = literal_eval(metadata_str) if metadata_str else {} |
| task_id = md.get("task_id") |
| except Exception: |
| pass |
|
|
| if task_id and gold_full: |
| self.by_task_id[task_id] = gold_full |
|
|
| print(f"[GoldAnswers] Loaded {len(self.by_task_id)} gold answers from {path}.") |
|
|
| |
| JUDGE_SYSTEM = ( |
| "You are a strict grader for short answers. " |
| "Follow these GAIA rules: answers must be exact, concise, and obey units/format rules. " |
| "However, accept semantically equivalent forms (e.g., pluralization or minor punctuation) " |
| "and unordered lists if order is not required by the question. " |
| "For numeric answers, small rounding differences are acceptable. " |
| "Return ONLY a compact JSON object with keys: is_correct (true/false), score (0..1), justification (short). " |
| "Do not include any additional text outside the JSON." |
| ) |
|
|
| def build_judge_prompt(question: str, predicted: str, gold: str) -> str: |
| return f""" |
| You are grading whether the predicted answer matches the gold answer for this GAIA-style item. |
| |
| Question: |
| {question} |
| |
| Predicted answer: |
| {predicted} |
| |
| Gold answer: |
| {gold} |
| |
| Evaluate correctness according to GAIA formatting rules and semantics. |
| Output strictly this JSON: |
| {{ |
| "is_correct": true|false, |
| "score": number between 0 and 1, |
| "justification": "≤ 2 short sentences; no chain-of-thought" |
| }} |
| """ |
|
|
| class JudgeAgent: |
| """ |
| A smolagents CodeAgent used purely for grading. We call .run(prompt) to avoid any |
| direct use of model.generate signatures — this mirrors the GAIA agent path. |
| """ |
| def __init__(self, base_model: OpenAIServerModel, verbose: bool = False): |
| self.verbose = verbose |
| self.agent = CodeAgent( |
| tools=[], |
| model=base_model, |
| add_base_tools=False, |
| planning_interval=0, |
| verbosity_level=2 if verbose else 0, |
| additional_authorized_imports=[] |
| ) |
|
|
| def judge(self, question: str, predicted: str, gold: str) -> Dict[str, Any]: |
| |
| if fast_heuristic_match(predicted, gold): |
| return {"is_correct": True, "score": 1.0, "justification": "Heuristic match."} |
|
|
| prompt = f"{JUDGE_SYSTEM}\n\n{build_judge_prompt(question, predicted, gold)}" |
| try: |
| raw = self.agent.run(prompt) |
| text = (raw or "").strip() |
| m = re.search(r"\{.*\}", text, flags=re.DOTALL) |
| payload = json.loads(m.group(0) if m else text) |
|
|
| is_correct = bool(payload.get("is_correct", False)) |
| score = float(payload.get("score", 0.0)) |
| justification = str(payload.get("justification", "")).strip()[:300] |
|
|
| return {"is_correct": is_correct, "score": score, "justification": justification} |
| except Exception as e: |
| return {"is_correct": False, "score": 0.0, "justification": f"Judge error: {e}"} |
|
|
| |
| class GAIAAgent: |
| def __init__(self, verbose=False): |
| self.verbose = verbose |
| print("Initializing GAIA Agent...") |
|
|
| |
| api_key = os.environ.get("OPENAI_API_KEY") |
| if not api_key: |
| raise ValueError("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
|
|
| |
| model_id = os.environ.get("OPENAI_MODEL_ID", "gpt-4o-mini") |
| print(f"Using OpenAI model: {model_id}") |
|
|
| model = OpenAIServerModel( |
| model_id=model_id, |
| api_key=api_key, |
| temperature=0.1 |
| ) |
|
|
| duck_search_tool = DuckDuckGoSearchTool() |
|
|
| self.tools = [ |
| duck_search_tool, |
| ReverseTextTool(), |
| ExtractTextFromImageTool(), |
| AnalyzeCSVTool(), |
| AnalyzeExcelTool(), |
| DateCalculatorTool(), |
| DownloadFileTool() |
| ] |
|
|
| additional_imports = [ |
| "PyPDF2", "pdf2image", "PIL", "nltk", "sklearn", |
| "networkx", "matplotlib", "seaborn", "scipy", "time" |
| ] |
|
|
| self.agent = CodeAgent( |
| tools=self.tools, |
| model=model, |
| add_base_tools=True, |
| planning_interval=3, |
| verbosity_level=2 if self.verbose else 0, |
| additional_authorized_imports=additional_imports |
| ) |
|
|
| print("GAIA Agent initialized and ready") |
|
|
| def _is_reversed_text(self, text): |
| return ( |
| text.startswith(".") or |
| ".rewsna eht sa" in text or |
| "esrever" in text or |
| "sdrawkcab" in text |
| ) |
|
|
| def _base_prompt(self, question: str, allow_extra_searches: bool = False) -> str: |
| |
| search_budget_line = ( |
| "- Limit to 1-2 web searches per question.\n" |
| if not allow_extra_searches else |
| "- You may use up to 3-4 web searches if needed.\n" |
| ) |
| return f""" |
| You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. |
| |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. |
| - If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. |
| - If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. |
| - If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string. |
| |
| Question: {question} |
| |
| IMPORTANT NOTES TO LIMIT COSTS AND PREVENT ERRORS: |
| - Use web search sparingly and only when absolutely necessary. |
| {search_budget_line}- If a search fails due to rate limiting, add a 3-5 second delay using time.sleep() before retrying with a different search term. |
| - Do not import libraries that aren't available - stick to basic Python and the tools provided. |
| - Focus on answering directly with what you already know when possible. |
| - If you've made more than 3 attempts to solve a problem, prioritize providing your best guess. |
| - Always add a delay of 2-3 seconds between web searches using time.sleep() to avoid rate limiting. |
| |
| Remember to structure your response in Python code format using the final_answer() function. |
| """ |
|
|
| def _reversed_prompt(self, question: str, allow_extra_searches: bool = False) -> str: |
| search_budget_line = ( |
| "- Limit to 1-2 web searches per question.\n" |
| if not allow_extra_searches else |
| "- You may use up to 3-4 web searches if needed.\n" |
| ) |
| return f""" |
| You are a general AI assistant. I will ask you a question. |
| |
| This question appears to be in reversed text. Here is the reversed version for clarity: |
| {question[::-1]} |
| |
| Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. |
| |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. |
| - If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. |
| - If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. |
| - If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string. |
| |
| IMPORTANT NOTES TO LIMIT COSTS AND PREVENT ERRORS: |
| - Use web search sparingly and only when absolutely necessary. |
| {search_budget_line}- If a search fails due to rate limiting, add a 3-5 second delay using time.sleep() before retrying with a different search term. |
| - Do not import libraries that aren't available - stick to basic Python and the tools provided. |
| - Focus on answering directly with what you already know when possible. |
| - If you've made more than 3 attempts to solve a problem, prioritize providing your best guess. |
| - Always add a delay of 2-3 seconds between web searches using time.sleep() to avoid rate limiting. |
| |
| Remember to structure your response in Python code format using the final_answer() function. |
| """ |
|
|
| def __call__(self, question: str, allow_extra_searches: bool = False) -> str: |
| if self.verbose: |
| msg = f"Processing question: {question[:100]}..." if len(question) > 100 else f"Processing question: {question}" |
| print(msg) |
|
|
| prompt = ( |
| self._reversed_prompt(question, allow_extra_searches) |
| if self._is_reversed_text(question) |
| else self._base_prompt(question, allow_extra_searches) |
| ) |
| try: |
| answer = self.agent.run(prompt) |
| if self.verbose: |
| print(f"Generated answer: {answer}") |
| return answer |
| except Exception as e: |
| error_msg = f"Error processing question: {e}" |
| if self.verbose: |
| print(error_msg) |
| return error_msg |
|
|
| def refine(self, question: str, prev_answer: str, judge_feedback: str, attempt_no: int) -> str: |
| """ |
| Reflection-based reattempt without using gold. |
| """ |
| if self.verbose: |
| print(f"Refining (attempt {attempt_no}) based on judge note: {judge_feedback}") |
|
|
| allow_extra = attempt_no >= 2 |
| base = self._base_prompt(question, allow_extra_searches=allow_extra) |
|
|
| refinement_addendum = f""" |
| Your previous FINAL ANSWER was: |
| {prev_answer} |
| |
| A strict judge said this answer was incorrect for the following reason(s) (be concise): {judge_feedback} |
| |
| Re-evaluate the question carefully. Consider possible formatting issues (units, articles, thousands commas), list ordering (only if the question requires a specific order), and rounding. |
| Produce a NEW final answer. Do not repeat the previous final answer if you think it was wrong. |
| """ |
|
|
| try: |
| answer = self.agent.run(base + refinement_addendum) |
| if self.verbose: |
| print(f"Refined answer: {answer}") |
| return answer |
| except Exception as e: |
| err = f"Error refining: {e}" |
| if self.verbose: |
| print(err) |
| return err |
|
|
| |
| gold_answers = GoldAnswers(path=DEFAULT_GOLD_CSV) |
| _judge_agent_singleton: Optional[JudgeAgent] = None |
|
|
| |
| def _ensure_judge(model: OpenAIServerModel) -> JudgeAgent: |
| global _judge_agent_singleton |
| if _judge_agent_singleton is None: |
| _judge_agent_singleton = JudgeAgent(base_model=model, verbose=False) |
| return _judge_agent_singleton |
|
|
| def run_and_submit_all(sample_size: int = 0, max_retries: int = 1, use_local_judge_to_select: bool = True): |
| """ |
| Fetches all questions, runs the agent on them, judges locally (if gold available), |
| optionally reattempts on incorrect results, submits answers, and returns: |
| - final status string |
| - final results dataframe (one row per question) |
| - attempt log dataframe (one row per attempt) |
| """ |
| username = "Gralon" |
| print(f"Using username: {username}") |
|
|
| api_url = DEFAULT_API_URL |
| questions_url = f"{api_url}/questions" |
| submit_url = f"{api_url}/submit" |
|
|
| |
| try: |
| agent = GAIAAgent(verbose=True) |
| except Exception as e: |
| print(f"Error instantiating agent: {e}") |
| return f"Error initializing agent: {e}", None, None |
|
|
| |
| judge_agent = _ensure_judge(agent.agent.model) |
|
|
| |
| space_id = os.getenv("SPACE_ID") |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "local" |
|
|
| |
| print(f"Fetching questions from: {questions_url}") |
| try: |
| response = requests.get(questions_url, timeout=15) |
| response.raise_for_status() |
| questions_data = response.json() |
| if not questions_data: |
| print("Fetched questions list is empty.") |
| return "Fetched questions list is empty or invalid format.", None, None |
| print(f"Fetched {len(questions_data)} questions.") |
| except requests.exceptions.RequestException as e: |
| print(f"Error fetching questions: {e}") |
| return f"Error fetching questions: {e}", None, None |
| except json.JSONDecodeError as e: |
| print(f"Error decoding JSON response from questions endpoint: {e}") |
| print(f"Response text: {response.text[:500]}") |
| return f"Error decoding server response for questions: {e}", None, None |
| except Exception as e: |
| print(f"An unexpected error occurred fetching questions: {e}") |
| return f"An unexpected error occurred fetching questions: {e}", None, None |
|
|
| |
| if sample_size > 0 and sample_size < len(questions_data): |
| import random |
| print(f"Using a sample of {sample_size} questions from {len(questions_data)} total questions") |
| questions_data = random.sample(questions_data, sample_size) |
|
|
| print(f"Running agent on {len(questions_data)} questions...") |
| results_log: List[Dict[str, Any]] = [] |
| attempts_log: List[Dict[str, Any]] = [] |
| answers_payload: List[Dict[str, Any]] = [] |
|
|
| for i, item in enumerate(questions_data): |
| task_id = item.get("task_id") |
| question_text = item.get("question") |
| if not task_id or question_text is None: |
| print(f"Skipping item with missing task_id or question: {item}") |
| continue |
|
|
| gold = gold_answers.by_task_id.get(task_id) |
| per_question_attempts: List[Dict[str, Any]] = [] |
|
|
| try: |
| print(f"Processing question {i+1}/{len(questions_data)}: Task ID {task_id}") |
|
|
| |
| raw = agent(question_text, allow_extra_searches=False) |
| ans = extract_final_answer(raw) |
| fixed = quick_format_fix(ans, question_text) or ans |
|
|
| |
| jres = None |
| j_is_correct = None |
| j_score = None |
| j_note = None |
| if gold: |
| jres = judge_agent.judge(question_text, fixed, gold) |
| j_is_correct = jres.get("is_correct") |
| j_score = jres.get("score") |
| j_note = jres.get("justification") |
|
|
| per_question_attempts.append({ |
| "Task ID": task_id, |
| "Attempt": 1, |
| "Submitted Answer (raw)": ans, |
| "Submitted Answer (fixed)": fixed, |
| "Judge Correct?": j_is_correct, |
| "Judge Score": j_score, |
| "Judge Note": j_note |
| }) |
|
|
| best_answer = fixed |
| best_score = j_score if j_score is not None else 0.0 |
| best_correct = j_is_correct |
|
|
| retries = 0 |
| while (j_is_correct is False) and (retries < max_retries): |
| retries += 1 |
|
|
| |
| refined_raw = agent.refine( |
| question=question_text, |
| prev_answer=fixed, |
| judge_feedback=j_note or "Format/content mismatch.", |
| attempt_no=retries |
| ) |
| refined = extract_final_answer(refined_raw) |
| refined_fixed = quick_format_fix(refined, question_text) or refined |
|
|
| |
| j2 = None |
| j2_is_correct = None |
| j2_score = None |
| j2_note = None |
| if gold: |
| j2 = judge_agent.judge(question_text, refined_fixed, gold) |
| j2_is_correct = j2.get("is_correct") |
| j2_score = j2.get("score") |
| j2_note = j2.get("justification") |
|
|
| per_question_attempts.append({ |
| "Task ID": task_id, |
| "Attempt": retries + 1, |
| "Submitted Answer (raw)": refined, |
| "Submitted Answer (fixed)": refined_fixed, |
| "Judge Correct?": j2_is_correct, |
| "Judge Score": j2_score, |
| "Judge Note": j2_note |
| }) |
|
|
| |
| if use_local_judge_to_select and gold and (j2_score is not None): |
| if (j2_score > (best_score or 0)) or (best_score is None): |
| best_answer, best_score, best_correct = refined_fixed, j2_score, j2_is_correct |
| else: |
| |
| best_answer = refined_fixed |
| best_score = j2_score if j2_score is not None else best_score |
| best_correct = j2_is_correct if j2_is_correct is not None else best_correct |
|
|
| |
| fixed = refined_fixed |
| j_is_correct = j2_is_correct |
| j_score = j2_score |
| j_note = j2_note |
|
|
| if j2_is_correct: |
| break |
|
|
| if retries < max_retries: |
| print("Waiting 2 seconds before next attempt...") |
| time.sleep(2) |
|
|
| |
| answers_payload.append({"task_id": task_id, "submitted_answer": best_answer}) |
| results_log.append({ |
| "Task ID": task_id, |
| "Question": question_text, |
| "Submitted Answer": best_answer, |
| "Gold (local)": gold if gold else "", |
| "Judge Correct?": best_correct, |
| "Judge Score": best_score, |
| "Judge Note": j_note |
| }) |
| print(f"Finished question {i+1}") |
|
|
| |
| attempts_log.extend(per_question_attempts) |
|
|
| if i < len(questions_data) - 1: |
| print("Waiting 2 seconds before next question...") |
| time.sleep(2) |
|
|
| except Exception as e: |
| print(f"Error running agent on task {task_id}: {e}") |
| results_log.append({ |
| "Task ID": task_id, |
| "Question": question_text, |
| "Submitted Answer": f"AGENT ERROR: {e}", |
| "Gold (local)": gold_answers.by_task_id.get(task_id, ""), |
| "Judge Correct?": False, |
| "Judge Score": 0.0, |
| "Judge Note": "agent error" |
| }) |
|
|
| if not answers_payload: |
| print("Agent did not produce any answers to submit.") |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log), pd.DataFrame(attempts_log) |
|
|
| |
| submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} |
| status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..." |
| print(status_update) |
|
|
| |
| print(f"Submitting {len(answers_payload)} answers to: {submit_url}") |
| try: |
| response = requests.post(submit_url, json=submission_data, timeout=60) |
| response.raise_for_status() |
| result_data = response.json() |
| final_status = ( |
| f"Submission Successful!\n" |
| f"User: {result_data.get('username')}\n" |
| f"Overall Score: {result_data.get('score', 'N/A')}% " |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" |
| f"Message: {result_data.get('message', 'No message received.')}" |
| ) |
| print("Submission successful.") |
| results_df = pd.DataFrame(results_log) |
| attempts_df = pd.DataFrame(attempts_log) |
| return final_status, results_df, attempts_df |
| except requests.exceptions.HTTPError as e: |
| error_detail = f"Server responded with status {e.response.status_code}." |
| try: |
| error_json = e.response.json() |
| error_detail += f" Detail: {error_json.get('detail', e.response.text)}" |
| except json.JSONDecodeError: |
| error_detail += f" Response: {e.response.text[:500]}" |
| status_message = f"Submission Failed: {error_detail}" |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| attempts_df = pd.DataFrame(attempts_log) |
| return status_message, results_df, attempts_df |
| except requests.exceptions.Timeout: |
| status_message = "Submission Failed: The request timed out." |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| attempts_df = pd.DataFrame(attempts_log) |
| return status_message, results_df, attempts_df |
| except requests.exceptions.RequestException as e: |
| status_message = f"Submission Failed: Network error - {e}" |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| attempts_df = pd.DataFrame(attempts_log) |
| return status_message, results_df, attempts_df |
| except Exception as e: |
| status_message = f"An unexpected error occurred during submission: {e}" |
| print(status_message) |
| results_df = pd.DataFrame(results_log) |
| attempts_df = pd.DataFrame(attempts_log) |
| return status_message, results_df, attempts_df |
|
|
| def test_single_question(question: str, retries: int = 1) -> str: |
| """Test the agent on a single question (no submission), with judge-aware retries if gold exists.""" |
| try: |
| agent = GAIAAgent(verbose=True) |
| judge_agent = _ensure_judge(agent.agent.model) |
| gold = None |
| |
| raw = agent(question) |
| ans = extract_final_answer(raw) |
| fixed = quick_format_fix(ans, question) or ans |
|
|
| if retries <= 0: |
| return fixed |
|
|
| |
| last = fixed |
| note = "Possible format/content mismatch; re-evaluate." |
| for k in range(retries): |
| refined_raw = agent.refine(question, prev_answer=last, judge_feedback=note, attempt_no=k+1) |
| refined = extract_final_answer(refined_raw) |
| refined_fixed = quick_format_fix(refined, question) or refined |
| last = refined_fixed |
| return last |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| |
| def local_judge_single(question: str, predicted: str, task_id_or_gold: str): |
| |
| gold = gold_answers.by_task_id.get(task_id_or_gold, task_id_or_gold) |
| agent = GAIAAgent(verbose=False) |
| judge_agent = _ensure_judge(agent.agent.model) |
| res = judge_agent.judge(question, predicted, gold) |
| out = { |
| "Gold": gold, |
| "is_correct": res["is_correct"], |
| "score": res["score"], |
| "note": res["justification"] |
| } |
| return json.dumps(out, ensure_ascii=False, indent=2) |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# GAIA Agent Evaluation Runner + Local LLM Judge (with smart retries)") |
| gr.Markdown( |
| """ |
| ## Instructions: |
| |
| 1. Log in to your Hugging Face account using the button below |
| 2. Test your agent on individual questions in the Testing tab |
| 3. Run the full evaluation on the GAIA benchmark in the Evaluation tab |
| |
| This agent runs locally, uses an LLM judge against your answers.csv (if present), |
| **retries intelligently** when the judge says 'incorrect', and then submits answers to the server. |
| """ |
| ) |
|
|
| gr.LoginButton() |
|
|
| with gr.Tab("Test Single Question"): |
| test_input = gr.Textbox(label="Enter a question to test", lines=3) |
| test_retries = gr.Slider(minimum=0, maximum=3, value=1, step=1, label="Retries (no gold here, heuristic only)") |
| test_output = gr.Textbox(label="Answer", lines=3) |
| test_button = gr.Button("Test Question") |
|
|
| test_button.click( |
| fn=test_single_question, |
| inputs=[test_input, test_retries], |
| outputs=test_output |
| ) |
|
|
| with gr.Tab("Local Judge (manual)"): |
| lj_q = gr.Textbox(label="Question", lines=3) |
| lj_pred = gr.Textbox(label="Predicted (your FINAL ANSWER)", lines=1) |
| lj_gold_or_id = gr.Textbox(label="Task ID (to fetch gold) OR paste a Gold answer", lines=1) |
| lj_out = gr.Textbox(label="Judge Result (JSON)", lines=8) |
| gr.Button("Judge Now").click(local_judge_single, inputs=[lj_q, lj_pred, lj_gold_or_id], outputs=lj_out) |
|
|
| with gr.Tab("Full Evaluation"): |
| with gr.Row(): |
| sample_size = gr.Slider( |
| minimum=0, |
| maximum=20, |
| value=0, |
| step=1, |
| label="Sample Size (0 for all questions)", |
| info="Set a number to limit how many questions to process (reduces costs)" |
| ) |
| max_retries = gr.Slider( |
| minimum=0, |
| maximum=3, |
| value=1, |
| step=1, |
| label="Max judge-driven retries per question", |
| info="0 = no retries; 1-3 = progressively more effort" |
| ) |
| use_local = gr.Checkbox( |
| value=True, |
| label="Use local judge (gold) to pick best attempt when available", |
| info="If unchecked, we submit the last attempt instead." |
| ) |
|
|
| run_button = gr.Button("Run Evaluation, Judge Locally, Retry & Submit") |
| status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) |
| results_table = gr.DataFrame(label="Final Results (per question)", wrap=True) |
| attempts_table = gr.DataFrame(label="Attempt Log (expanded)", wrap=True) |
|
|
| run_button.click( |
| fn=run_and_submit_all, |
| inputs=[sample_size, max_retries, use_local], |
| outputs=[status_output, results_table, attempts_table] |
| ) |
|
|
| if __name__ == "__main__": |
| print("\n" + "-"*30 + " GAIA Agent Starting " + "-"*30) |
|
|
| |
| api_key = os.environ.get("OPENAI_API_KEY") |
| if not api_key: |
| print("WARNING: OpenAI API key not found. Please set OPENAI_API_KEY environment variable.") |
| else: |
| print("OpenAI API key found.") |
|
|
| |
| space_host = os.getenv("SPACE_HOST") |
| space_id = os.getenv("SPACE_ID") |
|
|
| if space_host: |
| print(f"✅ Running in Hugging Face Space: {space_host}") |
| print(f" Runtime URL: https://{space_host}.hf.space") |
| else: |
| print("ℹ️ Running locally") |
|
|
| if space_id: |
| print(f"✅ Space ID: {space_id}") |
| print(f" Repo URL: https://huggingface.co/spaces/{space_id}") |
| print(f" Code URL: https://huggingface.co/spaces/{space_id}/tree/main") |
|
|
| print("-"*78 + "\n") |
|
|
| if Path(DEFAULT_GOLD_CSV).exists(): |
| print(f"Local gold answers found at: {DEFAULT_GOLD_CSV}") |
| else: |
| print(f"No local gold CSV found at: {DEFAULT_GOLD_CSV} (judge will skip gold for unknown tasks)") |
|
|
| print("Launching Gradio Interface...") |
| demo.launch(debug=True) |