Spaces:
Sleeping
Sleeping
| import ast | |
| import io | |
| import mimetypes | |
| import os | |
| import re | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| import pandas as pd | |
| import requests | |
| from dotenv import load_dotenv | |
| from google import genai | |
| from google.genai import types as genai_types | |
| from smolagents import CodeAgent, DuckDuckGoSearchTool, OpenAIModel, Tool | |
| load_dotenv() | |
| DEFAULT_SCORING_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| DEFAULT_GEMINI_MODEL_ID = os.getenv("GEMINI_MODEL_ID", "gemini/gemini-2.0-flash") | |
| GEMINI_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/" | |
| GAIA_OUTPUT_PROMPT = """ | |
| You are an evaluation-grade GAIA benchmark AI assistant. Your execution must be perfectly precise and highly deterministic. Your sole purpose is to isolate and output the exact, minimal final answer. | |
| ### CORE DIRECTIVE | |
| You must NEVER output explanations, intermediate reasoning, conversational filler, or comments. | |
| Your entire output must consist ONLY of the final answer, strictly enclosed within square brackets. | |
| Format: [FINAL_ANSWER] | |
| ### DATA FORMATTING RULES | |
| 1. NUMERICAL DATA: | |
| - Output digits only (e.g., [4], not [four]). | |
| - Exclude commas, currency symbols, percentage signs, or physical units unless the prompt explicitly dictates their inclusion. | |
| - Strip all approximation qualifiers (e.g., "around", "roughly", "~"). | |
| 2. STRING DATA: | |
| - Omit definite and indefinite articles ("a", "an", "the"). | |
| - Use full, unabbreviated words unless the prompt specifically requests an abbreviation. | |
| 3. LISTS AND SETS: | |
| - Output as a single comma-separated string with exactly one space after each comma (e.g., [a, b, c]). | |
| - Exclude conjunctions (e.g., do not use "and" or "or"). | |
| - Exclude internal list wrappers (do not include {} or () inside the main answer brackets). | |
| - Sort elements alphabetically or numerically in ascending order unless the prompt specifies an alternative sorting logic. | |
| ### EXECUTION PROTOCOL | |
| 1. SOURCE EXTRACTION: When processing data from web searches, files (via `run_query_with_file`), or video tools, extract only the atomic fact that satisfies the query. Do not summarize or quote surrounding context. | |
| 2. LITERALISM: Default to the narrowest, most literal interpretation of the prompt. Do not synthesize assumptions. | |
| 3. FALLBACK: If the requisite data to answer is demonstrably absent after exhaustive search, output exactly: [unknown] | |
| ### EXAMPLES | |
| Q: What is 2 + 2? | |
| A: [4] | |
| Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia. | |
| A: [3] | |
| Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity. | |
| A: [b, e] | |
| Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season? | |
| A: [519] | |
| """.strip() | |
| def _get_gemini_api_key() -> str: | |
| api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "Missing Gemini API key. Set GEMINI_API_KEY in your Hugging Face Space secrets." | |
| ) | |
| return api_key | |
| def _normalize_model_id(model_id: str | None = None) -> str: | |
| selected = model_id or DEFAULT_GEMINI_MODEL_ID | |
| if selected.startswith("gemini/"): | |
| return selected.split("/", 1)[1] | |
| return selected | |
| def _genai_client() -> genai.Client: | |
| return genai.Client(api_key=_get_gemini_api_key()) | |
| def _extract_text(response) -> str: | |
| text = getattr(response, "text", None) | |
| if text: | |
| return text.strip() | |
| parts: list[str] = [] | |
| for candidate in getattr(response, "candidates", []) or []: | |
| content = getattr(candidate, "content", None) | |
| for part in getattr(content, "parts", []) or []: | |
| piece = getattr(part, "text", None) | |
| if piece: | |
| parts.append(piece) | |
| return "\n".join(parts).strip() | |
| def _call_gemini_text(prompt: str, system_instruction: str | None = None) -> str: | |
| config = genai_types.GenerateContentConfig(temperature=0) | |
| if system_instruction: | |
| config.system_instruction = system_instruction | |
| response = _genai_client().models.generate_content( | |
| model=_normalize_model_id(), | |
| contents=prompt, | |
| config=config, | |
| ) | |
| return _extract_text(response) | |
| def _decode_text_bytes(payload: bytes) -> str: | |
| for encoding in ("utf-8", "utf-8-sig", "cp1252", "latin-1"): | |
| try: | |
| return payload.decode(encoding) | |
| except UnicodeDecodeError: | |
| continue | |
| return payload.decode("utf-8", errors="replace") | |
| def _download_task_file(task_id: str) -> tuple[bytes, str]: | |
| response = requests.get(f"{DEFAULT_SCORING_API_URL}/files/{task_id}", timeout=60) | |
| response.raise_for_status() | |
| content_type = response.headers.get("content-type", "application/octet-stream") | |
| return response.content, content_type | |
| def _mime_from_name(file_name: str | None, fallback: str) -> str: | |
| guessed, _ = mimetypes.guess_type(file_name or "") | |
| return guessed or fallback or "application/octet-stream" | |
| def _wait_until_active(client: genai.Client, uploaded_file) -> None: | |
| state = getattr(uploaded_file, "state", None) | |
| state_name = getattr(state, "name", None) | |
| while state_name and state_name != "ACTIVE": | |
| if state_name == "FAILED": | |
| raise RuntimeError("Gemini file processing failed.") | |
| time.sleep(3) | |
| uploaded_file = client.files.get(name=uploaded_file.name) | |
| state = getattr(uploaded_file, "state", None) | |
| state_name = getattr(state, "name", None) | |
| def _query_uploaded_file(file_bytes: bytes, file_name: str, mime_type: str, user_query: str) -> str: | |
| client = _genai_client() | |
| suffix = Path(file_name or "attachment").suffix | |
| temp_path = None | |
| uploaded_file = None | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
| temp_file.write(file_bytes) | |
| temp_path = temp_file.name | |
| uploaded_file = client.files.upload( | |
| file=temp_path, | |
| config={"mime_type": mime_type}, | |
| ) | |
| _wait_until_active(client, uploaded_file) | |
| response = client.models.generate_content( | |
| model=_normalize_model_id(), | |
| contents=[uploaded_file, user_query], | |
| config=genai_types.GenerateContentConfig(temperature=0), | |
| ) | |
| return _extract_text(response) | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| if uploaded_file and getattr(uploaded_file, "name", None): | |
| try: | |
| client.files.delete(name=uploaded_file.name) | |
| except Exception: | |
| pass | |
| def _excel_to_text(file_bytes: bytes) -> str: | |
| workbook = pd.read_excel(io.BytesIO(file_bytes), sheet_name=None) | |
| sections: list[str] = [] | |
| for sheet_name, frame in workbook.items(): | |
| csv_text = frame.fillna("").to_csv(index=False) | |
| sections.append(f"Sheet: {sheet_name}\n{csv_text}") | |
| return "\n\n".join(sections) | |
| def _textual_file_answer(file_text: str, user_query: str) -> str: | |
| prompt = ( | |
| "You are analyzing an attached file for a GAIA benchmark question.\n" | |
| "Answer the question using only the file contents below.\n" | |
| "Return only the direct final answer, with no explanation.\n\n" | |
| f"Question:\n{user_query}\n\n" | |
| f"File contents:\n{file_text[:50000]}" | |
| ) | |
| return _call_gemini_text(prompt) | |
| def _normalize_riddle_prompt(prompt: str) -> str: | |
| stripped = prompt.strip() | |
| if stripped and stripped.count(" ") > 3: | |
| weird_ratio = sum(char in ".,!?;:'\"()-" for char in stripped) / max(len(stripped), 1) | |
| if weird_ratio < 0.2: | |
| reversed_candidate = stripped[::-1] | |
| if re.search(r"\b(the|and|you|write|understand|sentence)\b", reversed_candidate.lower()): | |
| return reversed_candidate | |
| return stripped | |
| def _normalize_final_answer(raw_answer: str) -> str: | |
| cleaned = raw_answer.strip() | |
| if cleaned.startswith("[") and cleaned.endswith("]") and len(cleaned) >= 2: | |
| inner = cleaned[1:-1].strip() | |
| if inner: | |
| return inner | |
| return cleaned | |
| class MathSolver(Tool): | |
| name = "math_solver" | |
| description = "Evaluate arithmetic expressions with operators like +, -, *, /, //, %, and **." | |
| inputs = { | |
| "expression": { | |
| "type": "string", | |
| "description": "The arithmetic expression to evaluate.", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, expression: str) -> str: | |
| def eval_node(node: ast.AST) -> int | float: | |
| if isinstance(node, ast.Expression): | |
| return eval_node(node.body) | |
| if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): | |
| return node.value | |
| if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)): | |
| operand = eval_node(node.operand) | |
| return operand if isinstance(node.op, ast.UAdd) else -operand | |
| if isinstance(node, ast.BinOp) and isinstance( | |
| node.op, (ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow) | |
| ): | |
| left = eval_node(node.left) | |
| right = eval_node(node.right) | |
| if isinstance(node.op, ast.Add): | |
| return left + right | |
| if isinstance(node.op, ast.Sub): | |
| return left - right | |
| if isinstance(node.op, ast.Mult): | |
| return left * right | |
| if isinstance(node.op, ast.Div): | |
| return left / right | |
| if isinstance(node.op, ast.FloorDiv): | |
| return left // right | |
| if isinstance(node.op, ast.Mod): | |
| return left % right | |
| return left**right | |
| raise ValueError("Unsupported expression.") | |
| try: | |
| parsed = ast.parse(expression, mode="eval") | |
| result = eval_node(parsed) | |
| except Exception as exc: | |
| return f"Math error: {exc}" | |
| if isinstance(result, float) and result.is_integer(): | |
| return str(int(result)) | |
| return str(result) | |
| class RiddleSolver(Tool): | |
| name = "riddle_solver" | |
| description = "Solve riddles, wordplay, or short trick questions when the wording matters." | |
| inputs = { | |
| "prompt": { | |
| "type": "string", | |
| "description": "The riddle or wordplay prompt to solve.", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, prompt: str) -> str: | |
| normalized_prompt = _normalize_riddle_prompt(prompt) | |
| return _call_gemini_text( | |
| normalized_prompt, | |
| system_instruction=( | |
| "Solve the user's riddle or wordplay. Return only the direct answer with no explanation." | |
| ), | |
| ) | |
| class TextTransformer(Tool): | |
| name = "text_transformer" | |
| description = "Apply deterministic text transforms like reverse, upper, lower, title, strip, or swapcase." | |
| inputs = { | |
| "text": { | |
| "type": "string", | |
| "description": "The source text.", | |
| }, | |
| "operation": { | |
| "type": "string", | |
| "description": "One of: reverse, upper, lower, title, strip, swapcase.", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, text: str, operation: str) -> str: | |
| normalized_operation = operation.strip().lower() | |
| if normalized_operation == "reverse": | |
| return text[::-1] | |
| if normalized_operation == "upper": | |
| return text.upper() | |
| if normalized_operation == "lower": | |
| return text.lower() | |
| if normalized_operation == "title": | |
| return text.title() | |
| if normalized_operation == "strip": | |
| return text.strip() | |
| if normalized_operation == "swapcase": | |
| return text.swapcase() | |
| return "Unsupported operation." | |
| class GeminiVideoQA(Tool): | |
| name = "gemini_video_qa" | |
| description = "Answer questions about a public video URL, including YouTube links, using Gemini multimodal analysis." | |
| inputs = { | |
| "video_url": { | |
| "type": "string", | |
| "description": "The public video URL to inspect.", | |
| }, | |
| "user_query": { | |
| "type": "string", | |
| "description": "The exact question to answer about the video.", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, video_url: str, user_query: str) -> str: | |
| response = _genai_client().models.generate_content( | |
| model=_normalize_model_id(), | |
| contents=genai_types.Content( | |
| role="user", | |
| parts=[ | |
| genai_types.Part( | |
| file_data=genai_types.FileData(file_uri=video_url) | |
| ), | |
| genai_types.Part(text=user_query), | |
| ], | |
| ), | |
| config=genai_types.GenerateContentConfig(temperature=0), | |
| ) | |
| return _extract_text(response) | |
| class WikiTitleFinder(Tool): | |
| name = "wiki_title_finder" | |
| description = "Find likely English Wikipedia page titles for a topic." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The topic or search phrase to find on English Wikipedia.", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| response = requests.get( | |
| "https://en.wikipedia.org/w/api.php", | |
| params={ | |
| "action": "query", | |
| "list": "search", | |
| "srsearch": query, | |
| "srlimit": 5, | |
| "format": "json", | |
| }, | |
| timeout=20, | |
| ) | |
| response.raise_for_status() | |
| results = response.json().get("query", {}).get("search", []) | |
| if not results: | |
| return "No matching Wikipedia titles found." | |
| return ", ".join(item["title"] for item in results) | |
| class WikiContentFetcher(Tool): | |
| name = "wiki_content_fetcher" | |
| description = "Fetch plain-text content from an English Wikipedia page." | |
| inputs = { | |
| "page_title": { | |
| "type": "string", | |
| "description": "The exact English Wikipedia page title to fetch.", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, page_title: str) -> str: | |
| response = requests.get( | |
| "https://en.wikipedia.org/w/api.php", | |
| params={ | |
| "action": "query", | |
| "prop": "extracts", | |
| "explaintext": 1, | |
| "redirects": 1, | |
| "titles": page_title, | |
| "format": "json", | |
| }, | |
| timeout=20, | |
| ) | |
| response.raise_for_status() | |
| pages = response.json().get("query", {}).get("pages", {}) | |
| for page in pages.values(): | |
| extract = (page or {}).get("extract") | |
| if extract: | |
| return extract[:12000] | |
| return "Wikipedia page not found." | |
| class GoogleSearchTool(Tool): | |
| name = "google_search" | |
| description = "Search the live web using Gemini grounding with Google Search and return a concise result." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The web query to search for.", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| grounding_tool = genai_types.Tool(google_search=genai_types.GoogleSearch()) | |
| response = _genai_client().models.generate_content( | |
| model=_normalize_model_id(), | |
| contents=query, | |
| config=genai_types.GenerateContentConfig( | |
| temperature=0, | |
| tools=[grounding_tool], | |
| ), | |
| ) | |
| return _extract_text(response) | |
| class FileAttachmentQueryTool(Tool): | |
| name = "run_query_with_file" | |
| description = ( | |
| "Download an attached GAIA benchmark file by task_id, inspect it, and answer a question about it." | |
| ) | |
| inputs = { | |
| "task_id": { | |
| "type": "string", | |
| "description": "The GAIA task identifier used to download the attachment.", | |
| }, | |
| "file_name": { | |
| "type": "string", | |
| "description": "The attachment file name, including the extension.", | |
| }, | |
| "user_query": { | |
| "type": "string", | |
| "description": "The question to answer about the attached file.", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, task_id: str, file_name: str, user_query: str) -> str: | |
| file_bytes, content_type = _download_task_file(task_id) | |
| suffix = Path(file_name or "").suffix.lower() | |
| mime_type = _mime_from_name(file_name, content_type) | |
| if suffix in {".txt", ".md", ".json", ".csv", ".py", ".html", ".xml", ".yaml", ".yml", ".log"}: | |
| file_text = _decode_text_bytes(file_bytes) | |
| return _textual_file_answer(file_text, user_query) | |
| if suffix in {".xlsx", ".xls"}: | |
| file_text = _excel_to_text(file_bytes) | |
| return _textual_file_answer(file_text, user_query) | |
| return _query_uploaded_file(file_bytes, file_name, mime_type, user_query) | |
| class BasicAgent: | |
| def __init__(self): | |
| _get_gemini_api_key() | |
| self.agent = CodeAgent( | |
| model=OpenAIModel( | |
| model_id=_normalize_model_id(), | |
| api_base=GEMINI_OPENAI_BASE_URL, | |
| api_key=_get_gemini_api_key(), | |
| temperature=0, | |
| ), | |
| tools=[ | |
| MathSolver(), | |
| RiddleSolver(), | |
| TextTransformer(), | |
| GeminiVideoQA(), | |
| WikiTitleFinder(), | |
| WikiContentFetcher(), | |
| GoogleSearchTool(), | |
| DuckDuckGoSearchTool(), | |
| FileAttachmentQueryTool(), | |
| ], | |
| add_base_tools=False, | |
| max_steps=8, | |
| ) | |
| self.agent.prompt_templates["system_prompt"] += ( | |
| "\n\n" | |
| f"{GAIA_OUTPUT_PROMPT}\n\n" | |
| "Additional tool routing rules:\n" | |
| "- If attachment metadata is present, use run_query_with_file.\n" | |
| "- If a public video URL is present, use gemini_video_qa.\n" | |
| "- Use google_search or web_search for live web facts.\n" | |
| "- Use wiki_title_finder and wiki_content_fetcher when the prompt explicitly asks for Wikipedia.\n" | |
| "- Use text_transformer for reversal or casing tasks and math_solver for arithmetic." | |
| ) | |
| def __call__(self, question: str, task_id: str | None = None, file_name: str | None = None) -> str: | |
| prompt_parts = [question.strip()] | |
| if task_id and file_name: | |
| prompt_parts.append( | |
| "\nAttachment metadata:\n" | |
| f"- task_id: {task_id}\n" | |
| f"- file_name: {file_name}\n" | |
| "Use run_query_with_file if the question requires the attachment." | |
| ) | |
| result = self.agent.run("\n".join(prompt_parts).strip()) | |
| return _normalize_final_answer(str(result)) | |
| if __name__ == "__main__": | |
| sample_question = ( | |
| "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? " | |
| "You can use the latest 2022 version of english wikipedia." | |
| ) | |
| print(BasicAgent()(sample_question)) | |