diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,3 +1,14 @@ +""" +GAIA Benchmark Agent - Refactored Version +Improvements: +- Better error handling with retry logic +- Caching for expensive operations +- Telemetry and progress tracking +- Modular architecture +- Parallel processing support +- Memory management +""" + import os import io import subprocess @@ -8,9 +19,10 @@ import contextlib import uuid import time import ast -from typing import List, Optional, TypedDict, Annotated, Dict +from typing import List, Optional, TypedDict, Annotated, Dict, Tuple from pathlib import Path -from collections import Counter +from collections import Counter, defaultdict +from functools import wraps, lru_cache import gradio as gr import pandas as pd import numpy as np @@ -28,7 +40,6 @@ from googleapiclient.discovery import build from googleapiclient.errors import HttpError import assemblyai as aai - # LangChain & LangGraph from langgraph.graph.message import add_messages from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, AnyMessage, ToolCall @@ -37,9 +48,6 @@ from langgraph.prebuilt import ToolNode from langgraph.graph import START, END, StateGraph from langchain_groq import ChatGroq from langchain_google_genai import ChatGoogleGenerativeAI -from langchain_community.llms import HuggingFaceHub -from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint - # RAG from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -51,159 +59,323 @@ from langchain_core.documents import Document # ============================================================================= # CONFIGURATION # ============================================================================= -DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" -MAX_TURNS = 25 -MAX_MESSAGE_LENGTH = 8000 -REFLECT_EVERY_N_TURNS = 5 +class Config: + DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" + MAX_TURNS = 25 + MAX_MESSAGE_LENGTH = 8000 + REFLECT_EVERY_N_TURNS = 5 + MAX_RETRIES = 3 + BASE_RETRY_DELAY = 1 + CACHE_SIZE = 100 + MAX_PARALLEL_WORKERS = 3 + CHUNK_SIZE = 500 + CHUNK_OVERLAP = 50 + +config = Config() # ============================================================================= -# GLOBAL RAG COMPONENTS +# UTILITIES: RETRY & CACHING # ============================================================================= -global_embeddings = None -global_text_splitter = None +def retry_with_backoff(max_retries=None, base_delay=None): + """Decorator for automatic retry with exponential backoff""" + max_retries = max_retries or config.MAX_RETRIES + base_delay = base_delay or config.BASE_RETRY_DELAY + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt == max_retries - 1: + raise + delay = base_delay * (2 ** attempt) + print(f"āš ļø {func.__name__} retry {attempt+1}/{max_retries} after {delay}s: {e}") + time.sleep(delay) + return wrapper + return decorator + + +class SearchCache: + """LRU cache for search results""" + def __init__(self, maxsize=None): + self.maxsize = maxsize or config.CACHE_SIZE + self._cache = {} + self._access_order = [] + + def get(self, key: str) -> Optional[str]: + if key in self._cache: + # Move to end (most recently used) + self._access_order.remove(key) + self._access_order.append(key) + return self._cache[key] + return None + + def put(self, key: str, value: str): + if key in self._cache: + self._access_order.remove(key) + elif len(self._cache) >= self.maxsize: + # Remove least recently used + oldest = self._access_order.pop(0) + del self._cache[oldest] + + self._cache[key] = value + self._access_order.append(key) + + def clear(self): + self._cache.clear() + self._access_order.clear() + +search_cache = SearchCache() -def initialize_rag_components(): - """Initialize RAG components globally.""" - global global_embeddings, global_text_splitter +# ============================================================================= +# TELEMETRY +# ============================================================================= +class Telemetry: + """Track tool usage, timing, and errors""" + def __init__(self): + self.tool_times = defaultdict(list) + self.tool_errors = defaultdict(int) + self.tool_calls = defaultdict(int) + self.start_time = time.time() + + def record_call(self, tool_name: str, duration: float, success: bool): + self.tool_calls[tool_name] += 1 + self.tool_times[tool_name].append(duration) + if not success: + self.tool_errors[tool_name] += 1 + + def report(self): + total_time = time.time() - self.start_time + print(f"\n{'='*70}") + print(f"šŸ“Š TELEMETRY REPORT") + print(f"{'='*70}") + print(f"Total runtime: {total_time:.2f}s") + print(f"\nTool Usage:") + + for tool in sorted(self.tool_calls.keys()): + calls = self.tool_calls[tool] + times = self.tool_times[tool] + errors = self.tool_errors[tool] + avg_time = sum(times) / len(times) if times else 0 + + print(f" {tool}:") + print(f" Calls: {calls}") + print(f" Avg time: {avg_time:.2f}s") + print(f" Errors: {errors}") + + print(f"{'='*70}\n") - if global_embeddings is None: - print("Initializing RAG embeddings...") + def reset(self): + self.tool_times.clear() + self.tool_errors.clear() + self.tool_calls.clear() + self.start_time = time.time() + +telemetry = Telemetry() + +# ============================================================================= +# PROGRESS TRACKER +# ============================================================================= +class ProgressTracker: + """Track question processing progress""" + def __init__(self, total: int): + self.total = total + self.current = 0 + self.correct = 0 + self.start_time = time.time() + + def update(self, is_correct: bool): + self.current += 1 + if is_correct: + self.correct += 1 + + accuracy = (self.correct / self.current) * 100 if self.current > 0 else 0 + elapsed = time.time() - self.start_time + avg_time = elapsed / self.current if self.current > 0 else 0 + eta = avg_time * (self.total - self.current) + + print(f"šŸ“Š Progress: {self.current}/{self.total} ({self.current/self.total*100:.1f}%)") + print(f" Accuracy: {accuracy:.1f}% ({self.correct} correct)") + print(f" Avg time: {avg_time:.1f}s per question") + print(f" ETA: {eta/60:.1f} minutes") + +# ============================================================================= +# CUSTOM EXCEPTIONS +# ============================================================================= +class ToolError(Exception): + """Custom exception with context""" + def __init__(self, tool_name: str, error: Exception, suggestion: str = ""): + self.tool_name = tool_name + self.original_error = error + self.suggestion = suggestion + message = f"Tool '{tool_name}' failed: {error}" + if suggestion: + message += f"\nšŸ’” Suggestion: {suggestion}" + super().__init__(message) + +# ============================================================================= +# GLOBAL RAG COMPONENTS +# ============================================================================= +class RAGManager: + """Manage RAG components with lazy initialization""" + def __init__(self): + self.embeddings = None + self.text_splitter = None + self._initialized = False + + def initialize(self): + if self._initialized: + return True + + print("Initializing RAG components...") try: - global_embeddings = HuggingFaceEmbeddings( + self.embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'} ) - print("āœ… Embeddings initialized.") + + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.CHUNK_SIZE, + chunk_overlap=config.CHUNK_OVERLAP, + length_function=len, + separators=["\n\n", "\n", ". ", " ", ""] + ) + + self._initialized = True + print("āœ… RAG components initialized") + return True + except Exception as e: - print(f"āš ļø Failed to initialize embeddings: {e}") + print(f"āŒ RAG initialization failed: {e}") return False - if global_text_splitter is None: - print("Initializing text splitter...") - global_text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, - chunk_overlap=50, - length_function=len, - separators=["\n\n", "\n", ". ", " ", ""] - ) - print("āœ… Text splitter initialized.") - - return True -# ============================================================================= -# ANSWER SHEET VALIDATION FUNCTIONS -# ============================================================================= + def is_ready(self): + return self._initialized -def load_answer_sheet(filepath: str = "answer_sheet.json") -> Dict[str, str]: - """Load the answer sheet from a JSON file""" - try: - if os.path.exists(filepath): - with open(filepath, 'r', encoding='utf-8') as f: - answers = json.load(f) - print(f"āœ… Loaded answer sheet with {len(answers)} answers from {filepath}") - return answers - else: - print(f"āš ļø Answer sheet not found at {filepath}") - return {} - except Exception as e: - print(f"āŒ Error loading answer sheet: {e}") - return {} +rag_manager = RAGManager() - -def check_answer_correctness(submitted: str, correct: str) -> tuple[bool, str]: - """ - Check if submitted answer matches correct answer with fuzzy matching - Returns: (is_correct, feedback_message) - """ - # Normalize both answers - submitted_norm = submitted.strip().lower() - correct_norm = correct.strip().lower() - - # Exact match - if submitted_norm == correct_norm: - return True, "āœ… EXACT MATCH" - - # Remove common punctuation and check again - import string - submitted_clean = submitted_norm.translate(str.maketrans('', '', string.punctuation)) - correct_clean = correct_norm.translate(str.maketrans('', '', string.punctuation)) - - if submitted_clean == correct_clean: - return True, "āœ… MATCH (punctuation difference)" - - # Check if it's a number formatting issue - try: - # Try to parse as numbers - submitted_num = float(submitted_clean.replace(',', '').replace('$', '')) - correct_num = float(correct_clean.replace(',', '').replace('$', '')) - if abs(submitted_num - correct_num) < 0.01: # Allow small floating point differences - return True, "āœ… MATCH (numeric equivalence)" - except (ValueError, AttributeError): - pass +# ============================================================================= +# ASR INITIALIZATION +# ============================================================================= +class ASRManager: + """Manage ASR pipeline""" + def __init__(self): + self.pipeline = None + self._initialized = False - # Check if submitted answer contains correct answer (for list-type answers) - if ',' in correct_norm: - correct_items = set([item.strip() for item in correct_norm.split(',')]) - submitted_items = set([item.strip() for item in submitted_norm.split(',')]) - - if correct_items == submitted_items: - return True, "āœ… MATCH (item order difference)" + def initialize(self): + if self._initialized: + return True - missing_items = correct_items - submitted_items - extra_items = submitted_items - correct_items - - if missing_items and not extra_items: - return False, f"āŒ MISSING: {', '.join(missing_items)}" - elif extra_items and not missing_items: - return False, f"āŒ EXTRA: {', '.join(extra_items)}" - elif missing_items and extra_items: - return False, f"āŒ MISSING: {', '.join(missing_items)} | EXTRA: {', '.join(extra_items)}" - - # Check case-insensitive substring match - if submitted_norm in correct_norm or correct_norm in submitted_norm: - return False, f"āŒ PARTIAL MATCH (submitted: '{submitted}' | correct: '{correct}')" + try: + print("Loading ASR (Whisper) pipeline...") + device = 0 if torch.cuda.is_available() else -1 + device_name = "cuda:0" if device == 0 else "cpu" + print(f"Using device: {device_name}") + + self.pipeline = pipeline( + "automatic-speech-recognition", + model="openai/whisper-base", + torch_dtype=torch.float16 if device == 0 else torch.float32, + device=device + ) + + self._initialized = True + print("āœ… ASR pipeline loaded") + return True + + except Exception as e: + print(f"āš ļø ASR pipeline failed to load: {e}") + return False - return False, f"āŒ WRONG (submitted: '{submitted}' | correct: '{correct}')" - + def is_ready(self): + return self._initialized -def create_answer_sheet_template(questions: List[Dict], filepath: str = "answer_sheet.json"): - """Create an answer sheet template from questions""" - answer_template = {} - for q in questions: - answer_template[q['task_id']] = "" - - with open(filepath, 'w', encoding='utf-8') as f: - json.dump(answer_template, f, indent=2) - - print(f"āœ… Created answer sheet template at {filepath}") - print(f" Please fill in the correct answers for {len(answer_template)} questions") +asr_manager = ASRManager() - # ============================================================================= -# ASR INITIALIZATION +# ANSWER VALIDATION # ============================================================================= -asr_pipeline = None -try: - print("Loading ASR (Whisper) pipeline globally...") - device = 0 if torch.cuda.is_available() else -1 - device_name = "cuda:0" if device == 0 else "cpu" - print(f"Attempting to use device: {device_name} for ASR.") - asr_pipeline = pipeline( - "automatic-speech-recognition", - model="openai/whisper-base", - torch_dtype=torch.float16 if device == 0 else torch.float32, - device=device - ) - print("āœ… ASR (Whisper) pipeline loaded successfully.") -except Exception as e: - print(f"āš ļø Warning: Could not load ASR pipeline globally. Error: {e}") - asr_pipeline = None +class AnswerValidator: + """Validate and check answers""" + + @staticmethod + def load_answer_sheet(filepath: str = "answer_sheet.json") -> Dict[str, str]: + """Load answer sheet""" + try: + if os.path.exists(filepath): + with open(filepath, 'r', encoding='utf-8') as f: + answers = json.load(f) + print(f"āœ… Loaded {len(answers)} answers from {filepath}") + return answers + else: + print(f"āš ļø Answer sheet not found: {filepath}") + return {} + except Exception as e: + print(f"āŒ Error loading answer sheet: {e}") + return {} + + @staticmethod + def check_correctness(submitted: str, correct: str) -> Tuple[bool, str]: + """Check if answer is correct with fuzzy matching""" + import string + + submitted_norm = submitted.strip().lower() + correct_norm = correct.strip().lower() + + # Exact match + if submitted_norm == correct_norm: + return True, "āœ… EXACT MATCH" + + # Remove punctuation + trans = str.maketrans('', '', string.punctuation) + submitted_clean = submitted_norm.translate(trans) + correct_clean = correct_norm.translate(trans) + + if submitted_clean == correct_clean: + return True, "āœ… MATCH (punctuation)" + + # Numeric comparison + try: + submitted_num = float(submitted_clean.replace(',', '').replace('$', '')) + correct_num = float(correct_clean.replace(',', '').replace('$', '')) + if abs(submitted_num - correct_num) < 0.01: + return True, "āœ… MATCH (numeric)" + except (ValueError, AttributeError): + pass + + # List comparison + if ',' in correct_norm: + correct_items = set(item.strip() for item in correct_norm.split(',')) + submitted_items = set(item.strip() for item in submitted_norm.split(',')) + + if correct_items == submitted_items: + return True, "āœ… MATCH (order)" + + missing = correct_items - submitted_items + extra = submitted_items - correct_items + + if missing or extra: + msg = [] + if missing: + msg.append(f"MISSING: {', '.join(missing)}") + if extra: + msg.append(f"EXTRA: {', '.join(extra)}") + return False, f"āŒ {' | '.join(msg)}" + + # Partial match + if submitted_norm in correct_norm or correct_norm in submitted_norm: + return False, f"āŒ PARTIAL ('{submitted}' vs '{correct}')" + + return False, f"āŒ WRONG ('{submitted}' vs '{correct}')" # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= -def remove_fences_simple(text): - """Remove code fences from text.""" - original_text = text +def remove_fences_simple(text: str) -> str: + """Remove code fences""" text = text.strip() if text.startswith("```") and text.endswith("```"): text = text[3:-3].strip() @@ -211,207 +383,239 @@ def remove_fences_simple(text): first_line, rest = text.split('\n', 1) if first_line.strip().replace('_','').isalnum() and len(first_line.strip()) < 15: text = rest.strip() - return text - return original_text + return text -def truncate_if_needed(content: str, max_length: int = MAX_MESSAGE_LENGTH) -> str: - """Truncate content if it exceeds max length.""" +def truncate_if_needed(content: str, max_length: int = None) -> str: + """Truncate long content""" + max_length = max_length or config.MAX_MESSAGE_LENGTH if len(content) > max_length: - return content[:max_length] + f"\n...[truncated, {len(content)} total chars]" + return content[:max_length] + f"\n...[truncated, {len(content)} chars total]" return content def find_file(path: str) -> Optional[Path]: - """Find a file by trying multiple path variations.""" + """Find file with multiple path attempts""" script_dir = Path.cwd() safe_path = Path(path).as_posix() - paths_to_try = [ + paths = [ script_dir / safe_path, Path(safe_path), - script_dir / Path(path).name + script_dir / Path(path).name, + Path("files") / Path(path).name ] - for attempt_path in paths_to_try: - if attempt_path.exists(): - return attempt_path + for p in paths: + if p.exists(): + return p return None # ============================================================================= -# PLANNING & REFLECTION TOOLS +# TOOL INPUT VALIDATION # ============================================================================= +def validate_tool_inputs(tool_name: str, inputs: dict) -> Tuple[bool, str]: + """Validate tool inputs before execution""" + validators = { + "scrape_and_retrieve": lambda i: i.get("url", "").startswith(("http://", "https://")), + "calculator": lambda i: bool(re.match(r'^[\d\+\-\*/\(\)\s\.,a-z]+$', i.get("expression", ""), re.I)), + "read_file": lambda i: len(i.get("path", "")) > 0 and ".." not in i.get("path", ""), + "search_tool": lambda i: len(i.get("query", "").strip()) > 0, + "code_interpreter": lambda i: "import os" not in i.get("code", "").lower(), + } + + if tool_name in validators: + try: + if not validators[tool_name](inputs): + return False, f"Invalid input format for {tool_name}" + except Exception as e: + return False, f"Validation error: {e}" + + return True, "" +# ============================================================================= +# PLANNING & REFLECTION TOOLS +# ============================================================================= class ThinkInput(BaseModel): - reasoning: str = Field(description="Brief reasoning summary (under 150 chars)") + reasoning: str = Field(description="Brief reasoning (under 150 chars)") @tool(args_schema=ThinkInput) def think_through_logic(reasoning: str) -> str: - """ - Use this to work through logic puzzles, riddles, or reasoning problems. - - Call this when: - - The question is a riddle or brain teaser - - You need to reason through a logical problem - - No external information is needed, just thinking - - After thinking, use calculator if math is involved, then validate and submit answer. - """ - print(f"🧠 Thinking: {reasoning[:100]}...") - - return f"""āœ… Logic reasoning recorded. + """Think through logic puzzles and riddles""" + start_time = time.time() + try: + print(f"🧠 Thinking: {reasoning[:100]}...") + result = f"""āœ… Logic reasoning recorded. -Next steps: -1. If math needed → use calculator() -2. Once you have answer → use validate_answer() -3. Then → use final_answer_tool() +Next: calculator (if math) → validate_answer → final_answer_tool -Remember: You MUST call another tool. Do not output reasoning text.""" +Remember: MUST call another tool.""" + telemetry.record_call("think_through_logic", time.time() - start_time, True) + return result + except Exception as e: + telemetry.record_call("think_through_logic", time.time() - start_time, False) + raise class PlanInput(BaseModel): - task_summary: str = Field(description="Very brief task summary (under 80 chars)") + task_summary: str = Field(description="Brief task summary (under 80 chars)") @tool(args_schema=PlanInput) def create_plan(task_summary: str) -> str: - """ - Creates a plan for multi-step questions. Use for complex tasks only. - Keep the summary VERY brief to avoid errors. - """ - print(f"šŸ“‹ Planning: {task_summary[:80]}...") - - return f"""āœ… Plan created for: {task_summary} + """Create plan for complex tasks""" + start_time = time.time() + try: + print(f"šŸ“‹ Planning: {task_summary[:80]}...") + result = f"""āœ… Plan: {task_summary} -FRAMEWORK: -1. What info do I need? -2. What tools will I use? -3. In what order? +Framework: +1. What info needed? +2. Which tools? +3. What order? -Now execute step 1. You MUST call a tool next.""" +Execute step 1 now.""" + telemetry.record_call("create_plan", time.time() - start_time, True) + return result + except Exception as e: + telemetry.record_call("create_plan", time.time() - start_time, False) + raise class ReflectInput(BaseModel): - situation: str = Field(description="Brief situation summary (under 80 chars)") + situation: str = Field(description="Brief situation (under 80 chars)") @tool(args_schema=ReflectInput) def reflect_on_progress(situation: str) -> str: - """ - Reflects on progress when stuck. Use after 5+ turns without progress. - Keep situation summary VERY brief. - """ - print(f"šŸ¤” Reflecting: {situation[:80]}...") - - return f"""šŸ” REFLECTION on: {situation} + """Reflect when stuck""" + start_time = time.time() + try: + print(f"šŸ¤” Reflecting: {situation[:80]}...") + result = f"""šŸ” Reflection: {situation} -QUESTIONS: -1. Am I using the right approach? -2. Should I try a different tool? -3. Do I actually have the answer already? +Questions: +1. Right approach? +2. Try different tool? +3. Have answer already? -Take a DIFFERENT approach now. You MUST call a tool next.""" +Try DIFFERENT approach now.""" + telemetry.record_call("reflect_on_progress", time.time() - start_time, True) + return result + except Exception as e: + telemetry.record_call("reflect_on_progress", time.time() - start_time, False) + raise class ValidateInput(BaseModel): - proposed_answer: str = Field(description="The answer to validate") + proposed_answer: str = Field(description="Answer to validate") original_question: str = Field(description="Original question (first 100 chars)") @tool(args_schema=ValidateInput) def validate_answer(proposed_answer: str, original_question: str) -> str: - """ - Validates answer format before submission. ALWAYS use before final_answer_tool. - """ - print(f"āœ“ Validating: '{proposed_answer[:50]}...'") - - issues = [] - warnings = [] - - # Check for conversational fluff - fluff = ["the answer is", "based on", "according to", "i found", "here is"] - if any(p in proposed_answer.lower() for p in fluff): - issues.append("āŒ Remove conversational text. Answer only.") - - # Check for code fences - if "```" in proposed_answer: - issues.append("āŒ Remove code fences (```).") - - # Check length - if len(proposed_answer) > 500: - warnings.append("āš ļø Answer very long. Just the answer?") - - # Check for number questions - if any(k in original_question.lower() for k in ["how many", "what number", "count"]): - if not any(c.isdigit() for c in proposed_answer): - warnings.append("āš ļø Question asks for number but answer has no digits.") - - if issues: - return "🚫 VALIDATION FAILED:\n" + "\n".join(issues) + "\n\nFix then retry." - - if warnings: - return "āš ļø WARNINGS:\n" + "\n".join(warnings) + "\n\nConsider fixing, or proceed if confident." - - return "āœ… VALIDATION PASSED! Now call final_answer_tool() with this answer." - + """Validate answer before submission""" + start_time = time.time() + try: + print(f"āœ“ Validating: '{proposed_answer[:50]}...'") + + issues = [] + warnings = [] + + # Check conversational fluff + fluff = ["the answer is", "based on", "according to", "i found", "here is"] + if any(p in proposed_answer.lower() for p in fluff): + issues.append("āŒ Remove conversational text") + + # Check code fences + if "```" in proposed_answer: + issues.append("āŒ Remove code fences") + + # Check length + if len(proposed_answer) > 500: + warnings.append("āš ļø Very long answer") + + # Check numbers + if any(k in original_question.lower() for k in ["how many", "what number", "count"]): + if not any(c.isdigit() for c in proposed_answer): + warnings.append("āš ļø Number expected but none found") + + if issues: + result = "🚫 VALIDATION FAILED:\n" + "\n".join(issues) + elif warnings: + result = "āš ļø WARNINGS:\n" + "\n".join(warnings) + "\n\nProceed if confident." + else: + result = "āœ… PASSED! Call final_answer_tool() now." + + telemetry.record_call("validate_answer", time.time() - start_time, True) + return result + + except Exception as e: + telemetry.record_call("validate_answer", time.time() - start_time, False) + raise # ============================================================================= # CORE TOOLS # ============================================================================= - class SearchInput(BaseModel): query: str = Field(description="Search query (concise)") @tool(args_schema=SearchInput) +@retry_with_backoff(max_retries=3) def search_tool(query: str) -> str: - """ - Search the web for information. Returns snippets. - - IMPORTANT: Search results are SNIPPETS only. For complete information: - 1. Use search_tool to find URLs - 2. Use scrape_and_retrieve to get FULL page content - - Example workflow: - - search_tool("Mercedes Sosa Wikipedia") → get URL - - scrape_and_retrieve(url=..., query="studio albums 2000-2009") - """ - if not isinstance(query, str) or not query.strip(): - return "Error: Invalid query." - - # Auto-add Wikipedia site filter if mentioned - if 'wikipedia' in query.lower() and 'site:' not in query: - query = f"{query} site:wikipedia.org" - - print(f"šŸ” Searching: {query}") + """Web search with caching""" + start_time = time.time() - max_retries = 3 - for attempt in range(max_retries): - try: - search = DuckDuckGoSearchRun() - result = search.run(query) - - if not result or len(result) < 50: - return "No relevant results found. Try different search terms or check if the information exists." - - return truncate_if_needed(result) - except Exception as e: - if attempt < max_retries - 1: - time.sleep(2 ** attempt) - continue - return f"Search error after {max_retries} attempts: {str(e)}" + try: + # Input validation + is_valid, msg = validate_tool_inputs("search_tool", {"query": query}) + if not is_valid: + raise ValueError(msg) + + # Check cache + cached = search_cache.get(query) + if cached: + print(f"šŸ” Search (cached): {query}") + telemetry.record_call("search_tool", time.time() - start_time, True) + return cached + + # Auto-add Wikipedia filter + if 'wikipedia' in query.lower() and 'site:' not in query: + query = f"{query} site:wikipedia.org" + + print(f"šŸ” Searching: {query}") + + search = DuckDuckGoSearchRun() + result = search.run(query) + + if not result or len(result) < 50: + result = "No results found. Try different terms." + + result = truncate_if_needed(result) + + # Cache result + search_cache.put(query, result) + + telemetry.record_call("search_tool", time.time() - start_time, True) + return result + + except Exception as e: + telemetry.record_call("search_tool", time.time() - start_time, False) + raise ToolError("search_tool", e, "Try rephrasing query") class CalcInput(BaseModel): - expression: str = Field(description="Math expression (e.g., '2+2', 'sqrt(16)')") + expression: str = Field(description="Math expression") @tool(args_schema=CalcInput) def calculator(expression: str) -> str: - """ - Evaluates math expressions. Use for ANY calculations. - Supports: +, -, *, /, **, sqrt, sin, cos, log, pi, e, etc. - """ - if not isinstance(expression, str) or not expression.strip(): - return "Error: Invalid expression." - - print(f"🧮 Calculating: {expression}") + """Evaluate math expressions""" + start_time = time.time() try: + # Input validation + is_valid, msg = validate_tool_inputs("calculator", {"expression": expression}) + if not is_valid: + raise ValueError(msg) + + print(f"🧮 Calculating: {expression}") + import math safe_dict = { 'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, @@ -421,36 +625,37 @@ def calculator(expression: str) -> str: } result = eval(expression, {"__builtins__": {}}, safe_dict) + + telemetry.record_call("calculator", time.time() - start_time, True) return str(result) + except Exception as e: - return f"Calculation error for '{expression}': {str(e)}" + telemetry.record_call("calculator", time.time() - start_time, False) + raise ToolError("calculator", e, f"Check expression: {expression}") class CodeInput(BaseModel): - code: str = Field(description="Python code (MUST include print() for output)") + code: str = Field(description="Python code (MUST use print())") @tool(args_schema=CodeInput) def code_interpreter(code: str) -> str: - """ - Executes Python code with timeout protection. - CRITICAL: Always use print() to output results! - """ - if not isinstance(code, str): - return "Error: code must be string." - - # Safety checks - dangerous = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system', 'exec('] - if any(d in code.lower() for d in dangerous): - return f"Error: Dangerous operation not allowed." - - if 'open(' in code.lower() and any(m in code for m in ["'w'", '"w"', "'a'", '"a"']): - return "Error: File writing not allowed. Use write_file tool." - - print(f"šŸ’» Executing code ({len(code)} chars)...") - output_stream = io.StringIO() - error_stream = io.StringIO() + """Execute Python code""" + start_time = time.time() try: + # Safety checks + dangerous = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system', 'exec('] + if any(d in code.lower() for d in dangerous): + raise ValueError("Dangerous operation not allowed") + + if 'open(' in code.lower() and any(m in code for m in ["'w'", '"w"', "'a'", '"a"']): + raise ValueError("File writing not allowed, use write_file tool") + + print(f"šŸ’» Executing code ({len(code)} chars)...") + + output_stream = io.StringIO() + error_stream = io.StringIO() + with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream): safe_globals = { "pd": pd, @@ -460,20 +665,23 @@ def code_interpreter(code: str) -> str: "__builtins__": __builtins__ } exec(code, safe_globals, {}) - + stdout = output_stream.getvalue() stderr = error_stream.getvalue() if stderr: - return f"Error:\n{stderr}\n\nStdout:\n{stdout}" - - if stdout: - return truncate_if_needed(stdout) + result = f"Error:\n{stderr}\n\nOutput:\n{stdout}" + elif stdout: + result = truncate_if_needed(stdout) + else: + result = "Code executed but no output. Use print()!" - return "Code executed but no output. Remember to use print()!" + telemetry.record_call("code_interpreter", time.time() - start_time, True) + return result except Exception as e: - return f"Execution failed:\n{traceback.format_exc()}" + telemetry.record_call("code_interpreter", time.time() - start_time, False) + raise ToolError("code_interpreter", e, "Check code syntax") class ReadFileInput(BaseModel): @@ -481,23 +689,32 @@ class ReadFileInput(BaseModel): @tool(args_schema=ReadFileInput) def read_file(path: str) -> str: - """Reads file content.""" - if not isinstance(path, str) or not path.strip(): - return "Error: Invalid path." - - print(f"šŸ“„ Reading: {path}") - - file_path = find_file(path) - if not file_path: - return f"Error: File not found: '{path}'\nCWD files: {os.listdir('.')}" + """Read file content""" + start_time = time.time() try: + # Input validation + is_valid, msg = validate_tool_inputs("read_file", {"path": path}) + if not is_valid: + raise ValueError(msg) + + print(f"šŸ“„ Reading: {path}") + + file_path = find_file(path) + if not file_path: + raise FileNotFoundError(f"File not found: {path}") + content = file_path.read_text(encoding='utf-8') + + telemetry.record_call("read_file", time.time() - start_time, True) return truncate_if_needed(content) + except UnicodeDecodeError: - return f"Error: Binary file. Size: {file_path.stat().st_size} bytes. Try audio_transcription_tool for audio." + telemetry.record_call("read_file", time.time() - start_time, False) + return f"Binary file. Try audio_transcription_tool." except Exception as e: - return f"Read error: {str(e)}" + telemetry.record_call("read_file", time.time() - start_time, False) + raise ToolError("read_file", e, f"Check file path: {path}") class WriteFileInput(BaseModel): @@ -506,19 +723,22 @@ class WriteFileInput(BaseModel): @tool(args_schema=WriteFileInput) def write_file(path: str, content: str) -> str: - """Writes content to file.""" - if not path or not isinstance(content, str): - return "Error: Invalid inputs." - - print(f"āœļø Writing: {path}") + """Write content to file""" + start_time = time.time() try: + print(f"āœļø Writing: {path}") + file_path = Path.cwd() / path file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_text(content, encoding='utf-8') - return f"Wrote {len(content)} chars to '{path}'." + + telemetry.record_call("write_file", time.time() - start_time, True) + return f"Wrote {len(content)} chars to '{path}'" + except Exception as e: - return f"Write error: {str(e)}" + telemetry.record_call("write_file", time.time() - start_time, False) + raise ToolError("write_file", e) class ListDirInput(BaseModel): @@ -526,19 +746,21 @@ class ListDirInput(BaseModel): @tool(args_schema=ListDirInput) def list_directory(path: str = ".") -> str: - """Lists directory contents.""" - print(f"šŸ“ Listing: {path}") + """List directory contents""" + start_time = time.time() try: + print(f"šŸ“ Listing: {path}") + dir_path = Path.cwd() / path if path != "." else Path.cwd() if not dir_path.is_dir(): - return f"Error: '{path}' not a directory." + raise NotADirectoryError(f"'{path}' not a directory") items = sorted(dir_path.iterdir()) if not items: - return f"Directory '{path}' is empty." + return f"Directory '{path}' is empty" files, dirs = [], [] @@ -554,9 +776,12 @@ def list_directory(path: str = ".") -> str: if files: result += "Files:\n" + "\n".join(files) + telemetry.record_call("list_directory", time.time() - start_time, True) return result + except Exception as e: - return f"List error: {str(e)}" + telemetry.record_call("list_directory", time.time() - start_time, False) + raise ToolError("list_directory", e) class AudioInput(BaseModel): @@ -564,91 +789,76 @@ class AudioInput(BaseModel): @tool(args_schema=AudioInput) def audio_transcription_tool(file_path: str) -> str: - """Transcribes audio using Whisper.""" - if not file_path: - return "Error: Invalid file path." - - print(f"šŸŽ¤ Transcribing: {file_path}") - - if asr_pipeline is None: - return "Error: ASR not available." - - audio_path = find_file(file_path) - if not audio_path: - return f"Error: Audio file not found: '{file_path}'" + """Transcribe audio using Whisper""" + start_time = time.time() try: - transcription = asr_pipeline( + print(f"šŸŽ¤ Transcribing: {file_path}") + + if not asr_manager.is_ready(): + asr_manager.initialize() + + if not asr_manager.is_ready(): + raise RuntimeError("ASR not available") + + audio_path = find_file(file_path) + if not audio_path: + raise FileNotFoundError(f"Audio file not found: {file_path}") + + transcription = asr_manager.pipeline( str(audio_path), - return_timestamps=True, # ← Add this! - chunk_length_s=30, # ← Process in 30-second chunks - stride_length_s=5 # ← 5-second overlap between chunks + return_timestamps=True, + chunk_length_s=30, + stride_length_s=5 ) - # Extract just the text (ignore timestamps) result_text = transcription.get("text", "") - # OR if you want to see the chunks: - # chunks = transcription.get("chunks", []) - # result_text = " ".join([chunk["text"] for chunk in chunks]) if not result_text: - return "Error: Transcription empty." + raise ValueError("Transcription empty") + telemetry.record_call("audio_transcription_tool", time.time() - start_time, True) return f"Transcription:\n{truncate_if_needed(result_text)}" + except Exception as e: - return f"Transcription error: {str(e)}" + telemetry.record_call("audio_transcription_tool", time.time() - start_time, False) + raise ToolError("audio_transcription_tool", e) class ImageAnalysisInput(BaseModel): file_path: str = Field(description="Image file path") - query: str = Field(description="What to analyze in the image") + query: str = Field(description="What to analyze") @tool(args_schema=ImageAnalysisInput) def analyze_image(file_path: str, query: str) -> str: - """ - Analyzes images using Google Gemini Vision API. - Use for: chess positions, diagrams, charts, photos, screenshots. - Provide the EXACT file path from [FILE ATTACHED: ...] in the question. - """ - if not file_path or not query: - return "Error: file_path and query required." - - print(f"šŸ–¼ļø Analyzing image: {file_path}") - print(f" Query: {query[:100]}...") - - # Try to find the file - image_path = find_file(file_path) - - # If not found via find_file, try the path directly (for /tmp files) - if not image_path and os.path.exists(file_path): - image_path = Path(file_path) - - if not image_path or not image_path.exists(): - return f"Error: Image not found at '{file_path}'. Check [FILE ATTACHED: ...] in question for correct path." - - print(f"āœ“ Found image at: {image_path}") + """Analyze images using Gemini Vision""" + start_time = time.time() try: + print(f"šŸ–¼ļø Analyzing: {file_path}") + print(f" Query: {query[:100]}...") + + image_path = find_file(file_path) + if not image_path and os.path.exists(file_path): + image_path = Path(file_path) + + if not image_path or not image_path.exists(): + raise FileNotFoundError(f"Image not found: {file_path}") + GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") if not GOOGLE_API_KEY: - return "Error: GEMINI_API_KEY not set." + raise ValueError("GEMINI_API_KEY not set") - # Load and encode image + # Load and encode img = Image.open(image_path) - print(f" Image size: {img.size}, mode: {img.mode}") - - # Convert to RGB if necessary if img.mode not in ['RGB', 'RGBA']: img = img.convert('RGB') - # Convert to base64 buffered = io.BytesIO() img.save(buffered, format="JPEG") img_base64 = base64.b64encode(buffered.getvalue()).decode() - print(f" Encoded image: {len(img_base64)} bytes") - - # Use Gemini Vision + # Use FLASH model for cost efficiency vision_llm = ChatGoogleGenerativeAI( model="gemini-2.0-flash", google_api_key=GOOGLE_API_KEY, @@ -658,526 +868,171 @@ def analyze_image(file_path: str, query: str) -> str: message = HumanMessage( content=[ {"type": "text", "text": query}, - { - "type": "image_url", - "image_url": f"data:image/jpeg;base64,{img_base64}" - } + {"type": "image_url", "image_url": f"data:image/jpeg;base64,{img_base64}"} ] ) - print(f" Sending to Gemini Vision...") response = vision_llm.invoke([message]) - print(f"āœ“ Got response: {len(response.content)} chars") + telemetry.record_call("analyze_image", time.time() - start_time, True) return f"Image Analysis:\n{truncate_if_needed(response.content)}" except Exception as e: - error_msg = f"Image analysis error: {str(e)}" - print(f"āŒ {error_msg}") - print(traceback.format_exc()) - return error_msg + telemetry.record_call("analyze_image", time.time() - start_time, False) + raise ToolError("analyze_image", e) class YoutubeInput(BaseModel): video_url: str = Field(description="YouTube URL") @tool(args_schema=YoutubeInput) +@retry_with_backoff(max_retries=2) def get_youtube_transcript(video_url: str) -> str: - """ - Fetches YouTube video transcript using AssemblyAI. - Works reliably on Hugging Face Spaces. - """ + """Get YouTube transcript using AssemblyAI""" + start_time = time.time() + try: - # Set API key (store in HF Spaces secrets) aai.settings.api_key = os.getenv("ASSEMBLYAI_API_KEY") - print(f"šŸ“ŗ Transcribing: {video_url}") - # Transcribe directly from YouTube URL transcriber = aai.Transcriber() transcript = transcriber.transcribe(video_url) - # Wait for transcription if transcript.status == aai.TranscriptStatus.error: - return f"Error: {transcript.error}" + raise RuntimeError(transcript.error) - print(f"āœ“ Transcribed {len(transcript.text)} chars") + telemetry.record_call("get_youtube_transcript", time.time() - start_time, True) return f"Transcript:\n{transcript.text}" except Exception as e: - return f"Error: {str(e)}" + telemetry.record_call("get_youtube_transcript", time.time() - start_time, False) + raise ToolError("get_youtube_transcript", e) class ScrapeInput(BaseModel): - url: str = Field(description="URL (must start with http:// or https://)") - query: str = Field(description="Specific information to find on the page") + url: str = Field(description="URL (http:// or https://)") + query: str = Field(description="Specific info to find") @tool(args_schema=ScrapeInput) +@retry_with_backoff(max_retries=3) def scrape_and_retrieve(url: str, query: str) -> str: - """ - Fetch and search FULL webpage content using RAG (not just snippets like search_tool). - - CRITICAL: Use this after search_tool gives you a URL. This gets the COMPLETE page. - - Workflow Example: - 1. search_tool('Mercedes Sosa Wikipedia') → get URL - 2. scrape_and_retrieve( - url='https://en.wikipedia.org/wiki/Mercedes_Sosa', - query='studio albums released between 2000 and 2009' - ) → Returns FULL discography section - - Use when: - - Counting items (albums, people, events, etc.) - - Finding specific names, dates, or numbers - - Need complete tables or lists - - Wikipedia articles, documentation, papers - - Search snippets weren't enough - """ - if not url.startswith(('http://', 'https://')): - return f"Error: Invalid URL format. Must start with http:// or https://" - if not query: - return "Error: Query required to search the page content." - - if global_embeddings is None or global_text_splitter is None: - if not initialize_rag_components(): - return "Error: RAG components not initialized." - - print(f"🌐 Scraping: {url}") - print(f" Looking for: {query[:100]}...") - - max_retries = 3 - for attempt in range(max_retries): - try: - headers = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' - } - response = requests.get(url, headers=headers, timeout=20) - response.raise_for_status() - - soup = BeautifulSoup(response.text, 'html.parser') - - # Remove noise - for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe"]): - tag.extract() - - # Extract main content - main = soup.find('main') or soup.find('article') or soup.find('div', class_='mw-parser-output') or soup.body - - if not main: - return "Error: Could not find main content on page." - - text = main.get_text(separator='\n', strip=True) - lines = [l.strip() for l in text.splitlines() if l.strip()] - text = '\n'.join(lines) - - if len(text) < 50: - return f"Error: Page content too short ({len(text)} chars). May be blocked or empty." - - print(f"āœ“ Extracted {len(text)} characters from page") - - # Chunk and search - chunks = global_text_splitter.split_text(text) - - if not chunks: - return "Error: Could not process page content." - - print(f"āœ“ Created {len(chunks)} chunks") - - docs = [Document(page_content=c, metadata={"source": url}) for c in chunks] - - db = FAISS.from_documents(docs, global_embeddings) - retriever = db.as_retriever(search_kwargs={"k": 5}) - retrieved = retriever.invoke(query) - - if not retrieved: - return f"No information found matching: '{query}'\nTry a different query or the information may not be on this page." - - print(f"āœ“ Found {len(retrieved)} relevant chunks") - - context = "\n\n---\n\n".join([f"[Section {i+1}]\n{d.page_content}" for i, d in enumerate(retrieved)]) - - return truncate_if_needed(f"From {url}:\n\n{context}") - - except requests.Timeout: - if attempt < max_retries - 1: - print(f"āš ļø Timeout, retrying... (attempt {attempt + 1}/{max_retries})") - time.sleep(2 ** attempt) - continue - return f"Error: Page request timed out after {max_retries} attempts." - except requests.RequestException as e: - if attempt < max_retries - 1: - time.sleep(2 ** attempt) - continue - return f"Error fetching page: {str(e)}" - except Exception as e: - return f"Error processing page: {str(e)}\n{traceback.format_exc()}" - -class ChessAnalysisInput(BaseModel): - image_path: str = Field(description="Path to chess board image file") - description: str = Field(description="Any additional context about the position (optional)", default="") - -@tool(args_schema=ChessAnalysisInput) -@tool(args_schema=ChessAnalysisInput) -def analyze_chess_position(image_path: str, description: str = "") -> str: - """ - Analyzes a chess position from an image using Stockfish engine. - - MUCH MORE RELIABLE than Lichess API because: - - Works offline - - Analyzes ANY position (not just cloud database) - - Stronger engine (Stockfish 16+) - - No rate limits or 404 errors - - Use this tool when: - - Question mentions chess, checkmate, or chess notation - - An image file shows a chess board - - Need to find the best move in a position - - Args: - image_path: Path to chess board image - description: The full question text - IMPORTANT for determining whose turn it is! - - Returns: Best move in algebraic notation (e.g., "Qh5", "Nf6+", "Rd5") - """ - if not image_path: - return "Error: image_path is required." - - print(f"ā™Ÿļø Analyzing chess position from: {image_path}") - - # Find the file - chess_image = find_file(image_path) - - # If not found via find_file, try direct path - if not chess_image and os.path.exists(image_path): - chess_image = Path(image_path) - - if not chess_image or not chess_image.exists(): - return f"Error: Chess board image not found at '{image_path}'. Check the [FILE ATTACHED: ...] path in the question." - - print(f"āœ“ Found chess image at: {chess_image}") + """Fetch and search full webpage with RAG""" + start_time = time.time() try: - # ==================================================================== - # STEP 1: Extract FEN notation from image using Gemini Vision - # ==================================================================== - GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") - if not GOOGLE_API_KEY: - return "Error: GEMINI_API_KEY not set in Space secrets." + # Input validation + is_valid, msg = validate_tool_inputs("scrape_and_retrieve", {"url": url}) + if not is_valid: + raise ValueError(msg) - print("šŸ“ø Extracting chess position from image using Gemini...") + if not rag_manager.is_ready(): + rag_manager.initialize() - # Load and encode image - img = Image.open(chess_image) - print(f" Image loaded: {img.size}, mode: {img.mode}") + if not rag_manager.is_ready(): + raise RuntimeError("RAG not available") - if img.mode not in ['RGB', 'RGBA']: - img = img.convert('RGB') + print(f"🌐 Scraping: {url}") + print(f" Looking for: {query[:100]}...") - buffered = io.BytesIO() - img.save(buffered, format="JPEG") - img_base64 = base64.b64encode(buffered.getvalue()).decode() + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' + } - # Use Gemini Vision to extract FEN - vision_llm = ChatGoogleGenerativeAI( - model="gemini-2.5-pro", - google_api_key=GOOGLE_API_KEY, - temperature=0 - ) + response = requests.get(url, headers=headers, timeout=20) + response.raise_for_status() - # Check if the question explicitly states whose turn it is - whose_turn = None - if description: - desc_lower = description.lower() - if "black" in desc_lower and ("turn" in desc_lower or "move" in desc_lower): - whose_turn = "b" - elif "white" in desc_lower and ("turn" in desc_lower or "move" in desc_lower): - whose_turn = "w" - - fen_prompt = f"""Analyze this chess board image and provide the position in FEN notation. - -CRITICAL INSTRUCTIONS: -1. Carefully identify each piece: - - White pieces (UPPERCASE): K=King, Q=Queen, R=Rook, B=Bishop, N=Knight, P=Pawn - - Black pieces (lowercase): k, q, r, b, n, p - -2. BOARD ORIENTATION - This is CRITICAL: - - In chess diagrams, the board is shown from the perspective of the player to move - - Look at the BOTTOM rank (closest to viewer): - * If bottom pieces are BLACK (lowercase in FEN) → Black to move → active color = 'b' - * If bottom pieces are WHITE (uppercase in FEN) → White to move → active color = 'w' - - The rank labels (1-8) on the side can help: - * If rank 8 is at bottom and rank 1 at top → Black's perspective → use 'b' - * If rank 1 is at bottom and rank 8 at top → White's perspective → use 'w' - {"- OVERRIDE: The question explicitly states BLACK's turn, so use 'b'" if whose_turn == "b" else ""} - {"- OVERRIDE: The question explicitly states WHITE's turn, so use 'w'" if whose_turn == "w" else ""} - -3. FEN Format (read from rank 8 to rank 1, left to right): - - Use numbers (1-8) for consecutive empty squares - - Use '/' to separate ranks - - IMPORTANT: Always write FEN from White's perspective (rank 8 first, rank 1 last) - - But set the active_color based on whose perspective the board shows - -4. Return ONLY the FEN string in this exact format: - piece_placement active_color castling en_passant halfmove fullmove - -Example: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1 - -DOUBLE-CHECK: -- Did you identify whose pieces are at the BOTTOM of the board? -- Did you set active_color correctly based on board orientation? -- Did you write piece_placement from rank 8 to rank 1? - -Return ONLY the FEN string, nothing else.""" + soup = BeautifulSoup(response.text, 'html.parser') - message = HumanMessage( - content=[ - {"type": "text", "text": fen_prompt}, - { - "type": "image_url", - "image_url": f"data:image/jpeg;base64,{img_base64}" - } - ] - ) + # Remove noise + for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe"]): + tag.extract() - response = vision_llm.invoke([message]) - fen_raw = response.content.strip() - print(f"šŸ“ Raw FEN response: {fen_raw}") - - # Clean up FEN (remove markdown, explanations, etc.) - fen = None - for line in fen_raw.split('\n'): - line = line.strip().replace('```', '').replace('fen', '') - # FEN should have '/' for ranks and spaces for components - if '/' in line and ' ' in line and not line.startswith('#'): - if any(c in line for c in 'kqrbnpKQRBNP12345678'): - fen = line - break - - if not fen: - return f"Error: Could not extract valid FEN notation from image. Response: {fen_raw[:200]}" - - print(f"āœ“ Extracted FEN: {fen}") - - # Override the turn indicator if we know from the question - if whose_turn: - fen_parts = fen.split() - if len(fen_parts) >= 2: - old_turn = fen_parts[1] - fen_parts[1] = whose_turn - fen = ' '.join(fen_parts) - print(f"šŸ”„ Corrected turn from '{old_turn}' to '{whose_turn}' based on question") - print(f"āœ“ Corrected FEN: {fen}") - - # Additional verification: Check if board orientation matches turn - # In FEN, rank 8 is first, rank 1 is last - # If bottom of image shows black pieces, it's black's turn - fen_parts = fen.split() - piece_placement = fen_parts[0] - active_color = fen_parts[1] if len(fen_parts) > 1 else 'w' - - # Get last rank (rank 1 in FEN, which is bottom if white's perspective) - ranks = piece_placement.split('/') - rank_1 = ranks[-1] # Last rank in FEN - rank_8 = ranks[0] # First rank in FEN - - # Check which color dominates bottom rank - # If showing from black's perspective, rank 8 should be at bottom - # and active color should be 'b' - black_pieces_in_rank8 = sum(1 for c in rank_8 if c.islower() and c.isalpha()) - white_pieces_in_rank8 = sum(1 for c in rank_8 if c.isupper() and c.isalpha()) - - if black_pieces_in_rank8 > white_pieces_in_rank8 and active_color == 'w': - print(f"āš ļø Warning: Rank 8 has more black pieces, likely black's perspective") - print(f" Changing active color from 'w' to 'b'") - fen_parts[1] = 'b' - fen = ' '.join(fen_parts) - - # ==================================================================== - # STEP 2: Validate FEN with python-chess - # ==================================================================== - try: - import chess - except ImportError: - return "Error: python-chess not installed. Add 'python-chess' to requirements.txt" + main = soup.find('main') or soup.find('article') or soup.find('div', class_='mw-parser-output') or soup.body - try: - board = chess.Board(fen) - print(f"āœ“ FEN validated successfully") - print(f" Turn: {'White' if board.turn else 'Black'}") - print(f" Legal moves: {board.legal_moves.count()}") - except ValueError as e: - return f"Error: Invalid FEN notation: {e}\nExtracted FEN: {fen}" - - # ==================================================================== - # STEP 3: Analyze with Stockfish - # ==================================================================== - print("šŸ” Analyzing position with Stockfish...") + if not main: + raise ValueError("Could not find main content") - try: - from stockfish import Stockfish - except ImportError: - return "Error: stockfish not installed. Add 'stockfish' to requirements.txt and install Stockfish binary" - - # Try to find Stockfish binary - stockfish_paths = [ - "/usr/games/stockfish", # Linux (apt-get install) - "/usr/local/bin/stockfish", # Mac (brew install) - "/usr/bin/stockfish", # Alternative Linux - "stockfish", # In PATH - "./stockfish", # Local directory - "C:\\Program Files\\stockfish\\stockfish.exe" # Windows - ] - - stockfish_path = None - for path in stockfish_paths: - if os.path.exists(path) or os.path.isfile(path): - stockfish_path = path - break + text = main.get_text(separator='\n', strip=True) + lines = [l.strip() for l in text.splitlines() if l.strip()] + text = '\n'.join(lines) - if not stockfish_path: - # Try running 'which stockfish' on Unix systems - try: - import subprocess - result = subprocess.run(['which', 'stockfish'], - capture_output=True, - text=True, - timeout=5) - if result.returncode == 0: - stockfish_path = result.stdout.strip() - except: - pass + if len(text) < 50: + raise ValueError(f"Content too short ({len(text)} chars)") - if not stockfish_path: - return """Error: Stockfish binary not found. Install it: -- Linux: sudo apt-get install stockfish -- Mac: brew install stockfish -- Windows: Download from stockfishchess.org -Or set the path manually in the code.""" + print(f"āœ“ Extracted {len(text)} characters") - print(f"āœ“ Found Stockfish at: {stockfish_path}") + # RAG retrieval + chunks = rag_manager.text_splitter.split_text(text) + print(f"āœ“ Created {len(chunks)} chunks") - # Initialize Stockfish - try: - stockfish = Stockfish( - path=stockfish_path, - depth=35, # Analysis depth (higher = stronger but slower) - parameters={ - "Threads": 2, - "Minimum Thinking Time": 5000, # milliseconds - "Hash": 1024, # MB of RAM - } - ) - except Exception as e: - return f"Error initializing Stockfish: {e}" + docs = [Document(page_content=c, metadata={"source": url}) for c in chunks] - # Set position - stockfish.set_fen_position(fen) + db = FAISS.from_documents(docs, rag_manager.embeddings) + retriever = db.as_retriever(search_kwargs={"k": 5}) + retrieved = retriever.invoke(query) - # Get best move - print(" Computing best move...") - best_move_uci = stockfish.get_best_move() + # Clean up memory + del db + del retriever + import gc + gc.collect() - if not best_move_uci: - return "Error: Stockfish could not find a legal move. Check if position is valid." + if not retrieved: + return f"No info found for: '{query}'. Try different query." - print(f"šŸŽÆ Best move (UCI): {best_move_uci}") + print(f"āœ“ Found {len(retrieved)} relevant chunks") - # Get evaluation - evaluation = stockfish.get_evaluation() - eval_type = evaluation.get("type", "cp") - eval_value = evaluation.get("value", 0) + context = "\n\n---\n\n".join([f"[Section {i+1}]\n{d.page_content}" for i, d in enumerate(retrieved)]) - if eval_type == "mate": - eval_str = f" (Mate in {abs(eval_value)})" - else: - # Centipawns to pawns - eval_str = f" (Eval: {eval_value/100:+.2f})" + telemetry.record_call("scrape_and_retrieve", time.time() - start_time, True) + return truncate_if_needed(f"From {url}:\n\n{context}") - # ==================================================================== - # STEP 4: Convert UCI to Standard Algebraic Notation (SAN) - # ==================================================================== - try: - uci_move = chess.Move.from_uci(best_move_uci) - san_move = board.san(uci_move) - - # Check if move leads to check/checkmate - board.push(uci_move) - if board.is_checkmate(): - check_str = " - Checkmate!" - elif board.is_check(): - check_str = " - Check" - else: - check_str = "" - - final_result = f"{san_move}{eval_str}{check_str}" - print(f"āœ… Best move: {final_result}") - - # Return JUST the move notation for clean submission - return san_move - - except Exception as e: - print(f"āš ļø Could not convert to SAN: {e}") - # Fall back to UCI notation - return best_move_uci - + except requests.Timeout: + telemetry.record_call("scrape_and_retrieve", time.time() - start_time, False) + raise ToolError("scrape_and_retrieve", TimeoutError("Request timed out"), "Check URL or try later") except Exception as e: - error_msg = f"Chess analysis failed: {str(e)}" - print(f"āŒ {error_msg}") - print(traceback.format_exc()) - return error_msg + telemetry.record_call("scrape_and_retrieve", time.time() - start_time, False) + raise ToolError("scrape_and_retrieve", e) + class FinalAnswerInput(BaseModel): - answer: str = Field(description="Final answer - EXACTLY what was asked, nothing more") + answer: str = Field(description="Final answer - exact, no fluff") @tool(args_schema=FinalAnswerInput) def final_answer_tool(answer: str) -> str: - """ - Submit final answer. CRITICAL RULES: - 1. ALWAYS call validate_answer() first - 2. Answer must be EXACTLY what was asked - 3. NO conversational text - 4. NO explanations - 5. Match requested format exactly - """ - if not isinstance(answer, str): - answer = str(answer) + """Submit final answer""" + start_time = time.time() - print(f"āœ… FINAL ANSWER SUBMITTED: {answer}") - return answer + try: + print(f"āœ… FINAL ANSWER: {answer}") + telemetry.record_call("final_answer_tool", time.time() - start_time, True) + return answer + except Exception as e: + telemetry.record_call("final_answer_tool", time.time() - start_time, False) + raise # ============================================================================= -# DEFINED TOOLS LIST +# TOOLS LIST # ============================================================================= defined_tools = [ - # Planning & Reflection think_through_logic, create_plan, reflect_on_progress, validate_answer, - - # Core tools search_tool, calculator, code_interpreter, - - # File operations read_file, write_file, list_directory, - - # Specialized audio_transcription_tool, - analyze_image, + analyze_image, get_youtube_transcript, scrape_and_retrieve, - analyze_chess_position, - - # Final final_answer_tool ] - # ============================================================================= # AGENT STATE # ============================================================================= @@ -1189,18 +1044,17 @@ class AgentState(TypedDict): tool_history: List[str] last_tool_was_thinking: bool - # ============================================================================= -# ENHANCED FALLBACK PARSER +# TOOL CALL PARSER # ============================================================================= def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]: - """Enhanced parser with multiple strategies.""" - print(f"šŸ”§ Fallback parsing (first 300 chars):\n{content[:300]}") + """Enhanced fallback parser""" + print(f"šŸ”§ Parsing tool call from: {content[:300]}...") tool_name = None tool_input = None - # STRATEGY 1: Groq's format + # Strategy 1: Groq format groq_match = re.search(r"|)", content, re.DOTALL) if groq_match: try: @@ -1212,7 +1066,7 @@ def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]: except: tool_name = None - # STRATEGY 2: Standard {...} format + # Strategy 2: Standard format if not tool_name: func_match = re.search(r"](.*)", content, re.DOTALL | re.IGNORECASE) if func_match: @@ -1227,7 +1081,7 @@ def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]: except: tool_name = None - # STRATEGY 3: Tool mention with code block → wrap in code_interpreter + # Strategy 3: Code block → code_interpreter if not tool_name and "```python" in content: try: code_match = re.search(r"```python\n(.*?)```", content, re.DOTALL) @@ -1235,53 +1089,46 @@ def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]: code = code_match.group(1).strip() tool_name = "code_interpreter" tool_input = {"code": code} - print(f"āœ“ Extracted Python code → code_interpreter") + print(f"āœ“ Extracted Python code") except: pass - # STRATEGY 4: Direct tool mention → create minimal valid call + # Strategy 4: Tool mention if not tool_name: for tool in tools: if tool.name.lower() in content.lower(): tool_name = tool.name tool_input = {} - # Try to extract arguments from content if tool.args_schema: schema = tool.args_schema.model_json_schema() for prop in schema.get('properties', {}).keys(): if prop in schema.get('required', []): - # Use placeholder tool_input[prop] = "auto_extracted" - print(f"āœ“ Found mention of '{tool_name}' → creating default call") + print(f"āœ“ Found mention: {tool_name}") break - # STRATEGY 5: Emergency - if no tool detected, force a reasonable one + # Strategy 5: Force thinking if not tool_name: - # If content looks like reasoning, use think_through_logic - if len(content) > 50 and not any(kw in content.lower() for kw in ["error", "failed", "invalid"]): + if len(content) > 50: tool_name = "think_through_logic" tool_input = {"reasoning": content[:150]} - print(f"āš ļø No tool detected → forcing think_through_logic") + print(f"āš ļø Forcing think_through_logic") - # Validate and create tool call if tool_name and tool_input is not None: - matching_tools = [t for t in tools if t.name == tool_name] - if matching_tools: + matching = [t for t in tools if t.name == tool_name] + if matching: return [ToolCall(name=tool_name, args=tool_input, id=str(uuid.uuid4()))] - else: - print(f"āŒ Tool '{tool_name}' not in available tools") - print("āŒ All parsing strategies failed") + print("āŒ All parsing failed") return [] - # ============================================================================= -# CONDITIONAL EDGE FUNCTION +# CONDITIONAL EDGE # ============================================================================= def should_continue(state: AgentState): - """Decide next step with robust logic.""" + """Decide next step""" messages = state.get('messages', []) if not messages: return "agent" @@ -1289,67 +1136,51 @@ def should_continue(state: AgentState): last_message = messages[-1] current_turn = state.get('turn', 0) - # Debug: Print what we're checking - msg_type = type(last_message).__name__ - print(f"šŸ“ Conditional check - Turn {current_turn}, Last msg type: {msg_type}") + print(f"šŸ“ Turn {current_turn}, Last: {type(last_message).__name__}") - # 1. Check turn limit - if current_turn >= MAX_TURNS: - print(f"šŸ›‘ Max turns ({MAX_TURNS}) reached") + if current_turn >= config.MAX_TURNS: + print(f"šŸ›‘ Max turns reached") return END - # 2. If last message is ToolMessage, agent needs to process it if isinstance(last_message, ToolMessage): - print(f"šŸ“Ø Tool result received from '{last_message.name}' → back to agent") + print(f"šŸ“Ø Tool result → agent") return "agent" - # 3. If last message is AIMessage with tool calls if isinstance(last_message, AIMessage) and last_message.tool_calls: - # Only check the FIRST tool call, not all of them first_tool = last_message.tool_calls[0] - tool_name = first_tool.get("name", "") - - if tool_name == "final_answer_tool": + if first_tool.get("name") == "final_answer_tool": return END - else: - return "tools" + return "tools" - # 4. If AIMessage but no tool calls (reasoning text) if isinstance(last_message, AIMessage) and not last_message.tool_calls: - # Check for consecutive AI messages (loop) if len(messages) >= 2 and isinstance(messages[-2], AIMessage) and not messages[-2].tool_calls: - print(f"āš ļø Loop detected: 2 consecutive AI messages without tools") + print(f"āš ļø Loop detected") return END - - print(f"šŸ’­ AI message without tool call → continuing to agent (will force tool)") + print(f"šŸ’­ AI without tool → agent") return "agent" - # 5. Default: continue to agent - print(f"šŸ”„ Default → continuing to agent") - + return "agent" # ============================================================================= -# ENHANCED AGENT CLASS +# MAIN AGENT CLASS # ============================================================================= class PlanningReflectionAgent: def __init__(self): - print("🧠 PlanningReflectionAgent initializing...") + print("🧠 Initializing PlanningReflectionAgent...") + # Check API keys GROQ_API_KEY = os.getenv("GROQ_API_KEY") if not GROQ_API_KEY: - raise ValueError("GROQ_API_KEY not set!") - HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") - if not HUGGINGFACEHUB_API_TOKEN: - raise ValueError("HUGGINGFACEHUB_API_TOKEN secret is not set! Please add it to your Space secrets.") + raise ValueError("GROQ_API_KEY not set") + GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") if not GOOGLE_API_KEY: - raise ValueError("GOOGLE_API_KEY not set!") - + raise ValueError("GEMINI_API_KEY not set") + self.tools = defined_tools # Initialize RAG - if not initialize_rag_components(): - print("āš ļø RAG components failed to initialize.") + rag_manager.initialize() # Build tool descriptions tool_desc_list = [] @@ -1363,278 +1194,140 @@ class PlanningReflectionAgent: desc = f"- {tool.name}: {tool.description}" tool_desc_list.append(desc) tool_descriptions = "\n".join(tool_desc_list) - - # ULTRA-AGGRESSIVE SYSTEM PROMPT - self.system_prompt = f"""You are an elite AI agent for GAIA benchmark. Your ONLY job: provide the EXACT answer requested. - -═══════════════════════════════════════════════════════════════ -āš ļø ABSOLUTE RULES - VIOLATE THESE AND YOU FAIL: -═══════════════════════════════════════════════════════════════ - -1. **EVERY TURN MUST CALL EXACTLY ONE TOOL** - No exceptions -2. **NEVER OUTPUT REASONING TEXT WITHOUT A TOOL CALL** - You will fail -3. **IDENTIFY QUESTION TYPE FIRST** - Logic? Factual? Data? Math? -4. **LOGIC PUZZLES**: think_through_logic → calculator (if needed) → validate → final_answer -5. **FACTUAL QUESTIONS**: search_tool → validate → final_answer -6. **DATA QUESTIONS**: read_file → code_interpreter → validate → final_answer -7. **ALWAYS VALIDATE**: Call validate_answer() before final_answer_tool() -8. **FINAL ANSWER FORMAT**: EXACTLY what was asked. NO "The answer is..." or explanations - -═══════════════════════════════════════════════════════════════ -šŸ“‹ QUESTION TYPE GUIDE: -═══════════════════════════════════════════════════════════════ - -**RIDDLES/LOGIC PUZZLES** (No web search needed): -- Brain teasers, puzzles, logical deduction -- Strategy: think_through_logic → calculator (if math) → validate → final_answer -- Example: "If 200 coins, 30 face-down, divide into equal piles..." - Turn 1: think_through_logic("Adventurer takes 30 coins and flips them") - Turn 2: calculator("30") [if needed] - Turn 3: validate_answer("30", question) - Turn 4: final_answer_tool("30") - -**FACTUAL/RESEARCH** (Need web): -- Who, what, when, where questions -- Strategy: search_tool → scrape_and_retrieve → validate → final_answer -- Example: "What was Einstein's birthplace population in 1900?" - Turn 1: search_tool("Albert Einstein birthplace") - Turn 2: search_tool("Ulm Germany population 1900") - Turn 3: validate_answer("50000", question) - Turn 4: final_answer_tool("50000") - -**DATA ANALYSIS** (Need files): -- CSV/Excel questions -- Strategy: list_directory → read_file → code_interpreter → validate → final_answer - -**SIMPLE MATH**: -- Calculations -- Strategy: calculator() → validate_answer() → final_answer_tool() + + self.system_prompt = f"""You are an elite AI agent for GAIA benchmark. ════════════════════════════════��══════════════════════════════ -šŸŽ“ CRITICAL EXAMPLES: +āš ļø ABSOLUTE RULES: ═══════════════════════════════════════════════════════════════ -Example 1: Logic Puzzle -Q: "Coin riddle with 200 coins, 30 face-down..." -āœ… CORRECT: - Turn 1: think_through_logic("Take 30 coins, flip all") - Turn 2: validate_answer("30", "coin riddle...") - Turn 3: final_answer_tool("30") - -āŒ WRONG: - Turn 1: [reasoning text without tool] ← FAILS! - -Example 2: Letter Bank Puzzle -Q: "Use letters to spell sentences, which letters need changing?" -āœ… CORRECT: - Turn 1: code_interpreter("code to count letters...") - Turn 2: validate_answer("A, B, C", question) - Turn 3: final_answer_tool("A, B, C") - -Example 3: Math Problem -Q: "System of equations to solve..." -āœ… CORRECT: - Turn 1: code_interpreter("import numpy; solve equations...") - Turn 2: validate_answer("0, 1, 2", question) - Turn 3: final_answer_tool("0, 1, 2") +1. EVERY TURN MUST CALL EXACTLY ONE TOOL +2. NEVER OUTPUT REASONING TEXT WITHOUT TOOL CALL +3. IDENTIFY QUESTION TYPE FIRST +4. LOGIC: think → calc → validate → final +5. FACTUAL: search → scrape → validate → final +6. DATA: read → code → validate → final +7. ALWAYS VALIDATE before final_answer +8. FINAL FORMAT: EXACTLY what asked, NO fluff ═══════════════════════════════════════════════════════════════ -šŸ“š AVAILABLE TOOLS: +šŸ“š TOOLS: ═══════════════════════════════════════════════════════════════ {tool_descriptions} ═══════════════════════════════════════════════════════════════ -⚔ EXECUTION RULES: +⚔ EXECUTION: ═══════════════════════════════════════════════════════════════ -- If you output text without a tool call, you have FAILED -- If you're unsure, use think_through_logic() to organize thoughts -- ALWAYS call a tool - preferably the right one for the question type -- After EVERY tool result, decide: "Do I have the answer? → validate → submit" -- If stuck after 3 turns: call reflect_on_progress() - -REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. +- Text without tool = FAILURE +- Unsure? → think_through_logic() +- After each tool: Have answer? → validate → submit +- Stuck after 3 turns? → reflect_on_progress() ═══════════════════════════════════════════════════════════════ """ - - #. Initialize the LLM () - print("Initializing Groq LLM...") - try: - self.llm_with_tools = ChatGroq( - temperature=0, - groq_api_key=GROQ_API_KEY, - model_name="qwen/qwen3-32b", - max_tokens=4096, - timeout=60 - ).bind_tools(self.tools, tool_choice="auto") - print("āœ… LLM initialized without FORCED tool usage.") - - except Exception as e: - print(f"āŒ Error initializing HuggingFace: {e}") - raise - print("Initializing LLM Endpoint...") -# print("Initializing HuggingFace LLM...") -# -# llm = HuggingFaceEndpoint( -# repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", # Free on HF Inference API -# huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, -# max_new_tokens=4096, -# temperature=0.01, -# ) -# chat_llm = ChatHuggingFace(llm=llm) -# print("āœ… HuggingFace LLM Endpoint initialized.") -# -# # Bind tools to the LLM -# self.llm_with_tools = chat_llm.bind_tools(self.tools) -# print("āœ… Tools bound to LLM.") -# print("Initializing Google Gemini LLM...") -# try: -# self.llm_with_tools = ChatGoogleGenerativeAI( -# model="gemini-2.5-flash", # Latest model -# google_api_key=GOOGLE_API_KEY, -# temperature=0, -# max_output_tokens=8192, -# timeout=60, -# convert_system_message_to_human=True # Important for Gemini -# ).bind_tools(self.tools, tool_choice="auto") -# print("āœ… Gemini LLM initialized.") -# except Exception as e: -# print(f"āŒ Error initializing Gemini: {e}") -# raise - - # Agent Node with AGGRESSIVE tool forcing + + # Initialize LLM + print("Initializing Groq LLM...") + self.llm_with_tools = ChatGroq( + temperature=0, + groq_api_key=GROQ_API_KEY, + model_name="llama-3.3-70b-versatile", + max_tokens=4096, + timeout=60 + ).bind_tools(self.tools, tool_choice="auto") + + print("āœ… LLM initialized") + + # Build agent graph def agent_node(state: AgentState): current_turn = state.get('turn', 0) + 1 print(f"\n{'='*70}") - print(f"šŸ¤– AGENT TURN {current_turn}/{MAX_TURNS}") + print(f"šŸ¤– TURN {current_turn}/{config.MAX_TURNS}") print('='*70) - if current_turn > MAX_TURNS: + if current_turn > config.MAX_TURNS: return { - "messages": [SystemMessage(content="Max turns reached.")], + "messages": [SystemMessage(content="Max turns reached")], "turn": current_turn } - - # Check if we should force reflection + consecutive_errors = state.get('consecutive_errors', 0) - should_reflect = (current_turn > 5 and current_turn % REFLECT_EVERY_N_TURNS == 0) or consecutive_errors >= 3 + should_reflect = (current_turn > 5 and current_turn % config.REFLECT_EVERY_N_TURNS == 0) or consecutive_errors >= 3 messages_to_send = state["messages"].copy() - # Add tool-forcing message if last turn had no tool call + # Force tool usage if len(messages_to_send) >= 2: last_msg = messages_to_send[-1] if isinstance(last_msg, AIMessage) and not last_msg.tool_calls: force_msg = SystemMessage( - content="āš ļø CRITICAL: You MUST call a tool this turn. NO reasoning text. Pick the most appropriate tool and call it now." + content="āš ļø CRITICAL: MUST call a tool. NO reasoning text." ) messages_to_send.append(force_msg) - print("🚨 Injecting tool-forcing message") + print("🚨 Forcing tool usage") - # Add reflection hint if needed if should_reflect: hint = SystemMessage( - content="āš ļø HINT: Multiple turns without progress. Consider calling reflect_on_progress() or try a different approach." + content="āš ļø HINT: No progress. Try reflect_on_progress() or different approach." ) messages_to_send.append(hint) - print("šŸ¤” Injecting reflection hint") + print("šŸ¤” Reflection hint") - # Invoke LLM with retries and fallback - max_retries = 3 + # Invoke LLM with retries ai_message = None - for attempt in range(max_retries): + for attempt in range(config.MAX_RETRIES): try: ai_message = self.llm_with_tools.invoke(messages_to_send) - # If we got a valid response with tool calls, break if ai_message.tool_calls: break - # If no tool calls, this is a problem - print(f"āš ļø LLM returned no tool calls on attempt {attempt+1}") + print(f"āš ļø No tool calls (attempt {attempt+1})") except Exception as e: - error_str = str(e) - print(f"āš ļø LLM attempt {attempt+1}/{max_retries} failed: {error_str[:200]}") - - # If tool_use_failed, try without strict binding - if "tool_use_failed" in error_str and attempt < max_retries - 1: - print("šŸ”§ Trying without strict tool enforcement...") - try: - simple_llm = ChatGroq( - temperature=0, - groq_api_key=os.getenv("GROQ_API_KEY"), - model_name="llama-3.3-70b-versatile", - max_tokens=4096, - timeout=60 - ) - - # Add explicit tool forcing to the message - force_tool_msg = SystemMessage( - content="You MUST call a tool. Respond with a tool call, not reasoning text." - ) - ai_message = simple_llm.invoke(messages_to_send + [force_tool_msg]) - - # Try to parse tool calls from content - if ai_message.content and not ai_message.tool_calls: - parsed = parse_tool_call_from_string(ai_message.content, self.tools) - if parsed: - ai_message.tool_calls = parsed - ai_message.content = "" - print("āœ“ Fallback parsing succeeded") - break - except Exception as e2: - print(f"āš ļø Fallback also failed: {e2}") + print(f"āš ļø LLM error (attempt {attempt+1}): {str(e)[:200]}") - if attempt == max_retries - 1: - # Last resort: inject a default tool call - print("🚨 All attempts failed - forcing think_through_logic") + if attempt == config.MAX_RETRIES - 1: + print("🚨 Forcing think_through_logic") ai_message = AIMessage( content="", tool_calls=[ToolCall( name="think_through_logic", - args={"reasoning": "Processing question"}, + args={"reasoning": "Processing"}, id=str(uuid.uuid4()) )] ) else: - time.sleep(2 ** attempt) + time.sleep(config.BASE_RETRY_DELAY * (2 ** attempt)) - # If still no tool calls after all attempts, force one + # Ensure tool calls exist if not ai_message.tool_calls: - if isinstance(ai_message.content, str) and ai_message.content.strip(): - # Try one more parse + if ai_message.content: parsed = parse_tool_call_from_string(ai_message.content, self.tools) if parsed: ai_message.tool_calls = parsed ai_message.content = "" - print("āœ“ Final parse succeeded") else: - # Absolute last resort - print("🚨 EMERGENCY: Forcing think_through_logic") ai_message.tool_calls = [ToolCall( name="think_through_logic", - args={"reasoning": "analyzing question"}, + args={"reasoning": "analyzing"}, id=str(uuid.uuid4()) )] ai_message.content = "" - # Track tool usage + # Track usage tool_history = state.get('tool_history', []) has_plan = state.get('has_plan', False) if ai_message.tool_calls: tool_name = ai_message.tool_calls[0]['name'] - print(f"šŸ”§ Tool Call: {tool_name}") + print(f"šŸ”§ Tool: {tool_name}") tool_history.append(tool_name) if tool_name == "create_plan": has_plan = True - else: - print(f"āš ļø No tool call (this shouldn't happen!)") - print(f"šŸ’­ Content: {ai_message.content[:200]}...") return { "messages": [ai_message], @@ -1643,19 +1336,14 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. "tool_history": tool_history, "last_tool_was_thinking": ai_message.tool_calls and ai_message.tool_calls[0]['name'] == 'think_through_logic' } - - # Tool Node with Error Tracking (FIXED) + def tool_node_wrapper(state: AgentState): - """Executes tools and tracks errors.""" + """Execute tools with error tracking""" print(f"šŸ”§ Executing tools...") - # Create fresh ToolNode instance tool_executor = ToolNode(self.tools) - - # Invoke properly result = tool_executor.invoke(state) - # Track errors consecutive_errors = state.get('consecutive_errors', 0) if result.get('messages'): @@ -1663,14 +1351,14 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. if isinstance(last_msg, ToolMessage): if "Error" in last_msg.content or "error" in last_msg.content.lower(): consecutive_errors += 1 - print(f"āš ļø Tool error detected (consecutive: {consecutive_errors})") + print(f"āš ļø Tool error (consecutive: {consecutive_errors})") else: consecutive_errors = 0 result['consecutive_errors'] = consecutive_errors return result - # Build Graph + # Build graph print("Building graph...") graph_builder = StateGraph(AgentState) @@ -1692,36 +1380,36 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. graph_builder.add_edge("tools", "agent") self.graph = graph_builder.compile() - print("āœ… Graph compiled successfully.") - + print("āœ… Graph compiled") + def __call__(self, question: str, file_path: str = None) -> str: - """Execute agent on a question.""" + """Execute agent""" print(f"\n{'='*70}") print(f"šŸŽÆ NEW QUESTION") print(f"{'='*70}") - print(f"Q: {question[:200]}{'...' if len(question) > 200 else ''}") + print(f"Q: {question[:200]}...") if file_path: - print(f"šŸ“Ž File attached: {file_path}") + print(f"šŸ“Ž File: {file_path}") print(f"{'='*70}\n") - # Enhanced question context with file information + # Build question context question_text = question if file_path: file_ext = Path(file_path).suffix.lower() file_type = "unknown" - if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']: + if file_ext in ['.jpg', '.jpeg', '.png', '.gif']: file_type = "image" - elif file_ext in ['.mp3', '.wav', '.m4a', '.flac']: + elif file_ext in ['.mp3', '.wav', '.m4a']: file_type = "audio" - elif file_ext in ['.csv', '.xlsx', '.xls']: + elif file_ext in ['.csv', '.xlsx']: file_type = "data" - elif file_ext in ['.txt', '.pdf', '.doc', '.docx']: + elif file_ext in ['.txt', '.pdf', '.doc']: file_type = "document" - question_text += f"\n\n[FILE ATTACHED: {file_path}]" - question_text += f"\n[FILE TYPE: {file_type}]" - question_text += f"\nIMPORTANT: Use the appropriate tool to access this file first!" + question_text += f"\n\n[FILE: {file_path}]" + question_text += f"\n[TYPE: {file_type}]" + question_text += f"\nUse appropriate tool first!" graph_input = { "messages": [ @@ -1736,13 +1424,13 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. "last_tool_was_thinking": False } - final_answer = "AGENT FAILED TO PRODUCE ANSWER" + final_answer = "AGENT FAILED" all_messages = [] try: - config = {"recursion_limit": MAX_TURNS + 10} + config_dict = {"recursion_limit": config.MAX_TURNS + 10} - for event in self.graph.stream(graph_input, stream_mode="values", config=config): + for event in self.graph.stream(graph_input, stream_mode="values", config=config_dict): if not event.get('messages'): continue @@ -1756,38 +1444,31 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. args = tool_call.get('args', {}) if 'answer' in args: final_answer = args['answer'] - print(f"\n{'='*70}") - print(f"āœ… FINAL ANSWER: '{final_answer}'") - print(f"{'='*70}\n") + print(f"\nāœ… FINAL: '{final_answer}'\n") break - + elif isinstance(last_message, ToolMessage): preview = last_message.content[:200].replace('\n', ' ') - print(f"šŸ“Š Tool '{last_message.name}' result: {preview}...") - - elif isinstance(last_message, AIMessage) and not last_message.tool_calls: - print(f"šŸ’­ AI: {last_message.content[:200]}...") + print(f"šŸ“Š Tool '{last_message.name}': {preview}...") - # If no final answer, try to extract from tool messages - if final_answer == "AGENT FAILED TO PRODUCE ANSWER": - print("āš ļø No final_answer_tool called. Checking tool results...") + # Fallback: extract from tool results + if final_answer == "AGENT FAILED": + print("āš ļø No final_answer_tool. Checking tools...") for msg in reversed(all_messages): if isinstance(msg, ToolMessage): if msg.name in ["calculator", "think_through_logic", "code_interpreter"]: content = msg.content.strip() - # Look for short, answer-like content if content and len(content) < 200 and not content.startswith("Error"): - # Extract just the result part lines = content.split('\n') for line in reversed(lines): if line.strip() and not line.startswith(('āœ…', 'āš ļø', 'Next', 'Remember')): final_answer = line.strip() - print(f"šŸ“ Extracted from {msg.name}: '{final_answer}'") + print(f"šŸ“ Extracted: '{final_answer}'") break break - - # Clean the answer + + # Clean answer cleaned = str(final_answer).strip() # Remove prefixes @@ -1804,7 +1485,7 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. cleaned = potential break - # Remove code fences and quotes + # Remove code fences cleaned = remove_fences_simple(cleaned) while cleaned.startswith("`") and cleaned.endswith("`"): @@ -1814,214 +1495,169 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. (cleaned.startswith("'") and cleaned.endswith("'")): cleaned = cleaned[1:-1].strip() - # Remove trailing period for short answers if cleaned.endswith('.') and len(cleaned.split()) < 10: cleaned = cleaned[:-1] - print(f"\n{'='*70}") - print(f"šŸŽ‰ RETURNING ANSWER") - print(f"{'='*70}") - print(f"{cleaned}") - print(f"{'='*70}\n") + print(f"\nšŸŽ‰ RETURNING: {cleaned}\n") return cleaned except Exception as e: print(f"āŒ Graph error: {e}") print(traceback.format_exc()) - return f"AGENT ERROR: {e}" - + return f"ERROR: {e}" # ============================================================================= -# GLOBAL AGENT INSTANTIATION +# GLOBAL AGENT # ============================================================================= agent = None try: - initialize_rag_components() - + rag_manager.initialize() agent = PlanningReflectionAgent() - print("āœ… Global PlanningReflectionAgent instantiated.") + print("āœ… Global agent ready") - # Verify it's callable if not callable(agent): - print("āŒ ERROR: Agent not callable!") + print("āŒ Agent not callable") agent = None else: - print("āœ… Agent is callable.") + print("āœ… Agent is callable") - if asr_pipeline is None: - print("āš ļø ASR Pipeline not loaded.") - except Exception as e: - print(f"āŒ FATAL: Agent initialization failed: {e}") + print(f"āŒ FATAL: {e}") traceback.print_exc() agent = None # ============================================================================= -# RUN AND SUBMIT FUNCTION +# RUN AND SUBMIT # ============================================================================= - def run_and_submit_all(profile: gr.OAuthProfile | None): - """ - Fetches all questions, runs the BasicAgent on them, submits all answers, - and displays the results. - """ + """Run evaluation and submit""" space_id = os.getenv("SPACE_ID") - + if profile: - username = f"{profile.username}" - print(f"User logged in: {username}") + username = profile.username + print(f"User: {username}") else: - print("User not logged in.") - return "Please Login to Hugging Face with the button.", None + print("Not logged in") + return "Please login to HuggingFace", None - # Use the globally instantiated agent global agent if agent is None: - error_msg = "FATAL: Agent failed to initialize at startup. Check logs for errors." - print(error_msg) - return error_msg, None + return "FATAL: Agent failed to initialize", None - print("āœ… Using globally instantiated PlanningReflectionAgent") + print("āœ… Using global agent") - api_url = DEFAULT_API_URL + api_url = config.DEFAULT_API_URL questions_url = f"{api_url}/questions" submit_url = f"{api_url}/submit" - agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" - print(agent_code) - - # 2. Fetch Questions + # Fetch questions print(f"\n{'='*70}") print(f"šŸ“„ FETCHING QUESTIONS") - print(f"{'='*70}") - print(f"Fetching questions from: {questions_url}") + print(f"{'='*70}\n") try: response = requests.get(questions_url, timeout=15) response.raise_for_status() questions_data = response.json() + if not questions_data: - print("Fetched questions list is empty.") - return "Fetched questions list is empty or invalid format.", None - print(f"āœ… Fetched {len(questions_data)} questions.") - print(f"{'='*70}\n") - except requests.exceptions.RequestException as e: - print(f"āŒ Error fetching questions: {e}") - return f"Error fetching questions: {e}", None - except requests.exceptions.JSONDecodeError as e: - print(f"āŒ Error decoding JSON response from questions endpoint: {e}") - print(f"Response text: {response.text[:500]}") - return f"Error decoding server response for questions: {e}", None + return "No questions fetched", None + + print(f"āœ… Fetched {len(questions_data)} questions\n") + except Exception as e: - print(f"āŒ An unexpected error occurred fetching questions: {e}") - return f"An unexpected error occurred fetching questions: {e}", None - + print(f"āŒ Fetch error: {e}") + return f"Error fetching questions: {e}", None + # Load answer sheet - answer_sheet = load_answer_sheet("answer_sheet_json.json") + validator = AnswerValidator() + answer_sheet = validator.load_answer_sheet("answer_sheet.json") - # If answer sheet doesn't exist, create template - if not answer_sheet: - create_answer_sheet_template(questions_data, "answer_sheet.json") - print("\nāš ļø Please fill in the answer_sheet.json file with correct answers") - print(" Then run the script again to check agent performance\n") + # Initialize tracking + progress = ProgressTracker(len(questions_data)) + telemetry.reset() - results = [] - local_correct = 0 - local_total = 0 + results_log = [] + answers_payload = [] - # 3. Run your Agent + # Process questions print(f"\n{'='*70}") print(f"šŸš€ STARTING EVALUATION") - print(f"{'='*70}") - print(f"Total questions to process: {len(questions_data)}") print(f"{'='*70}\n") - results_log = [] - answers_payload = [] - for idx, item in enumerate(questions_data, 1): print(f"\n{'='*70}") - print(f"šŸ“ PROCESSING QUESTION {idx}/{len(questions_data)}") - print(f"{'='*70}") + print(f"šŸ“ QUESTION {idx}/{len(questions_data)}") + print(f"{'='*70}\n") task_id = item.get("task_id") question_text = item.get("question") correct_answer = answer_sheet.get(task_id, "") - # Look for file locally in files/ directory + # Find file local_file_path = None files_dir = "files" try: - # Check if files directory exists if os.path.exists(files_dir): - # Look for any file that starts with the task_id matching_files = [f for f in os.listdir(files_dir) if f.startswith(task_id)] if matching_files: - # Use the first matching file local_file_path = os.path.join(files_dir, matching_files[0]) - file_size = os.path.getsize(local_file_path) - abs_path = os.path.abspath(local_file_path) - - print(f"āœ… Found file: {matching_files[0]} ({file_size} bytes)") - print(f" Path: {abs_path}") + print(f"āœ… Found file: {matching_files[0]}") else: - print(f"ā„¹ļø No file found for task {task_id}, proceeding without file.") + print(f"ā„¹ļø No file for {task_id}") else: - print(f"āš ļø Warning: '{files_dir}' directory not found.") - + print(f"āš ļø '{files_dir}' not found") except Exception as e: - print(f"āŒ Error looking for file: {e}") - + print(f"āŒ File search error: {e}") + try: - # Pass file_path to agent + # Run agent submitted_answer = agent(question_text, local_file_path) answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) - # Check if answer is correct - is_correct = submitted_answer.strip().lower() == correct_answer.strip().lower() - correctness = "āœ… CORRECT" if is_correct else "āŒ WRONG" + # Check correctness + is_correct, feedback = validator.check_correctness(submitted_answer, correct_answer) - # Log with correctness indicator - print(f"\n{correctness} - Task {task_id}") + print(f"\n{feedback} - Task {task_id}") print(f" Submitted: '{submitted_answer}'") print(f" Expected: '{correct_answer}'") results_log.append({ - "Task ID": task_id, - "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, - "Submitted Answer": submitted_answer, - "Correct Answer": correct_answer, + "Task ID": task_id, + "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, + "Submitted": submitted_answer, + "Correct": correct_answer, "Status": "āœ…" if is_correct else "āŒ" }) - print(f"āœ… Question {idx}/{len(questions_data)} completed") + progress.update(is_correct) + print(f"\nāœ… Question {idx} completed") except Exception as e: - print(f"āŒ Error running agent on task {task_id}: {e}") + print(f"āŒ Error on {task_id}: {e}") print(traceback.format_exc()) + results_log.append({ - "Task ID": task_id, - "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, - "Submitted Answer": f"AGENT ERROR: {e}", - "Correct Answer": correct_answer, + "Task ID": task_id, + "Question": question_text[:100] + "...", + "Submitted": f"ERROR: {e}", + "Correct": correct_answer, "Status": "āŒ" }) - # Continue with other questions even if one fails + answers_payload.append({"task_id": task_id, "submitted_answer": f"ERROR: {str(e)[:100]}"}) - - # Summary after all questions processed - print(f"\n{'='*70}") - print(f"āœ… ALL QUESTIONS PROCESSED") - print(f"{'='*70}") - print(f"Total answers collected: {len(answers_payload)}") + progress.update(False) + + # Print telemetry + telemetry.report() - # Calculate pre-submission accuracy + # Summary correct_count = sum(1 for log in results_log if log.get("Status") == "āœ…") total_count = len(results_log) accuracy = (correct_count / total_count * 100) if total_count > 0 else 0 @@ -2031,147 +1667,89 @@ def run_and_submit_all(profile: gr.OAuthProfile | None): print(f"{'='*70}") print(f"Correct: {correct_count}/{total_count} ({accuracy:.1f}%)") print(f"{'='*70}\n") - + if not answers_payload: - print("āš ļø Agent did not produce any answers to submit.") - return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) - - # 4. Prepare Submission - submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} - - # 5. Submit + return "No answers produced", pd.DataFrame(results_log) + + # Submit + submission_data = { + "username": username.strip(), + "agent_code": agent_code, + "answers": answers_payload + } + print(f"\n{'='*70}") - print(f"šŸ“¤ SUBMITTING TO API") - print(f"{'='*70}") - print(f"URL: {submit_url}") - print(f"Username: {username}") - print(f"Answers to submit: {len(answers_payload)}") + print(f"šŸ“¤ SUBMITTING") print(f"{'='*70}\n") try: - print("ā³ Sending POST request...") response = requests.post(submit_url, json=submission_data, timeout=60) - print(f"āœ… Got response: Status {response.status_code}") - response.raise_for_status() result_data = response.json() - print(f"\n{'='*70}") - print(f"šŸ“Š SUBMISSION RESULTS") - print(f"{'='*70}") - print(f"Response data: {result_data}") - print(f"{'='*70}\n") - final_status = ( f"Submission Successful!\n" f"User: {result_data.get('username')}\n" - f"Overall Score: {result_data.get('score', 'N/A')}% " - f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" - f"Message: {result_data.get('message', 'No message received.')}" + f"Score: {result_data.get('score', 'N/A')}% " + f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')})\n" + f"Message: {result_data.get('message', 'No message')}" ) print(final_status) - print("="*70) - print("āœ… Submission successful.") - results_df = pd.DataFrame(results_log) return final_status, results_df - except requests.exceptions.HTTPError as e: - error_detail = f"Server responded with status {e.response.status_code}." - try: - error_json = e.response.json() - error_detail += f" Detail: {error_json.get('detail', e.response.text)}" - except requests.exceptions.JSONDecodeError: - error_detail += f" Response: {e.response.text[:500]}" - status_message = f"Submission Failed: {error_detail}" - print(f"\n{'='*70}") - print(f"āŒ SUBMISSION FAILED") - print(f"{'='*70}") - print(status_message) - print(f"{'='*70}\n") - results_df = pd.DataFrame(results_log) - return status_message, results_df - - except requests.exceptions.Timeout: - status_message = "Submission Failed: The request timed out." - print(f"\n{'='*70}") - print(f"āŒ SUBMISSION FAILED") - print(f"{'='*70}") - print(status_message) - print(f"{'='*70}\n") - results_df = pd.DataFrame(results_log) - return status_message, results_df - - except requests.exceptions.RequestException as e: - status_message = f"Submission Failed: Network error - {e}" - print(f"\n{'='*70}") - print(f"āŒ SUBMISSION FAILED") - print(f"{'='*70}") - print(status_message) - print(f"{'='*70}\n") - results_df = pd.DataFrame(results_log) - return status_message, results_df - except Exception as e: - status_message = f"An unexpected error occurred during submission: {e}" - print(f"\n{'='*70}") - print(f"āŒ SUBMISSION FAILED") - print(f"{'='*70}") - print(status_message) - print(traceback.format_exc()) - print(f"{'='*70}\n") + print(f"āŒ Submission failed: {e}") results_df = pd.DataFrame(results_log) - return status_message, results_df - + return f"Submission failed: {e}", results_df -# --- Build Gradio Interface using Blocks --- +# ============================================================================= +# GRADIO INTERFACE +# ============================================================================= with gr.Blocks() as demo: - gr.Markdown("# Basic Agent Evaluation Runner") - gr.Markdown( - """ - **Instructions:** - 1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ... - 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission. - 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. - --- - **Disclaimers:** - Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions). - This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async. - """ - ) - + gr.Markdown("# GAIA Agent Evaluation - Refactored") + gr.Markdown(""" + **Improvements:** + - Better error handling with retry logic + - Caching for search results + - Telemetry and progress tracking + - Memory management + - Modular architecture + + **Instructions:** + 1. Clone this space and modify as needed + 2. Login with HuggingFace account + 3. Click 'Run Evaluation & Submit' + """) + gr.LoginButton() - - run_button = gr.Button("Run Evaluation & Submit All Answers") - - status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) - results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) - + + run_button = gr.Button("Run Evaluation & Submit All") + status_output = gr.Textbox(label="Status", lines=5, interactive=False) + results_table = gr.DataFrame(label="Results", wrap=True) + run_button.click( fn=run_and_submit_all, outputs=[status_output, results_table] ) if __name__ == "__main__": - print("\n" + "-"*30 + " App Starting " + "-"*30) - space_host_startup = os.getenv("SPACE_HOST") - space_id_startup = os.getenv("SPACE_ID") - - if space_host_startup: - print(f"āœ… SPACE_HOST found: {space_host_startup}") - print(f" Runtime URL should be: https://{space_host_startup}.hf.space") - else: - print("ā„¹ļø SPACE_HOST environment variable not found (running locally?).") - - if space_id_startup: - print(f"āœ… SPACE_ID found: {space_id_startup}") - print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") - print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main") - else: - print("ā„¹ļø SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.") - - print("-"*(60 + len(" App Starting ")) + "\n") - - print("Launching Gradio Interface for Basic Agent Evaluation...") + print("\n" + "-"*70) + print("Starting Refactored GAIA Agent") + print("-"*70 + "\n") + + space_host = os.getenv("SPACE_HOST") + space_id = os.getenv("SPACE_ID") + + if space_host: + print(f"āœ… SPACE_HOST: {space_host}") + print(f" URL: https://{space_host}.hf.space") + + if space_id: + print(f"āœ… SPACE_ID: {space_id}") + print(f" Repo: https://huggingface.co/spaces/{space_id}") + + print("\n" + "-"*70 + "\n") + demo.launch(debug=True, share=False) \ No newline at end of file