""" GAIA Benchmark Agent - Refactored Version Improvements: - Better error handling with retry logic - Caching for expensive operations - Telemetry and progress tracking - Modular architecture - Parallel processing support - Memory management """ import gc import os import io # Workaround: Gradio 5.x bug where Queue.pending_message_lock stays None if the # ASGI lifespan startup events don't fire (a Python 3.13 asyncio compatibility issue). # Patch Queue.push to lazily initialize the lock before its first use. try: import asyncio as _asyncio from gradio.queueing import Queue as _GradioQueue _orig_push = _GradioQueue.push async def _patched_push(self, *args, **kwargs): if getattr(self, "pending_message_lock", None) is None: self.pending_message_lock = _asyncio.Lock() return await _orig_push(self, *args, **kwargs) _GradioQueue.push = _patched_push print("✅ Applied Gradio queue lock workaround") except Exception as _patch_err: print(f"ℹ️ Gradio queue patch skipped: {_patch_err}") import subprocess import json import re import traceback import contextlib import uuid import time import ast from typing import List, Optional, TypedDict, Annotated, Dict, Tuple from pathlib import Path from collections import Counter, defaultdict from functools import wraps, lru_cache import gradio as gr import pandas as pd import numpy as np import torch from pydantic import BaseModel, Field # Multimodal & Web Tools import chess import chess.engine from transformers import pipeline from youtube_transcript_api import YouTubeTranscriptApi from bs4 import BeautifulSoup import requests from PIL import Image import base64 from googleapiclient.discovery import build from googleapiclient.errors import HttpError import assemblyai as aai # LangChain & LangGraph from langgraph.graph.message import add_messages from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, AnyMessage, ToolCall from langchain_core.tools import tool from langgraph.prebuilt import ToolNode from langgraph.graph import START, END, StateGraph from langchain_groq import ChatGroq from langchain_google_genai import ChatGoogleGenerativeAI # RAG from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.documents import Document # ============================================================================= # CONFIGURATION # ============================================================================= class Config: DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" MAX_TURNS = 25 MAX_MESSAGE_LENGTH = 8000 REFLECT_EVERY_N_TURNS = 5 MAX_RETRIES = 3 BASE_RETRY_DELAY = 1 CACHE_SIZE = 100 MAX_PARALLEL_WORKERS = 3 CHUNK_SIZE = 500 CHUNK_OVERLAP = 50 config = Config() # ============================================================================= # UTILITIES: RETRY & CACHING # ============================================================================= def retry_with_backoff(max_retries=None, base_delay=None): """Decorator for automatic retry with exponential backoff""" max_retries = max_retries or Config.MAX_RETRIES base_delay = base_delay or config.BASE_RETRY_DELAY def decorator(func): @wraps(func) def wrapper(*args, **kwargs): for attempt in range(max_retries): try: return func(*args, **kwargs) except Exception as e: if attempt == max_retries - 1: raise delay = base_delay * (2 ** attempt) print(f"⚠️ {func.__name__} retry {attempt+1}/{max_retries} after {delay}s: {e}") time.sleep(delay) return wrapper return decorator def normalize_answer(answer: str, question: str = "") -> str: """ Normalize answer to match expected format. Args: answer: The answer to normalize question: Optional question text to determine if order matters """ if not answer: return answer original = answer answer = answer.strip() # Remove common prefixes prefixes_to_remove = [ "the answer is:", "the answer is", "answer:", "final answer:", "result:", ] for prefix in prefixes_to_remove: if answer.lower().startswith(prefix): answer = answer[len(prefix):].strip() # Handle lists if "," in answer: items = [item.strip() for item in answer.split(",")] items = [item for item in items if item] # Determine if order matters based on question order_matters_keywords = [ "first", "last", "before", "after", "sequence", "order", "chronological", "oldest", "newest", "in the form", "format" ] order_matters = any(kw in question.lower() for kw in order_matters_keywords) if not order_matters: # Sort alphabetically for consistency items.sort() print(f" 📋 Sorted list alphabetically (order doesn't seem to matter)") else: print(f" 📋 Kept original order (question specifies order)") # Normalize each item items = [item.strip().rstrip('.') for item in items] # Consistent spacing answer = ", ".join(items) # Single word capitalization if len(answer.split()) == 1: if answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']: answer = answer.capitalize() # Handle "St." vs "Saint" if "without abbreviations" in question.lower(): answer = answer.replace("St.", "Saint") answer = answer.replace("Dr.", "Doctor") answer = answer.replace("Mt.", "Mount") # Remove trailing period (unless decimal) if answer.endswith('.') and not (len(answer) > 1 and answer[-2].isdigit()): answer = answer[:-1] # Remove wrapping quotes if (answer.startswith('"') and answer.endswith('"')) or \ (answer.startswith("'") and answer.endswith("'")): answer = answer[1:-1] return answer class SearchCache: """LRU cache for search results""" def __init__(self, maxsize=None): self.maxsize = maxsize or config.CACHE_SIZE self._cache = {} self._access_order = [] def get(self, key: str) -> Optional[str]: if key in self._cache: # Move to end (most recently used) self._access_order.remove(key) self._access_order.append(key) return self._cache[key] return None def put(self, key: str, value: str): if key in self._cache: self._access_order.remove(key) elif len(self._cache) >= self.maxsize: # Remove least recently used oldest = self._access_order.pop(0) del self._cache[oldest] self._cache[key] = value self._access_order.append(key) def clear(self): self._cache.clear() self._access_order.clear() search_cache = SearchCache() # ============================================================================= # TELEMETRY # ============================================================================= class Telemetry: """Track tool usage, timing, and errors""" def __init__(self): self.tool_times = defaultdict(list) self.tool_errors = defaultdict(int) self.tool_calls = defaultdict(int) self.start_time = time.time() def record_call(self, tool_name: str, duration: float, success: bool): self.tool_calls[tool_name] += 1 self.tool_times[tool_name].append(duration) if not success: self.tool_errors[tool_name] += 1 def report(self): total_time = time.time() - self.start_time print(f"\n{'='*70}") print(f"📊 TELEMETRY REPORT") print(f"{'='*70}") print(f"Total runtime: {total_time:.2f}s") print(f"\nTool Usage:") for tool in sorted(self.tool_calls.keys()): calls = self.tool_calls[tool] times = self.tool_times[tool] errors = self.tool_errors[tool] avg_time = sum(times) / len(times) if times else 0 print(f" {tool}:") print(f" Calls: {calls}") print(f" Avg time: {avg_time:.2f}s") print(f" Errors: {errors}") print(f"{'='*70}\n") def reset(self): self.tool_times.clear() self.tool_errors.clear() self.tool_calls.clear() self.start_time = time.time() telemetry = Telemetry() # ============================================================================= # PROGRESS TRACKER # ============================================================================= class ProgressTracker: """Track question processing progress""" def __init__(self, total: int): self.total = total self.current = 0 self.correct = 0 self.start_time = time.time() def update(self, is_correct: bool): self.current += 1 if is_correct: self.correct += 1 accuracy = (self.correct / self.current) * 100 if self.current > 0 else 0 elapsed = time.time() - self.start_time avg_time = elapsed / self.current if self.current > 0 else 0 eta = avg_time * (self.total - self.current) print(f"📊 Progress: {self.current}/{self.total} ({self.current/self.total*100:.1f}%)") print(f" Accuracy: {accuracy:.1f}% ({self.correct} correct)") print(f" Avg time: {avg_time:.1f}s per question") print(f" ETA: {eta/60:.1f} minutes") # ============================================================================= # CUSTOM EXCEPTIONS # ============================================================================= class ToolError(Exception): """Custom exception with context""" def __init__(self, tool_name: str, error: Exception, suggestion: str = ""): self.tool_name = tool_name self.original_error = error self.suggestion = suggestion message = f"Tool '{tool_name}' failed: {error}" if suggestion: message += f"\n💡 Suggestion: {suggestion}" super().__init__(message) # ============================================================================= # GLOBAL RAG COMPONENTS # ============================================================================= class RAGManager: """Manage RAG components with lazy initialization""" def __init__(self): self.embeddings = None self.text_splitter = None self._initialized = False def initialize(self): if self._initialized: return True print("Initializing RAG components...") try: self.embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'} ) self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=config.CHUNK_SIZE, chunk_overlap=config.CHUNK_OVERLAP, length_function=len, separators=["\n\n", "\n", ". ", " ", ""] ) self._initialized = True print("✅ RAG components initialized") return True except Exception as e: print(f"❌ RAG initialization failed: {e}") return False def is_ready(self): return self._initialized rag_manager = RAGManager() # ============================================================================= # ASR INITIALIZATION # ============================================================================= class ASRManager: """Manage ASR pipeline""" def __init__(self): self.pipeline = None self._initialized = False def initialize(self): if self._initialized: return True try: print("Loading ASR (Whisper) pipeline...") device = 0 if torch.cuda.is_available() else -1 device_name = "cuda:0" if device == 0 else "cpu" print(f"Using device: {device_name}") self.pipeline = pipeline( "automatic-speech-recognition", model="openai/whisper-base", torch_dtype=torch.float16 if device == 0 else torch.float32, device=device ) self._initialized = True print("✅ ASR pipeline loaded") return True except Exception as e: print(f"⚠️ ASR pipeline failed to load: {e}") return False def is_ready(self): return self._initialized asr_manager = ASRManager() # ============================================================================= # ANSWER VALIDATION # ============================================================================= class AnswerValidator: """Validate and check answers""" @staticmethod def load_answer_sheet(filepath: str = "answer_sheet_json.json") -> Dict[str, str]: """Load answer sheet""" try: if os.path.exists(filepath): with open(filepath, 'r', encoding='utf-8') as f: answers = json.load(f) print(f"✅ Loaded {len(answers)} answers from {filepath}") return answers else: print(f"⚠️ Answer sheet not found: {filepath}") return {} except Exception as e: print(f"❌ Error loading answer sheet: {e}") return {} @staticmethod def check_correctness(submitted: str, correct: str) -> Tuple[bool, str]: """Check if answer is correct with fuzzy matching""" import string submitted_norm = submitted.strip().lower() correct_norm = correct.strip().lower() # Exact match if submitted_norm == correct_norm: return True, "✅ EXACT MATCH" # Remove punctuation trans = str.maketrans('', '', string.punctuation) submitted_clean = submitted_norm.translate(trans) correct_clean = correct_norm.translate(trans) if submitted_clean == correct_clean: return True, "✅ MATCH (punctuation)" # Numeric comparison try: submitted_num = float(submitted_clean.replace(',', '').replace('$', '')) correct_num = float(correct_clean.replace(',', '').replace('$', '')) if abs(submitted_num - correct_num) < 0.01: return True, "✅ MATCH (numeric)" except (ValueError, AttributeError): pass # List comparison if ',' in correct_norm: correct_items = set(item.strip() for item in correct_norm.split(',')) submitted_items = set(item.strip() for item in submitted_norm.split(',')) if correct_items == submitted_items: return True, "✅ MATCH (order)" missing = correct_items - submitted_items extra = submitted_items - correct_items if missing or extra: msg = [] if missing: msg.append(f"MISSING: {', '.join(missing)}") if extra: msg.append(f"EXTRA: {', '.join(extra)}") return False, f"❌ {' | '.join(msg)}" # Partial match if submitted_norm in correct_norm or correct_norm in submitted_norm: return False, f"❌ PARTIAL ('{submitted}' vs '{correct}')" return False, f"❌ WRONG ('{submitted}' vs '{correct}')" # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= def remove_fences_simple(text: str) -> str: """Remove code fences""" text = text.strip() if text.startswith("```") and text.endswith("```"): text = text[3:-3].strip() if '\n' in text: first_line, rest = text.split('\n', 1) if first_line.strip().replace('_','').isalnum() and len(first_line.strip()) < 15: text = rest.strip() return text def truncate_if_needed(content: str, max_length: int = None) -> str: """Truncate long content""" max_length = max_length or config.MAX_MESSAGE_LENGTH if len(content) > max_length: return content[:max_length] + f"\n...[truncated, {len(content)} chars total]" return content def find_file(path: str) -> Optional[Path]: """Find file with multiple path attempts""" script_dir = Path.cwd() safe_path = Path(path).as_posix() paths = [ script_dir / safe_path, Path(safe_path), script_dir / Path(path).name, Path("files") / Path(path).name ] for p in paths: if p.exists(): return p return None # ============================================================================= # TOOL INPUT VALIDATION # ============================================================================= def validate_tool_inputs(tool_name: str, inputs: dict) -> Tuple[bool, str]: """Validate tool inputs before execution""" validators = { "scrape_and_retrieve": lambda i: i.get("url", "").startswith(("http://", "https://")), "calculator": lambda i: bool(re.match(r'^[\d\+\-\*/\(\)\s\.,a-z]+$', i.get("expression", ""), re.I)), "read_file": lambda i: len(i.get("path", "")) > 0 and ".." not in i.get("path", ""), "search_tool": lambda i: len(i.get("query", "").strip()) > 0, "code_interpreter": lambda i: "import os" not in i.get("code", "").lower(), } if tool_name in validators: try: if not validators[tool_name](inputs): return False, f"Invalid input format for {tool_name}" except Exception as e: return False, f"Validation error: {e}" return True, "" # ============================================================================= # PLANNING & REFLECTION TOOLS # ============================================================================= class ThinkInput(BaseModel): reasoning: str = Field(description="Brief reasoning (under 150 chars)") @tool(args_schema=ThinkInput) def think_through_logic(reasoning: str) -> str: """ Use ONLY for logic puzzles and riddles (NOT research questions). Use for: - Brain teasers, logic puzzles, riddles DON'T use for: - Research questions → use search_tool or wikipedia_search - Math → use calculator - File analysis → use file tools """ print(f"🧠 Thinking: {reasoning[:100]}...") return f"""✅ Reasoning: {reasoning[:100]} ⚠️ DO NOT CALL think_through_logic AGAIN! For research → use search_tool() or wikipedia_search() For math → use calculator() If you have answer → use final_answer_tool() TAKE ACTION NOW!""" class PlanInput(BaseModel): task_summary: str = Field(description="Brief task summary (under 80 chars)") @tool(args_schema=PlanInput) def create_plan(task_summary: str) -> str: """Create plan for complex tasks""" start_time = time.time() try: print(f"📋 Planning: {task_summary[:80]}...") result = f"""✅ Plan: {task_summary} Framework: 1. What info needed? 2. Which tools? 3. What order? Execute step 1 now.""" telemetry.record_call("create_plan", time.time() - start_time, True) return result except Exception as e: telemetry.record_call("create_plan", time.time() - start_time, False) raise class ReflectInput(BaseModel): situation: str = Field(description="Brief situation (under 80 chars)") @tool(args_schema=ReflectInput) def reflect_on_progress(situation: str) -> str: """Reflect when stuck""" start_time = time.time() try: print(f"🤔 Reflecting: {situation[:80]}...") result = f"""🔍 Reflection: {situation} Questions: 1. Right approach? 2. Try different tool? 3. Have answer already? Try DIFFERENT approach now.""" telemetry.record_call("reflect_on_progress", time.time() - start_time, True) return result except Exception as e: telemetry.record_call("reflect_on_progress", time.time() - start_time, False) raise class ValidateAnswerInput(BaseModel): proposed_answer: str = Field(description="Answer to validate") original_question: str = Field(description="Original question (first 100 chars)") @tool(args_schema=ValidateAnswerInput) def validate_answer(proposed_answer: str, original_question: str = "") -> str: """ Validate answer format and provide warnings. Returns validation result with normalization suggestions. """ start_time = time.time() try: print(f"✓ Validating: '{proposed_answer[:50]}...'") warnings = [] errors = [] normalization_needed = [] # Normalize for validation normalized = normalize_answer(proposed_answer) if normalized != proposed_answer: normalization_needed.append(f"Consider using normalized form: '{normalized}'") # Check 1: Empty answer if not proposed_answer or not proposed_answer.strip(): errors.append("Answer is empty") # Check 2: Too long (probably explaining instead of answering) if len(proposed_answer) > 200: warnings.append("Answer is very long (>200 chars). Consider if question asks for brief response.") # Check 3: Contains question words question_words = ['what', 'who', 'when', 'where', 'why', 'how', 'which'] if any(word in proposed_answer.lower() for word in question_words): warnings.append("Answer contains question words. Make sure you're providing the answer, not rephrasing the question.") # Check 4: List ordering if "," in proposed_answer: items = [item.strip() for item in proposed_answer.split(",")] if len(items) > 1: warnings.append(f"List detected with {len(items)} items. Verify order matches question requirements.") # Check 5: Capitalization consistency if proposed_answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']: if not proposed_answer[0].isupper(): normalization_needed.append(f"Consider capitalizing: '{proposed_answer.capitalize()}'") # Check 6: Abbreviations if any(abbrev in proposed_answer.lower() for abbrev in ['st.', 'dr.', 'mt.']): if "without abbreviations" in str(proposed_answer).lower() or "full" in str(proposed_answer).lower(): warnings.append("Question may ask for full form without abbreviations") # Check 7: Spacing in lists if "," in proposed_answer: # Check for inconsistent spacing if ", " in proposed_answer and "," in proposed_answer.replace(", ", ""): normalization_needed.append("Inconsistent spacing in list. Use consistent ', ' format") # Build result result_parts = [] if errors: result_parts.append("🚫 VALIDATION FAILED:") for error in errors: result_parts.append(f"❌ {error}") result_parts.append("Fix issues then retry validation.") else: result_parts.append("✅ VALIDATION PASSED!") if normalization_needed: result_parts.append("\n💡 NORMALIZATION SUGGESTIONS:") for suggestion in normalization_needed: result_parts.append(f" • {suggestion}") if warnings: result_parts.append("\n⚠️ WARNINGS:") for warning in warnings: result_parts.append(f"⚠️ {warning}") result_parts.append("Proceed if confident, or refine answer.") else: result_parts.append("Call final_answer_tool() now.") result = "\n".join(result_parts) telemetry.record_call("validate_answer", time.time() - start_time, True) return result except Exception as e: telemetry.record_call("validate_answer", time.time() - start_time, False) raise ToolError("validate_answer", e) # ============================================================================= # CORE TOOLS # ============================================================================= class WikipediaInput(BaseModel): query: str = Field(description="Topic to search (just the subject name)") @tool(args_schema=WikipediaInput) def wikipedia_search(query: str) -> str: """ Search Wikipedia directly. Keep query SHORT! ✅ GOOD: "Mercedes Sosa" ❌ BAD: "Mercedes Sosa discography 2022 Wikipedia version" """ # AGGRESSIVE query cleaning original = query query = query.lower().strip() # Remove these phrases (order matters - longest first!) remove_list = [ "2022 english wikipedia version", "english wikipedia version", "2022 version", "wikipedia version", "latest version", "wikipedia", "wiki", "discography", "site:", " the ", " a ", " an " ] for phrase in remove_list: query = query.replace(phrase, "") # Clean whitespace query = " ".join(query.split()).strip() # Fallback if query too short if len(query) < 2: words = original.split() query = words[0] if words else original print(f"📚 Wikipedia: '{original}' → '{query}'") # Try direct page page_name = query.title().replace(" ", "_") page_url = f"https://en.wikipedia.org/wiki/{page_name}" print(f" Trying: {page_url}") try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(page_url, headers=headers, timeout=10) if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') title_tag = soup.find('h1', class_='firstHeading') title = title_tag.get_text() if title_tag else page_name content_div = soup.find('div', class_='mw-parser-output') preview = "" if content_div: paragraphs = content_div.find_all('p', limit=3) for p in paragraphs: text = p.get_text().strip() if len(text) > 50: preview = text[:300] break result = f"""✅ Found: {title} URL: {page_url} Preview: {preview}... NEXT: Use scrape_and_retrieve(url="{page_url}", query="specific info")""" print(f"✓ Success: {title}") return result else: # Try search print(f" 404, trying search") search_url = f"https://en.wikipedia.org/w/index.php?search={query.replace(' ', '+')}" try: search_resp = requests.get(search_url, headers=headers, timeout=10) if "wikipedia.org/wiki/" in search_resp.url and search_resp.url != search_url: return f"✅ Redirected to: {search_resp.url}\n\nUse scrape_and_retrieve() for details." soup = BeautifulSoup(search_resp.text, 'html.parser') results = soup.find_all('div', class_='mw-search-result-heading', limit=3) if results: formatted = [] for i, result in enumerate(results, 1): link = result.find('a') if link: title = link.get_text() href = link.get('href') full_url = f"https://en.wikipedia.org{href}" formatted.append(f"{i}. {title}\n {full_url}") return "Wikipedia results:\n\n" + "\n\n".join(formatted) + "\n\nUse scrape_and_retrieve() with relevant URL." return f"""No Wikipedia page found for '{query}'. Try: 1. search_tool("{query}") 2. Different search term 3. Check spelling""" except Exception as search_err: return f"Wikipedia search failed. Try search_tool('{query}') instead." except requests.Timeout: return f"Wikipedia timed out. Try search_tool('{query}') instead." except Exception as e: print(f"⚠️ Wikipedia error: {str(e)[:100]}") return f"Wikipedia error. Try search_tool('{query}') instead." class SearchInput(BaseModel): query: str = Field(description="Search query (concise)") @tool(args_schema=SearchInput) @retry_with_backoff(max_retries=3) def search_tool(query: str) -> str: """Web search with caching and language filtering""" start_time = time.time() try: # Input validation is_valid, msg = validate_tool_inputs("search_tool", {"query": query}) if not is_valid: raise ValueError(msg) # Check cache cached = search_cache.get(query) if cached: print(f"🔍 Search (cached): {query}") telemetry.record_call("search_tool", time.time() - start_time, True) return cached # Auto-add Wikipedia filter if 'wikipedia' in query.lower() and 'site:' not in query: query = f"{query} site:wikipedia.org" print(f"🔍 Searching: {query}") # DuckDuckGo doesn't support these params directly, # but we can filter by adding language hints # For English results, add hint to query search = DuckDuckGoSearchRun() # Add language hint to force English results if not any(keyword in query.lower() for keyword in ['lang:', 'region:']): query = f"{query} lang:en" result = search.run(query) if not result or len(result) < 50: result = "No results found. Try different terms." result = truncate_if_needed(result) # Cache result search_cache.put(query, result) telemetry.record_call("search_tool", time.time() - start_time, True) return result except Exception as e: telemetry.record_call("search_tool", time.time() - start_time, False) raise ToolError("search_tool", e, "Try rephrasing query") class CalcInput(BaseModel): expression: str = Field(description="Math expression") @tool(args_schema=CalcInput) def calculator(expression: str) -> str: """Evaluate math expressions""" start_time = time.time() try: # Input validation is_valid, msg = validate_tool_inputs("calculator", {"expression": expression}) if not is_valid: raise ValueError(msg) print(f"🧮 Calculating: {expression}") import math safe_dict = { 'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, 'log': math.log, 'log10': math.log10, 'exp': math.exp, 'pi': math.pi, 'e': math.e, 'abs': abs, 'round': round, 'pow': pow, 'sum': sum, 'min': min, 'max': max } result = eval(expression, {"__builtins__": {}}, safe_dict) telemetry.record_call("calculator", time.time() - start_time, True) return str(result) except Exception as e: telemetry.record_call("calculator", time.time() - start_time, False) raise ToolError("calculator", e, f"Check expression: {expression}") class CodeInput(BaseModel): code: str = Field(description="Python code (MUST use print())") @tool(args_schema=CodeInput) def code_interpreter(code: str) -> str: """Execute Python code""" start_time = time.time() try: # Safety checks dangerous = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system', 'exec('] if any(d in code.lower() for d in dangerous): raise ValueError("Dangerous operation not allowed") if 'open(' in code.lower() and any(m in code for m in ["'w'", '"w"', "'a'", '"a"']): raise ValueError("File writing not allowed, use write_file tool") print(f"💻 Executing code ({len(code)} chars)...") output_stream = io.StringIO() error_stream = io.StringIO() with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream): safe_globals = { "pd": pd, "np": np, "json": json, "re": re, "__builtins__": __builtins__ } exec(code, safe_globals, {}) stdout = output_stream.getvalue() stderr = error_stream.getvalue() if stderr: result = f"Error:\n{stderr}\n\nOutput:\n{stdout}" elif stdout: result = truncate_if_needed(stdout) else: result = "Code executed but no output. Use print()!" telemetry.record_call("code_interpreter", time.time() - start_time, True) return result except Exception as e: telemetry.record_call("code_interpreter", time.time() - start_time, False) raise ToolError("code_interpreter", e, "Check code syntax") class AnalyzeDataInput(BaseModel): file_path: str = Field(description="Path to CSV or Excel file") question: str = Field(description="What to find (e.g., 'count rows where year > 2000')") @tool(args_schema=AnalyzeDataInput) def analyze_data_file(file_path: str, question: str) -> str: """ Analyze CSV/Excel files with automatic data profiling. Generates Python code to answer questions about data files. Better than code_interpreter alone because it: 1. Profiles the data first (columns, types, sample) 2. Generates appropriate pandas code 3. Handles common data issues (encoding, missing values) Use for questions like: - "How many rows have X?" - "What's the sum/average of column Y?" - "Count items grouped by Z" """ start_time = time.time() try: print(f"📊 Analyzing data file: {file_path}") print(f" Question: {question[:100]}...") # Find file data_file = find_file(file_path) if not data_file: raise FileNotFoundError(f"Data file not found: {file_path}") file_ext = data_file.suffix.lower() if file_ext not in ['.csv', '.xlsx', '.xls', '.tsv']: raise ValueError(f"Unsupported file type: {file_ext}. Use .csv, .xlsx, .xls, or .tsv") print(f" File type: {file_ext}") # Generate profiling code profiling_code = f""" import pandas as pd import numpy as np # Load file file_path = r"{data_file}" """ if file_ext == '.csv': profiling_code += """ # Try different encodings for encoding in ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252']: try: df = pd.read_csv(file_path, encoding=encoding) break except: continue """ elif file_ext == '.tsv': profiling_code += """ df = pd.read_csv(file_path, sep='\\t', encoding='utf-8') """ else: # Excel profiling_code += """ df = pd.read_excel(file_path) """ profiling_code += """ # Profile data print("=" * 60) print("DATA PROFILE") print("=" * 60) print(f"Shape: {df.shape[0]} rows × {df.shape[1]} columns") print(f"\\nColumns: {', '.join(df.columns.tolist())}") print(f"\\nData types:") print(df.dtypes) print(f"\\nFirst 3 rows:") print(df.head(3)) print(f"\\nMissing values:") print(df.isnull().sum()) """ # Execute profiling print(f" Profiling data...") output_stream = io.StringIO() error_stream = io.StringIO() with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream): exec(profiling_code, {"pd": pd, "np": np, "__builtins__": __builtins__}) profile_output = output_stream.getvalue() if error_stream.getvalue(): raise RuntimeError(f"Profiling failed: {error_stream.getvalue()}") print(f" Profiling complete") print(profile_output[:500] + "..." if len(profile_output) > 500 else profile_output) # Now generate analysis code based on question analysis_code = profiling_code + f""" # Analysis for: {question} print("\\n" + "=" * 60) print("ANALYSIS RESULT") print("=" * 60) """ # Add intelligent code based on question keywords q_lower = question.lower() if 'count' in q_lower or 'how many' in q_lower: if 'where' in q_lower or 'with' in q_lower: analysis_code += """ # Count rows matching condition # NOTE: Adjust the filter condition based on your needs result = len(df) # Total count print(f"Total rows: {result}") # Example filters (uncomment and modify as needed): # result = len(df[df['column'] > value]) # result = len(df[df['column'].str.contains('text', na=False)]) """ else: analysis_code += """ result = len(df) print(f"Total rows: {result}") """ elif 'sum' in q_lower or 'total' in q_lower: analysis_code += """ # Sum a numeric column # NOTE: Replace 'column_name' with actual column # result = df['column_name'].sum() # print(f"Sum: {result}") """ elif 'average' in q_lower or 'mean' in q_lower: analysis_code += """ # Average of a column # result = df['column_name'].mean() # print(f"Average: {result}") """ elif 'group' in q_lower or 'by' in q_lower: analysis_code += """ # Group by and count # result = df.groupby('column_name').size() # print(result) """ else: # Generic: show summary analysis_code += """ # Summary statistics print(df.describe()) """ result = f"""Data Profile: {profile_output} Generated Analysis Code: ```python {analysis_code} ``` **IMPORTANT**: The code above needs column names adjusted. Use code_interpreter() with the corrected code to get the answer. Columns available: {", ".join((pd.read_csv(data_file) if file_ext == '.csv' else pd.read_excel(data_file)).columns.tolist())} """ telemetry.record_call("analyze_data_file", time.time() - start_time, True) return truncate_if_needed(result) except Exception as e: telemetry.record_call("analyze_data_file", time.time() - start_time, False) raise ToolError("analyze_data_file", e, "Check file path and format") class ReadFileInput(BaseModel): path: str = Field(description="File path") @tool(args_schema=ReadFileInput) def read_file(path: str) -> str: """Read file content""" start_time = time.time() try: # Input validation is_valid, msg = validate_tool_inputs("read_file", {"path": path}) if not is_valid: raise ValueError(msg) print(f"📄 Reading: {path}") file_path = find_file(path) if not file_path: raise FileNotFoundError(f"File not found: {path}") content = file_path.read_text(encoding='utf-8') telemetry.record_call("read_file", time.time() - start_time, True) return truncate_if_needed(content) except UnicodeDecodeError: telemetry.record_call("read_file", time.time() - start_time, False) return f"Binary file. Try audio_transcription_tool." except Exception as e: telemetry.record_call("read_file", time.time() - start_time, False) raise ToolError("read_file", e, f"Check file path: {path}") class WriteFileInput(BaseModel): path: str = Field(description="File path") content: str = Field(description="Content to write") @tool(args_schema=WriteFileInput) def write_file(path: str, content: str) -> str: """Write content to file""" start_time = time.time() try: print(f"✍️ Writing: {path}") file_path = Path.cwd() / path file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_text(content, encoding='utf-8') telemetry.record_call("write_file", time.time() - start_time, True) return f"Wrote {len(content)} chars to '{path}'" except Exception as e: telemetry.record_call("write_file", time.time() - start_time, False) raise ToolError("write_file", e) class ListDirInput(BaseModel): path: str = Field(description="Directory path", default=".") @tool(args_schema=ListDirInput) def list_directory(path: str = ".") -> str: """List directory contents""" start_time = time.time() try: print(f"📁 Listing: {path}") dir_path = Path.cwd() / path if path != "." else Path.cwd() if not dir_path.is_dir(): raise NotADirectoryError(f"'{path}' not a directory") items = sorted(dir_path.iterdir()) if not items: return f"Directory '{path}' is empty" files, dirs = [], [] for item in items: if item.is_dir(): dirs.append(f"📁 {item.name}/") else: files.append(f"📄 {item.name} ({item.stat().st_size} bytes)") result = f"Contents of '{path}':\n\n" if dirs: result += "Directories:\n" + "\n".join(dirs) + "\n\n" if files: result += "Files:\n" + "\n".join(files) telemetry.record_call("list_directory", time.time() - start_time, True) return result except Exception as e: telemetry.record_call("list_directory", time.time() - start_time, False) raise ToolError("list_directory", e) class AudioInput(BaseModel): file_path: str = Field(description="Audio file path") @tool(args_schema=AudioInput) def audio_transcription_tool(file_path: str) -> str: """Transcribe audio using Whisper""" start_time = time.time() try: print(f"🎤 Transcribing: {file_path}") if not asr_manager.is_ready(): asr_manager.initialize() if not asr_manager.is_ready(): raise RuntimeError("ASR not available") audio_path = find_file(file_path) if not audio_path: raise FileNotFoundError(f"Audio file not found: {file_path}") transcription = asr_manager.pipeline( str(audio_path), return_timestamps=True, chunk_length_s=30, stride_length_s=5 ) result_text = transcription.get("text", "") if not result_text: raise ValueError("Transcription empty") telemetry.record_call("audio_transcription_tool", time.time() - start_time, True) return f"Transcription:\n{truncate_if_needed(result_text)}" except Exception as e: telemetry.record_call("audio_transcription_tool", time.time() - start_time, False) raise ToolError("audio_transcription_tool", e) class ChessAnalysisInput(BaseModel): image_path: str = Field(description="Path to chess board image") description: str = Field(description="Context about position", default="") @tool(args_schema=ChessAnalysisInput) def analyze_chess_position(image_path: str, description: str = "") -> str: """ Analyze chess position from image using Gemini Vision + Stockfish. Extracts FEN, analyzes best move. """ start_time = time.time() try: print(f"♟️ Analyzing chess: {image_path}") # Find file image_path_obj = find_file(image_path) if not image_path_obj and os.path.exists(image_path): image_path_obj = Path(image_path) if not image_path_obj or not image_path_obj.exists(): raise FileNotFoundError(f"Image not found: {image_path}") GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") if not GOOGLE_API_KEY: raise ValueError("GEMINI_API_KEY not set") # Read image as base64 with open(image_path_obj, "rb") as f: image_data = base64.b64encode(f.read()).decode("utf-8") # Use Gemini to extract FEN llm = ChatGoogleGenerativeAI( model="gemini-2.5-flash", google_api_key=GOOGLE_API_KEY, temperature=0 ) message = HumanMessage( content=[ { "type": "text", "text": """Analyze this chess position and provide the FEN notation. CRITICAL: The FEN string MUST include whose turn it is: - If White to move: end with "w - - 0 1" - If Black to move: end with "b - - 0 1" Look at the board carefully to determine whose turn it is based on: 1. Any text in the image indicating whose turn 2. The position context 3. If unclear, look at piece positions Respond with ONLY the FEN string, nothing else.""" }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{image_data}" } } ] ) response = llm.invoke([message]) fen = response.content.strip() print(f"✓ FEN: {fen}") # ===== FIX: Parse whose turn it is from FEN ===== # FEN format: position w/b castling en-passant halfmove fullmove fen_parts = fen.split() # Ensure we have the turn indicator if len(fen_parts) < 2: # Default to white if not specified fen = f"{fen} w - - 0 1" fen_parts = fen.split() # Get whose turn it is turn = fen_parts[1] if len(fen_parts) > 1 else 'w' print(f"✓ Turn: {'Black' if turn == 'b' else 'White'}") # ===== END FIX ===== # Analyze with Stockfish try: board = chess.Board(fen) except ValueError as e: raise ValueError(f"Invalid FEN from Gemini: {fen}. Error: {e}") # Configure Stockfish stockfish_path = "/usr/games/stockfish" if not os.path.exists(stockfish_path): raise FileNotFoundError("Stockfish not found at /usr/games/stockfish") engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) # ===== FIX: Analyze with appropriate depth ===== # For tactical positions (like mate puzzles), need deeper analysis result = engine.analyse(board, chess.engine.Limit(depth=20)) # ===== END FIX ===== best_move = result["pv"][0] # Principal variation (best line) engine.quit() # Convert to algebraic notation move_san = board.san(best_move) print(f"✓ Best move: {move_san}") telemetry.record_call("analyze_chess_position", time.time() - start_time, True) # ===== FIX: Include turn info in response ===== turn_text = "Black" if turn == 'b' else "White" return f"{move_san} ({turn_text} to move, from FEN: {fen})" # ===== END FIX ===== except Exception as e: telemetry.record_call("analyze_chess_position", time.time() - start_time, False) raise ToolError("analyze_chess_position", e, "Check image quality and Stockfish installation") class ImageAnalysisInput(BaseModel): file_path: str = Field(description="Image file path") query: str = Field(description="What to analyze") @tool(args_schema=ImageAnalysisInput) def analyze_image(file_path: str, query: str) -> str: """Analyze images using Gemini Vision""" start_time = time.time() try: print(f"🖼️ Analyzing: {file_path}") print(f" Query: {query[:100]}...") image_path = find_file(file_path) if not image_path and os.path.exists(file_path): image_path = Path(file_path) if not image_path or not image_path.exists(): raise FileNotFoundError(f"Image not found: {file_path}") GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") if not GOOGLE_API_KEY: raise ValueError("GEMINI_API_KEY not set") # Load and encode img = Image.open(image_path) if img.mode not in ['RGB', 'RGBA']: img = img.convert('RGB') buffered = io.BytesIO() img.save(buffered, format="JPEG") img_base64 = base64.b64encode(buffered.getvalue()).decode() # Use FLASH model for cost efficiency vision_llm = ChatGoogleGenerativeAI( model="gemini-2.5-flash", google_api_key=GOOGLE_API_KEY, temperature=0 ) message = HumanMessage( content=[ {"type": "text", "text": query}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}} ] ) response = vision_llm.invoke([message]) telemetry.record_call("analyze_image", time.time() - start_time, True) return f"Image Analysis:\n{truncate_if_needed(response.content)}" except Exception as e: telemetry.record_call("analyze_image", time.time() - start_time, False) raise ToolError("analyze_image", e) class YoutubeInput(BaseModel): video_url: str = Field(description="YouTube URL") @tool(args_schema=YoutubeInput) def get_youtube_transcript(video_url: str) -> str: """Get YouTube transcript using AssemblyAI with proper status handling""" start_time = time.time() try: aai.settings.api_key = os.getenv("ASSEMBLYAI_API_KEY") if not aai.settings.api_key: raise ValueError("ASSEMBLYAI_API_KEY not set in Space secrets") print(f"📺 Transcribing YouTube: {video_url}") # Validate URL if not ("youtube.com" in video_url or "youtu.be" in video_url): raise ValueError(f"Invalid YouTube URL: {video_url}") # Submit transcription request transcriber = aai.Transcriber() print(f" Submitting to AssemblyAI...") config_obj = aai.TranscriptionConfig( speech_model=aai.SpeechModel.best, ) transcript = transcriber.transcribe(video_url, config=config_obj) # Wait for completion print(f" Initial status: {transcript.status}") # Poll for completion (max 5 minutes) max_wait = 300 poll_interval = 5 elapsed = 0 while transcript.status == aai.TranscriptStatus.queued or transcript.status == aai.TranscriptStatus.processing: if elapsed >= max_wait: raise TimeoutError(f"Transcription timed out after {max_wait}s. Video may be too long.") time.sleep(poll_interval) elapsed += poll_interval # Refresh transcript object try: transcript = transcriber.get_transcript(transcript.id) print(f" Status after {elapsed}s: {transcript.status}") except Exception as refresh_err: print(f" Warning: Could not refresh status: {refresh_err}") break # Check final status if transcript.status == aai.TranscriptStatus.error: error_msg = getattr(transcript, 'error', 'Unknown error') # ===== NEW: Check for network block ===== if "text/html" in error_msg or "HTML document" in error_msg: raise RuntimeError( "YouTube access blocked. " "If a local video file was provided, use analyze_image or audio_transcription_tool instead. " "Or try downloading the video first." ) # ===== END NEW ===== raise RuntimeError(f"AssemblyAI transcription failed: {error_msg}") if transcript.status != aai.TranscriptStatus.completed: raise RuntimeError(f"Unexpected status: {transcript.status}") # Extract text if not hasattr(transcript, 'text'): raise AttributeError("Transcript object has no 'text' attribute") result_text = transcript.text if not result_text or not isinstance(result_text, str): raise ValueError(f"Transcript text is invalid: {type(result_text)}") result_text = result_text.strip() if len(result_text) < 10: raise ValueError(f"Transcript too short ({len(result_text)} chars). Video may have no audio.") print(f"✓ Transcribed {len(result_text)} chars") telemetry.record_call("get_youtube_transcript", time.time() - start_time, True) return f"YouTube Transcript:\n{truncate_if_needed(result_text)}" except Exception as e: telemetry.record_call("get_youtube_transcript", time.time() - start_time, False) error_msg = str(e) suggestions = [] if "text/html" in error_msg.lower() or "html document" in error_msg.lower(): suggestions.append("YouTube blocked on HuggingFace. Use the local .mp4 file instead with audio_transcription_tool or analyze_image") elif "not found" in error_msg.lower(): suggestions.append("Video may be private or deleted") elif "quota" in error_msg.lower() or "limit" in error_msg.lower(): suggestions.append("AssemblyAI quota exceeded") elif "timeout" in error_msg.lower(): suggestions.append("Video may be too long (try shorter video)") suggestion_text = " | ".join(suggestions) if suggestions else "Check video URL is valid and public" raise ToolError("get_youtube_transcript", e, suggestion_text) class BrowseInput(BaseModel): start_url: str = Field(description="Starting URL (http:// or https://)") goal: str = Field(description="What you're trying to find (e.g., 'Mercedes Sosa albums 2000-2009')") max_steps: int = Field(description="Max pages to visit (1-5)", default=3) @tool(args_schema=BrowseInput) @retry_with_backoff(max_retries=2) def iterative_web_browser(start_url: str, goal: str, max_steps: int = 3) -> str: """ Multi-turn web browsing - follows links iteratively to find information. Use when: - Information requires navigating through multiple pages - Need to follow "Read more" or "Details" links - Example: "Find Mercedes Sosa's discography, then count 2000-2009 albums" This tool: 1. Visits start_url 2. Searches content for goal-related info 3. Extracts relevant links 4. Follows most promising link 5. Repeats until info found or max_steps reached Better than scrape_and_retrieve when single page doesn't have complete info. """ start_time = time.time() try: if not rag_manager.is_ready(): rag_manager.initialize() print(f"🌐 Iterative browsing starting at: {start_url}") print(f" Goal: {goal[:100]}...") print(f" Max steps: {max_steps}") visited_urls = set() current_url = start_url all_findings = [] headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } for step in range(max_steps): if current_url in visited_urls: print(f" Step {step+1}: Already visited, stopping") break visited_urls.add(current_url) print(f" Step {step+1}: Visiting {current_url}") try: response = requests.get(current_url, headers=headers, timeout=15) response.raise_for_status() soup = BeautifulSoup(response.text, 'html.parser') # Remove noise for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe"]): tag.extract() # Extract main content main = soup.find('main') or soup.find('article') or soup.find('div', class_='mw-parser-output') or soup.body if not main: print(f" No main content found") continue text = main.get_text(separator='\n', strip=True) lines = [l.strip() for l in text.splitlines() if l.strip()] text = '\n'.join(lines) print(f" Extracted {len(text)} chars") # Search for goal-related content chunks = rag_manager.text_splitter.split_text(text) docs = [Document(page_content=c, metadata={"source": current_url, "step": step+1}) for c in chunks] db = FAISS.from_documents(docs, rag_manager.embeddings) retriever = db.as_retriever(search_kwargs={"k": 3}) retrieved = retriever.invoke(goal) # Clean up del db del retriever import gc gc.collect() if retrieved: print(f" Found {len(retrieved)} relevant chunks") for i, doc in enumerate(retrieved): all_findings.append({ 'step': step + 1, 'url': current_url, 'content': doc.page_content }) # Extract links for next step if step < max_steps - 1: links = [] for a in main.find_all('a', href=True): href = a.get('href') text = a.get_text(strip=True).lower() # Make absolute URL if href.startswith('/'): from urllib.parse import urljoin href = urljoin(current_url, href) # Filter relevant links goal_keywords = goal.lower().split() if any(keyword in href.lower() or keyword in text for keyword in goal_keywords): if href.startswith('http') and href not in visited_urls: links.append((href, text)) if links: # Pick most relevant link current_url = links[0][0] print(f" Found {len(links)} potential links, following: {links[0][1][:50]}") else: print(f" No more relevant links found") break else: print(f" Max steps reached") break except Exception as e: print(f" Error on step {step+1}: {e}") break # Compile findings if not all_findings: result = f"Browsed {len(visited_urls)} pages but found no relevant information for: '{goal}'" else: result = f"Information gathered from {len(visited_urls)} pages:\n\n" for finding in all_findings: result += f"[Step {finding['step']} - {finding['url']}]\n{finding['content']}\n\n---\n\n" result = truncate_if_needed(result) telemetry.record_call("iterative_web_browser", time.time() - start_time, True) return result except Exception as e: telemetry.record_call("iterative_web_browser", time.time() - start_time, False) raise ToolError("iterative_web_browser", e, "Try starting from a more specific URL") class ScrapeInput(BaseModel): url: str = Field(description="URL (http:// or https://)") query: str = Field(description="Specific info to find") @tool(args_schema=ScrapeInput) @retry_with_backoff(max_retries=3) def scrape_and_retrieve(url: str, query: str) -> str: """ Scrape webpage and retrieve relevant sections using RAG with smart fallbacks. """ start_time = time.time() try: is_valid, msg = validate_tool_inputs("scrape_and_retrieve", {"url": url, "query": query}) if not is_valid: raise ValueError(msg) print(f"🌐 Scraping: {url}") print(f" Looking for: {query[:50]}...") # ===== TRY PRIMARY URL ===== try: response = requests.get(url, timeout=15, headers={ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' }) response.raise_for_status() except requests.exceptions.HTTPError as e: if e.response.status_code == 404: print(f" ❌ 404 error, trying fallbacks...") # ===== FALLBACK 1: Try alternative URL formats ===== if "wikipedia.org" in url: fallback_urls = [] # Example: Wikipedia:Featured_articles/2016_November # Try: Wikipedia:Featured_articles#2016 if "/20" in url and "_" in url: # Extract year import re year_match = re.search(r'/(\d{4})', url) if year_match: year = year_match.group(1) # Try anchor link format base_url = url.split('/20')[0] fallback_urls.append(f"{base_url}#{year}") # Try without year suffix fallback_urls.append(base_url) # Try with underscores replaced by spaces (URL encoded) if "_" in url: fallback_urls.append(url.replace("_", "%20")) # Try each fallback for fallback_url in fallback_urls: try: print(f" Trying fallback: {fallback_url}") response = requests.get(fallback_url, timeout=15, headers={ 'User-Agent': 'Mozilla/5.0' }) response.raise_for_status() url = fallback_url # Update URL for later print(f" ✓ Fallback succeeded!") break except: continue else: # All fallbacks failed # ===== FALLBACK 2: Use Wikipedia search ===== print(f" All URL fallbacks failed, trying Wikipedia search...") # Extract search terms from URL search_terms = url.split('/')[-1].replace('_', ' ').replace('%20', ' ') # Search Wikipedia search_url = f"https://en.wikipedia.org/w/api.php?action=opensearch&search={search_terms}&limit=1&format=json" search_response = requests.get(search_url, timeout=10) search_data = search_response.json() if len(search_data) > 3 and search_data[3]: # Found a result wiki_url = search_data[3][0] print(f" ✓ Found via search: {wiki_url}") response = requests.get(wiki_url, timeout=15, headers={ 'User-Agent': 'Mozilla/5.0' }) response.raise_for_status() url = wiki_url else: raise ToolError( "scrape_and_retrieve", Exception(f"404 and all fallbacks failed for {url}"), "Try using wikipedia_search tool to find the correct article first" ) else: # Non-Wikipedia 404 raise else: # Other HTTP error raise # ===== END FALLBACKS ===== # Parse content soup = BeautifulSoup(response.content, 'html.parser') # Remove unwanted elements for element in soup(['script', 'style', 'nav', 'header', 'footer']): element.decompose() text = soup.get_text(separator='\n', strip=True) if len(text) < 100: raise ValueError(f"Insufficient content extracted from {url}") print(f"✓ Extracted {len(text)} characters") # RAG retrieval docs = [Document(page_content=text, metadata={"source": url})] text_splitter = RecursiveCharacterTextSplitter( chunk_size=Config.CHUNK_SIZE, chunk_overlap=Config.CHUNK_OVERLAP ) chunks = text_splitter.split_documents(docs) print(f"✓ Created {len(chunks)} chunks") # Search for relevant chunks vectorstore = FAISS.from_documents(chunks, rag_manager.embeddings) retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) relevant_docs = retriever.invoke(query) print(f"✓ Found {len(relevant_docs)} relevant chunks") # Format results results = [] for i, doc in enumerate(relevant_docs, 1): content = doc.page_content.strip() results.append(f"[Section {i}]\n{content}") result = f"From {url}:\n\n" + "\n\n".join(results) # Cleanup del vectorstore gc.collect() telemetry.record_call("scrape_and_retrieve", time.time() - start_time, True) return truncate_if_needed(result) except Exception as e: telemetry.record_call("scrape_and_retrieve", time.time() - start_time, False) raise ToolError("scrape_and_retrieve", e) class VideoAnalysisInput(BaseModel): file_path: str = Field(description="Path to video file (.mp4, .mov, etc.)") query: str = Field(description="What to find in the video") @tool(args_schema=VideoAnalysisInput) def analyze_video(file_path: str, query: str) -> str: """ Analyze video using Gemini Vision (supports video). Use for: - Counting objects/people/animals in video - Describing what happens - Finding specific moments - Visual Q&A about video content """ start_time = time.time() try: print(f"🎥 Analyzing video: {file_path}") print(f" Query: {query[:100]}...") video_path = find_file(file_path) if not video_path and os.path.exists(file_path): video_path = Path(file_path) if not video_path or not video_path.exists(): raise FileNotFoundError(f"Video not found: {file_path}") GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") if not GOOGLE_API_KEY: raise ValueError("GEMINI_API_KEY not set") # Use Google GenAI SDK directly — LangChain wrapper doesn't support video # Try new SDK (google-genai) first, fall back to old SDK (google-generativeai) import time as _time try: from google import genai as _genai client = _genai.Client(api_key=GOOGLE_API_KEY) print(f" Uploading video to Gemini Files API (new SDK)...") video_file = client.files.upload(file=str(video_path)) while video_file.state.name == "PROCESSING": _time.sleep(2) video_file = client.files.get(name=video_file.name) if video_file.state.name == "FAILED": raise RuntimeError(f"Gemini file processing failed: {video_file.state}") print(f" Analyzing with Gemini...") response = client.models.generate_content( model="gemini-2.5-flash", contents=[query, video_file] ) result = response.text try: client.files.delete(name=video_file.name) except Exception: pass except ImportError: import google.generativeai as genai_old genai_old.configure(api_key=GOOGLE_API_KEY) print(f" Uploading video to Gemini Files API (old SDK)...") video_file = genai_old.upload_file(str(video_path)) while video_file.state.name == "PROCESSING": _time.sleep(2) video_file = genai_old.get_file(video_file.name) if video_file.state.name == "FAILED": raise RuntimeError(f"Gemini file processing failed: {video_file.state}") print(f" Analyzing with Gemini...") model = genai_old.GenerativeModel("gemini-2.5-flash") response = model.generate_content([query, video_file]) result = response.text try: genai_old.delete_file(video_file.name) except Exception: pass print(f"✓ Analysis complete: {len(result)} chars") telemetry.record_call("analyze_video", time.time() - start_time, True) return f"Video Analysis:\n{truncate_if_needed(result)}" except Exception as e: telemetry.record_call("analyze_video", time.time() - start_time, False) raise ToolError("analyze_video", e, "Check video file path and Gemini API") class FinalAnswerInput(BaseModel): answer: str = Field(description="Final answer - exact, no fluff") @tool(args_schema=FinalAnswerInput) def final_answer_tool(answer: str) -> str: """Submit final answer with normalization""" start_time = time.time() try: # Get question from state (you'll need to pass this through) # For now, normalize without question context original_answer = answer answer = normalize_answer(answer) if answer != original_answer: print(f"📝 Normalized answer:") print(f" Before: '{original_answer}'") print(f" After: '{answer}'") print(f"\n✅ FINAL: '{answer}'\n") telemetry.record_call("final_answer_tool", time.time() - start_time, True) return f"FINAL_ANSWER: {answer}" except Exception as e: telemetry.record_call("final_answer_tool", time.time() - start_time, False) raise ToolError("final_answer_tool", e) # ============================================================================= # TOOLS LIST # ============================================================================= defined_tools = [ # Planning & Reflection think_through_logic, create_plan, reflect_on_progress, validate_answer, analyze_data_file, # Core tools search_tool, wikipedia_search, calculator, analyze_video, code_interpreter, # File operations read_file, write_file, list_directory, # Specialized audio_transcription_tool, analyze_image, get_youtube_transcript, scrape_and_retrieve, analyze_chess_position, # Final final_answer_tool ] # ============================================================================= # AGENT STATE # ============================================================================= class AgentState(TypedDict): messages: Annotated[List[AnyMessage], add_messages] turn: int has_plan: bool consecutive_errors: int tool_history: List[str] last_tool_was_thinking: bool # ============================================================================= # TOOL CALL PARSER # ============================================================================= def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]: """Enhanced fallback parser""" print(f"🔧 Parsing tool call from: {content[:300]}...") tool_name = None tool_input = None # Strategy 1: Groq format groq_match = re.search(r"|)", content, re.DOTALL) if groq_match: try: tool_name = groq_match.group(1).strip() json_str = groq_match.group(2).strip() json_str = json_str.encode().decode('unicode_escape') tool_input = json.loads(json_str) print(f"✓ Parsed Groq format: {tool_name}") except: tool_name = None # Strategy 2: Standard format if not tool_name: func_match = re.search(r"](.*)", content, re.DOTALL | re.IGNORECASE) if func_match: try: tool_name = func_match.group(1).strip().replace("'", "").replace('"', '') remaining = func_match.group(2) json_start = remaining.find('{') if json_start != -1: json_str = remaining[json_start:].strip().rstrip(',') tool_input = json.loads(json_str) print(f"✓ Parsed standard format: {tool_name}") except: tool_name = None # Strategy 3: Code block → code_interpreter if not tool_name and "```python" in content: try: code_match = re.search(r"```python\n(.*?)```", content, re.DOTALL) if code_match: code = code_match.group(1).strip() tool_name = "code_interpreter" tool_input = {"code": code} print(f"✓ Extracted Python code") except: pass # Strategy 4: Tool mention if not tool_name: for tool in tools: if tool.name.lower() in content.lower(): tool_name = tool.name tool_input = {} if tool.args_schema: schema = tool.args_schema.model_json_schema() for prop in schema.get('properties', {}).keys(): if prop in schema.get('required', []): tool_input[prop] = "auto_extracted" print(f"✓ Found mention: {tool_name}") break # Strategy 5: Force thinking if not tool_name: if len(content) > 50: tool_name = "think_through_logic" tool_input = {"reasoning": content[:150]} print(f"⚠️ Forcing think_through_logic") if tool_name and tool_input is not None: matching = [t for t in tools if t.name == tool_name] if matching: return [ToolCall(name=tool_name, args=tool_input, id=str(uuid.uuid4()))] print("❌ All parsing failed") return [] # ============================================================================= # CONDITIONAL EDGE # ============================================================================= def should_continue(state: AgentState): """Decide next step""" messages = state.get('messages', []) if not messages: return "agent" last_message = messages[-1] current_turn = state.get('turn', 0) print(f"📍 Turn {current_turn}, Last: {type(last_message).__name__}") if current_turn >= config.MAX_TURNS: print(f"🛑 Max turns reached") return END if isinstance(last_message, ToolMessage): print(f"📨 Tool result → agent") return "agent" if isinstance(last_message, AIMessage) and last_message.tool_calls: first_tool = last_message.tool_calls[0] if first_tool.get("name") == "final_answer_tool": return END return "tools" if isinstance(last_message, AIMessage) and not last_message.tool_calls: if len(messages) >= 2 and isinstance(messages[-2], AIMessage) and not messages[-2].tool_calls: print(f"⚠️ Loop detected") return END print(f"💭 AI without tool → agent") return "agent" return "agent" # ============================================================================= # MAIN AGENT CLASS # ============================================================================= class PlanningReflectionAgent: def __init__(self): print("🧠 Initializing PlanningReflectionAgent...") # Check API keys GROQ_API_KEY = os.getenv("GROQ_API_KEY") if not GROQ_API_KEY: raise ValueError("GROQ_API_KEY not set") GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") if not GOOGLE_API_KEY: raise ValueError("GEMINI_API_KEY not set") self.tools = defined_tools # Initialize RAG rag_manager.initialize() # Build tool descriptions tool_desc_list = [] for tool in self.tools: if tool.args_schema: schema = tool.args_schema.model_json_schema() args_desc = [f" - {p}: {d.get('description', '')}" for p, d in schema.get('properties', {}).items()] desc = f"- {tool.name}:\n {tool.description}\n" + "\n".join(args_desc) else: desc = f"- {tool.name}: {tool.description}" tool_desc_list.append(desc) tool_descriptions = "\n".join(tool_desc_list) self.system_prompt = f"""You are an elite AI agent for GAIA benchmark. Your ONLY job: provide the EXACT answer requested. ═══════════════════════════════════════════════════════════════ ⚠️ ABSOLUTE RULES - VIOLATE THESE AND YOU FAIL: ═══════════════════════════════════════════════════════════════ 1. **EVERY TURN MUST CALL EXACTLY ONE TOOL** - No exceptions 2. **NEVER OUTPUT REASONING TEXT WITHOUT A TOOL CALL** - You will fail 3. **IDENTIFY QUESTION TYPE FIRST** - Logic? Factual? Data? Math? 4. **ALWAYS VALIDATE**: Call validate_answer() before final_answer_tool() 5. **FINAL ANSWER FORMAT**: EXACTLY what was asked. NO "The answer is..." or explanations ═══════════════════════════════════════════════════════════════ 📋 QUESTION TYPE → TOOL SEQUENCE: ═══════════════════════════════════════════════════════════════ **LOGIC PUZZLES** (No web search needed): → think_through_logic → calculator (if math) → validate → final_answer **FACTUAL/BIOGRAPHICAL** (Need web): → wikipedia_search (if person/place/thing) → validate → final_answer OR search_tool → scrape_and_retrieve → validate → final_answer **COUNTING FROM WEB** (Need full page content): → wikipedia_search (if Wikipedia topic) → validate → final_answer OR iterative_web_browser (if needs navigation) → validate → final_answer **DATA FILES** (CSV/Excel): → list_directory → analyze_data_file → code_interpreter → validate → final_answer **IMAGES** (Chess, diagrams, photos): → analyze_image → validate → final_answer **AUDIO FILES**: → audio_transcription_tool → validate → final_answer **MATH CALCULATIONS**: → calculator → validate → final_answer ═══════════════════════════════════════════════════════════════ 📚 WIKIPEDIA QUERIES - CRITICAL: ═══════════════════════════════════════════════════════════════ If question mentions Wikipedia: 1. Use wikipedia_search() with SHORT query (just the subject) 2. Get Wikipedia URL 3. Use scrape_and_retrieve() for detailed info ✅ CORRECT Example: Q: "How many albums by Mercedes Sosa 2000-2009 using Wikipedia?" Turn 1: wikipedia_search("Mercedes Sosa") → Returns URL Turn 2: scrape_and_retrieve( url="https://en.wikipedia.org/wiki/Mercedes_Sosa", query="studio albums 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009" ) → Returns full discography Turn 3: code_interpreter("count albums in years 2000-2009") → Returns "3" Turn 4: validate_answer("3", question) Turn 5: final_answer_tool("3") ❌ WRONG Examples: - wikipedia_search("Mercedes Sosa discography 2022 English Wikipedia version") - wikipedia_search("Mercedes Sosa Wikipedia") - wikipedia_search("How many albums Mercedes Sosa") REMEMBER: wikipedia_search() wants just the SUBJECT NAME! ═══════════════════════════════════════════════════════════════ **YOUTUBE VIDEO HANDLING:** ⚠️ YouTube URLs are BLOCKED on HuggingFace Spaces! IF question mentions YouTube URL AND local video file exists: → Use analyze_video tool on the local .mp4 file instead → The local file contains the same video content Example: Question: "In video https://youtube.com/watch?v=abc, how many birds?" File: files/task_123.mp4 ✅ CORRECT: analyze_video("files/task_123.mp4", "count bird species") ❌ WRONG: get_youtube_transcript("https://youtube.com/...") 🚨 ANTI-LOOP RULES: ═══════════════════════════════════════════════════════════════ 1. NEVER call the same tool 3 times in a row 2. think_through_logic is ONLY for logic puzzles (NOT research) 3. Research questions need search_tool or wikipedia_search 4. If stuck for 3 turns → try DIFFERENT tool ═══════════════════════════════════════════════════════════════ ═══════════════════════════════════════════════════════════════ 🎯 CRITICAL TOOL USAGE PATTERNS: ═══════════════════════════════════════════════════════════════ **For Counting Questions:** BAD: search_tool("Mercedes Sosa albums") → snippets only GOOD: wikipedia_search("Mercedes Sosa") → full discography section **For Multi-Step Web Questions:** BAD: scrape_and_retrieve("https://...") → single page only GOOD: iterative_web_browser("https://...", "find X", max_steps=3) **For Data Questions:** BAD: read_file("data.csv") → raw text dump GOOD: analyze_data_file("data.csv", "count rows where X > Y") **For Validation:** ALWAYS: validate_answer("your answer", "original question") THEN: final_answer_tool("your answer") ═══════════════════════════════════════════════════════════════ 📚 AVAILABLE TOOLS: ═══════════════════════════════════════════════════════════════ {tool_descriptions} ═══════════════════════════════════════════════════════════════ ⚡ EXECUTION RULES: ═══════════════════════════════════════════════════════════════ - Text without tool call = FAILURE - Unsure? → think_through_logic() to organize thoughts - After EVERY tool result: "Do I have the answer? → validate → submit" - Stuck after 3 turns? → reflect_on_progress() - For Wikipedia topics → ALWAYS use wikipedia_search, NOT search_tool - For counting from web → Use wikipedia_search or iterative_web_browser - For data files → Use analyze_data_file, NOT just read_file ═══════════════════════════════════════════════════════════════ 🎓 EXAMPLES OF PERFECT EXECUTION: ═══════════════════════════════════════════════════════════════ Example 1: "How many studio albums did Mercedes Sosa release 2000-2009?" Turn 1: wikipedia_search("Mercedes Sosa") → Gets full discography with all albums and years Turn 2: code_interpreter("count albums 2000-2009 from text") → Result: 3 Turn 3: validate_answer("3", "How many studio albums...") → ✅ PASSED Turn 4: final_answer_tool("3") Example 2: "What's the population of Einstein's birthplace in 1900?" Turn 1: wikipedia_search("Albert Einstein") → Birthplace: Ulm, Germany Turn 2: search_tool("Ulm Germany population 1900") → Find sources Turn 3: scrape_and_retrieve("url", "population 1900") → ~50,000 Turn 4: validate_answer("50000", "population 1900") → ✅ PASSED Turn 5: final_answer_tool("50000") Example 3: Logic puzzle Turn 1: think_through_logic("Work through the logic...") → Reasoning recorded Turn 2: calculator("30") [if calculation needed] → 30 Turn 3: validate_answer("30", "coin puzzle") → ✅ PASSED Turn 4: final_answer_tool("30") ═══════════════════════════════════════════════════════════════ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. ═══════════════════════════════════════════════════════════════ """ # Initialize LLMs print("Initializing LLMs...") # Primary: Groq qwen3-32b self.groq_llm = ChatGroq( temperature=0, groq_api_key=GROQ_API_KEY, model_name="qwen/qwen3-32b", max_tokens=4096, timeout=60 ).bind_tools(self.tools, tool_choice="auto") # Fallback 1: Groq llama-3.3-70b (separate per-model quota) self.groq_llama_llm = ChatGroq( temperature=0, groq_api_key=GROQ_API_KEY, model_name="llama-3.3-70b-versatile", max_tokens=4096, timeout=60 ).bind_tools(self.tools, tool_choice="auto") print("✅ Groq llama-3.3-70b fallback initialized") # Fallback 2: Gemma 3 27B via Gemini API (same key, 15K TPM, 14.4K RPD) self.gemma_llm = ChatGoogleGenerativeAI( model="gemma-3-27b-it", google_api_key=GOOGLE_API_KEY, temperature=0, max_tokens=4096 ).bind_tools(self.tools, tool_choice="auto") print("✅ Gemma 3 27B fallback initialized") # Fallback 3: Claude (if key provided) ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") if ANTHROPIC_API_KEY: from langchain_anthropic import ChatAnthropic self.claude_llm = ChatAnthropic( model="claude-sonnet-4-20250514", anthropic_api_key=ANTHROPIC_API_KEY, temperature=0, max_tokens=4096 ).bind_tools(self.tools, tool_choice="auto") print("✅ Claude fallback initialized") else: self.claude_llm = None print("ℹ️ Claude fallback unavailable (no ANTHROPIC_API_KEY)") chain = "Groq qwen3-32b → Groq llama-3.3-70b → Gemma 3 27B" if ANTHROPIC_API_KEY: chain += " → Claude" print(f"✅ LLM chain: {chain}") # Start with Groq self.llm_with_tools = self.groq_llm self.current_llm = "groq" def prune_context_if_needed(state: AgentState) -> AgentState: """ Prune conversation history if it's getting too long. Keeps system message + recent history to stay under token limits. """ messages = state.get("messages", []) # Keep first message (system prompt) + last N messages MAX_MESSAGES = 20 # ~6000 token limit on Groq; system msg ~3000 chars leaves ~18000 for the rest MAX_TOOL_CONTENT = 1500 # Prune by count if len(messages) > MAX_MESSAGES: print(f"⚠️ Context pruning: {len(messages)} messages → {MAX_MESSAGES}") system_msg = None if messages and isinstance(messages[0], SystemMessage): system_msg = messages[0] messages = messages[1:] recent_messages = messages[-(MAX_MESSAGES-1):] if system_msg: messages = [system_msg] + recent_messages else: messages = recent_messages # Truncate oversized tool outputs to prevent 413 errors pruned = [] for msg in messages: if isinstance(msg, ToolMessage) and len(msg.content) > MAX_TOOL_CONTENT: msg = ToolMessage( content=msg.content[:MAX_TOOL_CONTENT] + "...[truncated]", tool_call_id=msg.tool_call_id, name=msg.name ) pruned.append(msg) state["messages"] = pruned return state # Build agent graph def agent_node(state: AgentState): current_turn = state.get('turn', 0) + 1 max_retries = config.MAX_RETRIES print(f"\n{'='*70}") print(f"🤖 AGENT TURN {current_turn}/{config.MAX_TURNS}") print('='*70) state = prune_context_if_needed(state) if current_turn > config.MAX_TURNS: return { "messages": [SystemMessage(content="Max turns reached.")], "turn": current_turn } tool_history = state.get('tool_history', []) # Check for loops (same tool called 3+ times) if len(tool_history) >= 3: last_3 = tool_history[-3:] # If same tool 3 times in a row, FORCE change if len(set(last_3)) == 1: problem_tool = last_3[0] print(f"🚨 LOOP DETECTED: {problem_tool} called 3x - FORCING CHANGE") force_msg = SystemMessage( content=f"""⚠️ EMERGENCY: You called {problem_tool}() 3 times in a row! THIS IS A LOOP. You MUST use a DIFFERENT tool now. BANNED this turn: {problem_tool} Pick ANY other tool and call it NOW.""" ) messages_to_send = state["messages"].copy() messages_to_send.append(force_msg) else: messages_to_send = state["messages"].copy() else: messages_to_send = state["messages"].copy() # ===== END LOOP DETECTION ===== # Check if we should force reflection consecutive_errors = state.get('consecutive_errors', 0) should_reflect = (current_turn > 5 and current_turn % Config.REFLECT_EVERY_N_TURNS == 0) or consecutive_errors >= 3 # Force tool usage if len(messages_to_send) >= 2: last_msg = messages_to_send[-1] if isinstance(last_msg, AIMessage) and not last_msg.tool_calls: force_msg = SystemMessage( content="⚠️ CRITICAL: MUST call a tool. NO reasoning text." ) messages_to_send.append(force_msg) print("🚨 Forcing tool usage") if should_reflect: hint = SystemMessage( content="⚠️ HINT: No progress. Try reflect_on_progress() or different approach." ) messages_to_send.append(hint) print("🤔 Reflection hint") # Invoke LLM with retries and fallback ai_message = None for attempt in range(max_retries): try: ai_message = self.llm_with_tools.invoke(messages_to_send) if ai_message.tool_calls: break except Exception as e: error_str = str(e) print(f"⚠️ Groq error (attempt {attempt+1}): {error_str[:200]}") # ===== IMPROVED RATE LIMIT HANDLING ===== # Context too large — truncate aggressively and retry immediately if "413" in error_str or "request too large" in error_str.lower(): print("❌ Request too large (413) - aggressively pruning context") # Keep system message + last 4 messages, truncate tool content to 1000 chars pruned = [] for msg in messages_to_send: if isinstance(msg, SystemMessage): pruned.append(msg) break pruned += messages_to_send[-4:] for msg in pruned: if isinstance(msg, ToolMessage) and len(msg.content) > 1000: msg = ToolMessage( content=msg.content[:1000] + "...[truncated]", tool_call_id=msg.tool_call_id, name=msg.name ) messages_to_send = pruned print(f" Pruned to {len(messages_to_send)} messages, retrying...") continue # Check for rate limit if "429" in error_str or "rate limit" in error_str.lower(): print("❌ Groq rate limit hit!") if attempt < max_retries - 1: wait = 10 * (2 ** attempt) # 10s, 20s, 40s print(f" Waiting {wait}s before retry...") time.sleep(wait) continue # Fallback chain: Groq llama → Gemma 3 → Claude → search if self.groq_llama_llm and self.current_llm != "groq_llama": print("🔄 Groq qwen limit - switching to Groq llama-3.3-70b") self.llm_with_tools = self.groq_llama_llm self.current_llm = "groq_llama" try: ai_message = self.groq_llama_llm.invoke(messages_to_send) break except Exception as llama_err: print(f"❌ Groq llama fallback also failed: {llama_err}") if self.gemma_llm: print("🔄 Groq rate limit - switching to Gemma 3 27B fallback") self.llm_with_tools = self.gemma_llm self.current_llm = "gemma" try: ai_message = self.gemma_llm.invoke(messages_to_send) break except Exception as gemma_err: print(f"❌ Gemma fallback also failed: {gemma_err}") if self.claude_llm: print("🔄 Switching to Claude fallback") self.llm_with_tools = self.claude_llm self.current_llm = "claude" try: ai_message = self.claude_llm.invoke(messages_to_send) break except Exception as claude_err: print(f"❌ Claude fallback also failed: {claude_err}") # No LLM available — extract question and do one targeted search print("🔄 No LLM available - attempting targeted search fallback") question_text = "" for msg in state["messages"]: if isinstance(msg, HumanMessage) and msg.content: question_text = str(msg.content)[:200].strip() break ai_message = AIMessage( content="", tool_calls=[ToolCall( name="search_tool", args={"query": question_text or "unknown question"}, id=str(uuid.uuid4()) )] ) break # ===== END RATE LIMIT HANDLING ===== # Tool use failed error if any(kw in error_str for kw in ["tool_use_failed", "tool call validation"]): print("🚨 Tool error - forcing think_through_logic") ai_message = AIMessage( content="", tool_calls=[ToolCall( name="think_through_logic", args={"reasoning": "Processing..."}, id=str(uuid.uuid4()) )] ) break # Final retry if attempt == max_retries - 1: print("🚨 All attempts failed - forcing think_through_logic") ai_message = AIMessage( content="", tool_calls=[ToolCall( name="think_through_logic", args={"reasoning": "Processing"}, id=str(uuid.uuid4()) )] ) else: time.sleep(2 ** attempt) # Ensure tool calls exist if not ai_message.tool_calls: if ai_message.content: parsed = parse_tool_call_from_string(ai_message.content, self.tools) if parsed: ai_message.tool_calls = parsed ai_message.content = "" else: ai_message.tool_calls = [ToolCall( name="think_through_logic", args={"reasoning": "analyzing"}, id=str(uuid.uuid4()) )] ai_message.content = "" # Track usage tool_history = state.get('tool_history', []) has_plan = state.get('has_plan', False) if ai_message.tool_calls: tool_name = ai_message.tool_calls[0]['name'] print(f"🔧 Tool: {tool_name}") tool_history.append(tool_name) if tool_name == "create_plan": has_plan = True return { "messages": [ai_message], "turn": current_turn, "has_plan": has_plan, "tool_history": tool_history, "last_tool_was_thinking": ai_message.tool_calls and ai_message.tool_calls[0]['name'] == 'think_through_logic' } def tool_node_wrapper(state: AgentState): """Execute tools with error tracking""" print(f"🔧 Executing tools...") tool_executor = ToolNode(self.tools) result = tool_executor.invoke(state) consecutive_errors = state.get('consecutive_errors', 0) if result.get('messages'): last_msg = result['messages'][-1] if isinstance(last_msg, ToolMessage): if "Error" in last_msg.content or "error" in last_msg.content.lower(): consecutive_errors += 1 print(f"⚠️ Tool error (consecutive: {consecutive_errors})") else: consecutive_errors = 0 result['consecutive_errors'] = consecutive_errors return result # Build graph print("Building graph...") graph_builder = StateGraph(AgentState) graph_builder.add_node("agent", agent_node) graph_builder.add_node("tools", tool_node_wrapper) graph_builder.add_edge(START, "agent") graph_builder.add_conditional_edges( "agent", should_continue, { "tools": "tools", "agent": "agent", END: END } ) graph_builder.add_edge("tools", "agent") self.graph = graph_builder.compile() print("✅ Graph compiled") def __call__(self, question: str, file_path: str = None) -> str: """Execute agent""" print(f"\n{'='*70}") print(f"🎯 NEW QUESTION") print(f"{'='*70}") print(f"Q: {question[:200]}...") if file_path: print(f"📎 File: {file_path}") print(f"{'='*70}\n") # Build question context question_text = question if file_path: file_ext = Path(file_path).suffix.lower() file_type = "unknown" if file_ext in ['.jpg', '.jpeg', '.png', '.gif']: file_type = "image" elif file_ext in ['.mp3', '.wav', '.m4a']: file_type = "audio" elif file_ext in ['.csv', '.xlsx']: file_type = "data" elif file_ext in ['.txt', '.pdf', '.doc']: file_type = "document" question_text += f"\n\n[FILE: {file_path}]" question_text += f"\n[TYPE: {file_type}]" question_text += f"\nUse appropriate tool first!" graph_input = { "messages": [ SystemMessage(content=self.system_prompt), HumanMessage(content=question_text) ], "file_path": file_path, "turn": 0, "has_plan": False, "consecutive_errors": 0, "tool_history": [], "last_tool_was_thinking": False } # Reset to Groq for each question if self.groq_llm: self.llm_with_tools = self.groq_llm self.current_llm = "groq" final_answer = "AGENT FAILED" all_messages = [] try: config_dict = {"recursion_limit": config.MAX_TURNS * 2 + 10} for event in self.graph.stream(graph_input, stream_mode="values", config=config_dict): if not event.get('messages'): continue all_messages = event["messages"] last_message = all_messages[-1] # Check for final answer if isinstance(last_message, AIMessage) and last_message.tool_calls: for tool_call in last_message.tool_calls: if tool_call.get("name") == "final_answer_tool": args = tool_call.get('args', {}) if 'answer' in args: final_answer = normalize_answer(args['answer']) print(f"\n✅ FINAL: '{final_answer}'\n") break elif isinstance(last_message, ToolMessage): preview = last_message.content[:200].replace('\n', ' ') print(f"📊 Tool '{last_message.name}': {preview}...") # Fallback: extract from tool results if final_answer == "AGENT FAILED": print("⚠️ No final_answer_tool. Checking tools...") for msg in reversed(all_messages): if isinstance(msg, ToolMessage): if msg.name in ["calculator", "think_through_logic", "code_interpreter"]: content = msg.content.strip() if content and len(content) < 200 and not content.startswith("Error"): lines = content.split('\n') for line in reversed(lines): if line.strip() and not line.startswith(('✅', '⚠️', 'Next', 'Remember')): final_answer = line.strip() print(f"📝 Extracted: '{final_answer}'") break break # Clean answer more aggressively cleaned = str(final_answer).strip() # Remove common prefixes (case-insensitive) prefixes = [ "the answer is:", "here is the answer:", "based on", "final answer:", "answer:", "the final answer is:", "my answer is:", "according to", "i found that", "the result is:", "result:", "here's the answer:", "after analysis:", "the correct answer is:", "from the data:", "from the search:", ] for prefix in prefixes: if cleaned.lower().startswith(prefix.lower()): potential = cleaned[len(prefix):].strip() if potential: cleaned = potential break # Remove code fences cleaned = remove_fences_simple(cleaned) # Remove backticks while cleaned.startswith("`") and cleaned.endswith("`"): cleaned = cleaned[1:-1].strip() # Remove quotes (but only if they wrap entire answer) if (cleaned.startswith('"') and cleaned.endswith('"')) or \ (cleaned.startswith("'") and cleaned.endswith("'")): cleaned = cleaned[1:-1].strip() # Remove trailing period for short answers if cleaned.endswith('.') and len(cleaned.split()) < 10: cleaned = cleaned[:-1] # Remove markdown bold/italic cleaned = cleaned.replace('**', '').replace('__', '').replace('*', '').replace('_', '') # Remove bullet points if cleaned.startswith(('- ', '* ', '• ')): cleaned = cleaned[2:].strip() # Remove numbered list prefix import re cleaned = re.sub(r'^\d+\.\s+', '', cleaned) # Final whitespace cleanup cleaned = ' '.join(cleaned.split()) print(f"\n🎉 RETURNING: {cleaned}\n") return cleaned except Exception as e: print(f"❌ Graph error: {e}") print(traceback.format_exc()) return f"ERROR: {e}" # ============================================================================= # GLOBAL AGENT # ============================================================================= agent = None try: rag_manager.initialize() agent = PlanningReflectionAgent() print("✅ Global agent ready") if not callable(agent): print("❌ Agent not callable") agent = None else: print("✅ Agent is callable") except Exception as e: print(f"❌ FATAL: {e}") traceback.print_exc() agent = None # ============================================================================= # RUN AND SUBMIT # ============================================================================= def run_and_submit_all(profile: gr.OAuthProfile | None): """Run evaluation and submit""" space_id = os.getenv("SPACE_ID") if profile: username = profile.username print(f"User: {username}") else: print("Not logged in") return "Please login to HuggingFace", None global agent if agent is None: return "FATAL: Agent failed to initialize", None print("✅ Using global agent") api_url = config.DEFAULT_API_URL questions_url = f"{api_url}/questions" submit_url = f"{api_url}/submit" agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" # Fetch questions print(f"\n{'='*70}") print(f"📥 FETCHING QUESTIONS") print(f"{'='*70}\n") try: response = requests.get(questions_url, timeout=15) response.raise_for_status() questions_data = response.json() if not questions_data: return "No questions fetched", None print(f"✅ Fetched {len(questions_data)} questions\n") except Exception as e: print(f"❌ Fetch error: {e}") return f"Error fetching questions: {e}", None # Load answer sheet validator = AnswerValidator() answer_sheet = validator.load_answer_sheet("answer_sheet_json.json") # Initialize tracking progress = ProgressTracker(len(questions_data)) telemetry.reset() results_log = [] answers_payload = [] # Process questions print(f"\n{'='*70}") print(f"🚀 STARTING EVALUATION") print(f"{'='*70}\n") for idx, item in enumerate(questions_data, 1): print(f"\n{'='*70}") print(f"📝 QUESTION {idx}/{len(questions_data)}") print(f"{'='*70}\n") task_id = item.get("task_id") question_text = item.get("question") correct_answer = answer_sheet.get(task_id, "") # Find file local_file_path = None files_dir = "files" try: if os.path.exists(files_dir): matching_files = [f for f in os.listdir(files_dir) if f.startswith(task_id)] if matching_files: local_file_path = os.path.join(files_dir, matching_files[0]) print(f"✅ Found file: {matching_files[0]}") else: print(f"ℹ️ No file for {task_id}") else: print(f"⚠️ '{files_dir}' not found") except Exception as e: print(f"❌ File search error: {e}") try: # Run agent submitted_answer = agent(question_text, local_file_path) answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) # Check correctness is_correct, feedback = validator.check_correctness(submitted_answer, correct_answer) print(f"\n{feedback} - Task {task_id}") print(f" Submitted: '{submitted_answer}'") print(f" Expected: '{correct_answer}'") results_log.append({ "Task ID": task_id, "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, "Submitted": submitted_answer, "Correct": correct_answer, "Status": "✅" if is_correct else "❌" }) progress.update(is_correct) print(f"\n✅ Question {idx} completed") except Exception as e: print(f"❌ Error on {task_id}: {e}") print(traceback.format_exc()) results_log.append({ "Task ID": task_id, "Question": question_text[:100] + "...", "Submitted": f"ERROR: {e}", "Correct": correct_answer, "Status": "❌" }) answers_payload.append({"task_id": task_id, "submitted_answer": f"ERROR: {str(e)[:100]}"}) progress.update(False) # Print telemetry telemetry.report() # Summary correct_count = sum(1 for log in results_log if log.get("Status") == "✅") total_count = len(results_log) accuracy = (correct_count / total_count * 100) if total_count > 0 else 0 print(f"\n{'='*70}") print(f"📊 PRE-SUBMISSION SUMMARY") print(f"{'='*70}") print(f"Correct: {correct_count}/{total_count} ({accuracy:.1f}%)") print(f"{'='*70}\n") if not answers_payload: return "No answers produced", pd.DataFrame(results_log) # Submit submission_data = { "username": username.strip(), "agent_code": agent_code, "answers": answers_payload } print(f"\n{'='*70}") print(f"📤 SUBMITTING") print(f"{'='*70}\n") try: response = requests.post(submit_url, json=submission_data, timeout=60) response.raise_for_status() result_data = response.json() final_status = ( f"Submission Successful!\n" f"User: {result_data.get('username')}\n" f"Score: {result_data.get('score', 'N/A')}% " f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')})\n" f"Message: {result_data.get('message', 'No message')}" ) print(final_status) results_df = pd.DataFrame(results_log) return final_status, results_df except Exception as e: print(f"❌ Submission failed: {e}") results_df = pd.DataFrame(results_log) return f"Submission failed: {e}", results_df # ============================================================================= # GRADIO INTERFACE # ============================================================================= with gr.Blocks() as demo: gr.Markdown("# GAIA Agent Evaluation - Refactored") gr.Markdown(""" **Improvements:** - Better error handling with retry logic - Caching for search results - Telemetry and progress tracking - Memory management - Modular architecture **Instructions:** 1. Clone this space and modify as needed 2. Login with HuggingFace account 3. Click 'Run Evaluation & Submit' """) gr.LoginButton() run_button = gr.Button("Run Evaluation & Submit All") status_output = gr.Textbox(label="Status", lines=5, interactive=False) results_table = gr.DataFrame(label="Results", wrap=True) run_button.click( fn=run_and_submit_all, outputs=[status_output, results_table], queue=False ) if __name__ == "__main__": print("\n" + "-"*70) print("Starting Refactored GAIA Agent") print("-"*70 + "\n") space_host = os.getenv("SPACE_HOST") space_id = os.getenv("SPACE_ID") if space_host: print(f"✅ SPACE_HOST: {space_host}") print(f" URL: https://{space_host}.hf.space") if space_id: print(f"✅ SPACE_ID: {space_id}") print(f" Repo: https://huggingface.co/spaces/{space_id}") print("\n" + "-"*70 + "\n") demo.launch(debug=True, share=False, ssr_mode=False)