Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import subprocess | |
| import json | |
| import re | |
| import traceback | |
| import contextlib | |
| import uuid | |
| import time | |
| import ast | |
| from typing import List, Optional, TypedDict, Annotated, Dict | |
| from pathlib import Path | |
| from collections import Counter | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from pydantic import BaseModel, Field | |
| # Multimodal & Web Tools | |
| from transformers import pipeline | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from bs4 import BeautifulSoup | |
| import requests | |
| from PIL import Image | |
| import base64 | |
| from googleapiclient.discovery import build | |
| from googleapiclient.errors import HttpError | |
| import assemblyai as aai | |
| # LangChain & LangGraph | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, AnyMessage, ToolCall | |
| from langchain_core.tools import tool | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain_groq import ChatGroq | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| # RAG | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_core.documents import Document | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| MAX_TURNS = 25 | |
| MAX_MESSAGE_LENGTH = 8000 | |
| REFLECT_EVERY_N_TURNS = 5 | |
| # ============================================================================= | |
| # GLOBAL RAG COMPONENTS | |
| # ============================================================================= | |
| global_embeddings = None | |
| global_text_splitter = None | |
| def initialize_rag_components(): | |
| """Initialize RAG components globally.""" | |
| global global_embeddings, global_text_splitter | |
| if global_embeddings is None: | |
| print("Initializing RAG embeddings...") | |
| try: | |
| global_embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| print("✅ Embeddings initialized.") | |
| except Exception as e: | |
| print(f"⚠️ Failed to initialize embeddings: {e}") | |
| return False | |
| if global_text_splitter is None: | |
| print("Initializing text splitter...") | |
| global_text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| length_function=len, | |
| separators=["\n\n", "\n", ". ", " ", ""] | |
| ) | |
| print("✅ Text splitter initialized.") | |
| return True | |
| # ============================================================================= | |
| # ANSWER SHEET VALIDATION FUNCTIONS | |
| # ============================================================================= | |
| def load_answer_sheet(filepath: str = "answer_sheet.json") -> Dict[str, str]: | |
| """Load the answer sheet from a JSON file""" | |
| try: | |
| if os.path.exists(filepath): | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| answers = json.load(f) | |
| print(f"✅ Loaded answer sheet with {len(answers)} answers from {filepath}") | |
| return answers | |
| else: | |
| print(f"⚠️ Answer sheet not found at {filepath}") | |
| return {} | |
| except Exception as e: | |
| print(f"❌ Error loading answer sheet: {e}") | |
| return {} | |
| def check_answer_correctness(submitted: str, correct: str) -> tuple[bool, str]: | |
| """ | |
| Check if submitted answer matches correct answer with fuzzy matching | |
| Returns: (is_correct, feedback_message) | |
| """ | |
| # Normalize both answers | |
| submitted_norm = submitted.strip().lower() | |
| correct_norm = correct.strip().lower() | |
| # Exact match | |
| if submitted_norm == correct_norm: | |
| return True, "✅ EXACT MATCH" | |
| # Remove common punctuation and check again | |
| import string | |
| submitted_clean = submitted_norm.translate(str.maketrans('', '', string.punctuation)) | |
| correct_clean = correct_norm.translate(str.maketrans('', '', string.punctuation)) | |
| if submitted_clean == correct_clean: | |
| return True, "✅ MATCH (punctuation difference)" | |
| # Check if it's a number formatting issue | |
| try: | |
| # Try to parse as numbers | |
| submitted_num = float(submitted_clean.replace(',', '').replace('$', '')) | |
| correct_num = float(correct_clean.replace(',', '').replace('$', '')) | |
| if abs(submitted_num - correct_num) < 0.01: # Allow small floating point differences | |
| return True, "✅ MATCH (numeric equivalence)" | |
| except (ValueError, AttributeError): | |
| pass | |
| # Check if submitted answer contains correct answer (for list-type answers) | |
| if ',' in correct_norm: | |
| correct_items = set([item.strip() for item in correct_norm.split(',')]) | |
| submitted_items = set([item.strip() for item in submitted_norm.split(',')]) | |
| if correct_items == submitted_items: | |
| return True, "✅ MATCH (item order difference)" | |
| missing_items = correct_items - submitted_items | |
| extra_items = submitted_items - correct_items | |
| if missing_items and not extra_items: | |
| return False, f"❌ MISSING: {', '.join(missing_items)}" | |
| elif extra_items and not missing_items: | |
| return False, f"❌ EXTRA: {', '.join(extra_items)}" | |
| elif missing_items and extra_items: | |
| return False, f"❌ MISSING: {', '.join(missing_items)} | EXTRA: {', '.join(extra_items)}" | |
| # Check case-insensitive substring match | |
| if submitted_norm in correct_norm or correct_norm in submitted_norm: | |
| return False, f"❌ PARTIAL MATCH (submitted: '{submitted}' | correct: '{correct}')" | |
| return False, f"❌ WRONG (submitted: '{submitted}' | correct: '{correct}')" | |
| def create_answer_sheet_template(questions: List[Dict], filepath: str = "answer_sheet.json"): | |
| """Create an answer sheet template from questions""" | |
| answer_template = {} | |
| for q in questions: | |
| answer_template[q['task_id']] = "" | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(answer_template, f, indent=2) | |
| print(f"✅ Created answer sheet template at {filepath}") | |
| print(f" Please fill in the correct answers for {len(answer_template)} questions") | |
| # ============================================================================= | |
| # ASR INITIALIZATION | |
| # ============================================================================= | |
| asr_pipeline = None | |
| try: | |
| print("Loading ASR (Whisper) pipeline globally...") | |
| device = 0 if torch.cuda.is_available() else -1 | |
| device_name = "cuda:0" if device == 0 else "cpu" | |
| print(f"Attempting to use device: {device_name} for ASR.") | |
| asr_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base", | |
| torch_dtype=torch.float16 if device == 0 else torch.float32, | |
| device=device | |
| ) | |
| print("✅ ASR (Whisper) pipeline loaded successfully.") | |
| except Exception as e: | |
| print(f"⚠️ Warning: Could not load ASR pipeline globally. Error: {e}") | |
| asr_pipeline = None | |
| # ============================================================================= | |
| # UTILITY FUNCTIONS | |
| # ============================================================================= | |
| def remove_fences_simple(text): | |
| """Remove code fences from text.""" | |
| original_text = text | |
| text = text.strip() | |
| if text.startswith("```") and text.endswith("```"): | |
| text = text[3:-3].strip() | |
| if '\n' in text: | |
| first_line, rest = text.split('\n', 1) | |
| if first_line.strip().replace('_','').isalnum() and len(first_line.strip()) < 15: | |
| text = rest.strip() | |
| return text | |
| return original_text | |
| def truncate_if_needed(content: str, max_length: int = MAX_MESSAGE_LENGTH) -> str: | |
| """Truncate content if it exceeds max length.""" | |
| if len(content) > max_length: | |
| return content[:max_length] + f"\n...[truncated, {len(content)} total chars]" | |
| return content | |
| def find_file(path: str) -> Optional[Path]: | |
| """Find a file by trying multiple path variations.""" | |
| script_dir = Path.cwd() | |
| safe_path = Path(path).as_posix() | |
| paths_to_try = [ | |
| script_dir / safe_path, | |
| Path(safe_path), | |
| script_dir / Path(path).name | |
| ] | |
| for attempt_path in paths_to_try: | |
| if attempt_path.exists(): | |
| return attempt_path | |
| return None | |
| # ============================================================================= | |
| # PLANNING & REFLECTION TOOLS | |
| # ============================================================================= | |
| class ThinkInput(BaseModel): | |
| reasoning: str = Field(description="Brief reasoning summary (under 150 chars)") | |
| def think_through_logic(reasoning: str) -> str: | |
| """ | |
| Use this to work through logic puzzles, riddles, or reasoning problems. | |
| Call this when: | |
| - The question is a riddle or brain teaser | |
| - You need to reason through a logical problem | |
| - No external information is needed, just thinking | |
| After thinking, use calculator if math is involved, then validate and submit answer. | |
| """ | |
| print(f"🧠 Thinking: {reasoning[:100]}...") | |
| return f"""✅ Logic reasoning recorded. | |
| Next steps: | |
| 1. If math needed → use calculator() | |
| 2. Once you have answer → use validate_answer() | |
| 3. Then → use final_answer_tool() | |
| Remember: You MUST call another tool. Do not output reasoning text.""" | |
| class PlanInput(BaseModel): | |
| task_summary: str = Field(description="Very brief task summary (under 80 chars)") | |
| def create_plan(task_summary: str) -> str: | |
| """ | |
| Creates a plan for multi-step questions. Use for complex tasks only. | |
| Keep the summary VERY brief to avoid errors. | |
| """ | |
| print(f"📋 Planning: {task_summary[:80]}...") | |
| return f"""✅ Plan created for: {task_summary} | |
| FRAMEWORK: | |
| 1. What info do I need? | |
| 2. What tools will I use? | |
| 3. In what order? | |
| Now execute step 1. You MUST call a tool next.""" | |
| class ReflectInput(BaseModel): | |
| situation: str = Field(description="Brief situation summary (under 80 chars)") | |
| def reflect_on_progress(situation: str) -> str: | |
| """ | |
| Reflects on progress when stuck. Use after 5+ turns without progress. | |
| Keep situation summary VERY brief. | |
| """ | |
| print(f"🤔 Reflecting: {situation[:80]}...") | |
| return f"""🔍 REFLECTION on: {situation} | |
| QUESTIONS: | |
| 1. Am I using the right approach? | |
| 2. Should I try a different tool? | |
| 3. Do I actually have the answer already? | |
| Take a DIFFERENT approach now. You MUST call a tool next.""" | |
| class ValidateInput(BaseModel): | |
| proposed_answer: str = Field(description="The answer to validate") | |
| original_question: str = Field(description="Original question (first 100 chars)") | |
| def validate_answer(proposed_answer: str, original_question: str) -> str: | |
| """ | |
| Validates answer format before submission. ALWAYS use before final_answer_tool. | |
| """ | |
| print(f"✓ Validating: '{proposed_answer[:50]}...'") | |
| issues = [] | |
| warnings = [] | |
| # Check for conversational fluff | |
| fluff = ["the answer is", "based on", "according to", "i found", "here is"] | |
| if any(p in proposed_answer.lower() for p in fluff): | |
| issues.append("❌ Remove conversational text. Answer only.") | |
| # Check for code fences | |
| if "```" in proposed_answer: | |
| issues.append("❌ Remove code fences (```).") | |
| # Check length | |
| if len(proposed_answer) > 500: | |
| warnings.append("⚠️ Answer very long. Just the answer?") | |
| # Check for number questions | |
| if any(k in original_question.lower() for k in ["how many", "what number", "count"]): | |
| if not any(c.isdigit() for c in proposed_answer): | |
| warnings.append("⚠️ Question asks for number but answer has no digits.") | |
| if issues: | |
| return "🚫 VALIDATION FAILED:\n" + "\n".join(issues) + "\n\nFix then retry." | |
| if warnings: | |
| return "⚠️ WARNINGS:\n" + "\n".join(warnings) + "\n\nConsider fixing, or proceed if confident." | |
| return "✅ VALIDATION PASSED! Now call final_answer_tool() with this answer." | |
| # ============================================================================= | |
| # CORE TOOLS | |
| # ============================================================================= | |
| class SearchInput(BaseModel): | |
| query: str = Field(description="Search query (concise)") | |
| def search_tool(query: str) -> str: | |
| """ | |
| Search the web for information. Returns snippets. | |
| IMPORTANT: Search results are SNIPPETS only. For complete information: | |
| 1. Use search_tool to find URLs | |
| 2. Use scrape_and_retrieve to get FULL page content | |
| Example workflow: | |
| - search_tool("Mercedes Sosa Wikipedia") → get URL | |
| - scrape_and_retrieve(url=..., query="studio albums 2000-2009") | |
| """ | |
| if not isinstance(query, str) or not query.strip(): | |
| return "Error: Invalid query." | |
| # Auto-add Wikipedia site filter if mentioned | |
| if 'wikipedia' in query.lower() and 'site:' not in query: | |
| query = f"{query} site:wikipedia.org" | |
| print(f"🔍 Searching: {query}") | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| search = DuckDuckGoSearchRun() | |
| result = search.run(query) | |
| if not result or len(result) < 50: | |
| return "No relevant results found. Try different search terms or check if the information exists." | |
| return truncate_if_needed(result) | |
| except Exception as e: | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) | |
| continue | |
| return f"Search error after {max_retries} attempts: {str(e)}" | |
| class CalcInput(BaseModel): | |
| expression: str = Field(description="Math expression (e.g., '2+2', 'sqrt(16)')") | |
| def calculator(expression: str) -> str: | |
| """ | |
| Evaluates math expressions. Use for ANY calculations. | |
| Supports: +, -, *, /, **, sqrt, sin, cos, log, pi, e, etc. | |
| """ | |
| if not isinstance(expression, str) or not expression.strip(): | |
| return "Error: Invalid expression." | |
| print(f"🧮 Calculating: {expression}") | |
| try: | |
| import math | |
| safe_dict = { | |
| 'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, | |
| 'log': math.log, 'log10': math.log10, 'exp': math.exp, | |
| 'pi': math.pi, 'e': math.e, 'abs': abs, 'round': round, | |
| 'pow': pow, 'sum': sum, 'min': min, 'max': max | |
| } | |
| result = eval(expression, {"__builtins__": {}}, safe_dict) | |
| return str(result) | |
| except Exception as e: | |
| return f"Calculation error for '{expression}': {str(e)}" | |
| class CodeInput(BaseModel): | |
| code: str = Field(description="Python code (MUST include print() for output)") | |
| def code_interpreter(code: str) -> str: | |
| """ | |
| Executes Python code with timeout protection. | |
| CRITICAL: Always use print() to output results! | |
| """ | |
| if not isinstance(code, str): | |
| return "Error: code must be string." | |
| # Safety checks | |
| dangerous = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system', 'exec('] | |
| if any(d in code.lower() for d in dangerous): | |
| return f"Error: Dangerous operation not allowed." | |
| if 'open(' in code.lower() and any(m in code for m in ["'w'", '"w"', "'a'", '"a"']): | |
| return "Error: File writing not allowed. Use write_file tool." | |
| print(f"💻 Executing code ({len(code)} chars)...") | |
| output_stream = io.StringIO() | |
| error_stream = io.StringIO() | |
| try: | |
| with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream): | |
| safe_globals = { | |
| "pd": pd, | |
| "np": np, | |
| "json": json, | |
| "re": re, | |
| "__builtins__": __builtins__ | |
| } | |
| exec(code, safe_globals, {}) | |
| stdout = output_stream.getvalue() | |
| stderr = error_stream.getvalue() | |
| if stderr: | |
| return f"Error:\n{stderr}\n\nStdout:\n{stdout}" | |
| if stdout: | |
| return truncate_if_needed(stdout) | |
| return "Code executed but no output. Remember to use print()!" | |
| except Exception as e: | |
| return f"Execution failed:\n{traceback.format_exc()}" | |
| class ReadFileInput(BaseModel): | |
| path: str = Field(description="File path") | |
| def read_file(path: str) -> str: | |
| """Reads file content.""" | |
| if not isinstance(path, str) or not path.strip(): | |
| return "Error: Invalid path." | |
| print(f"📄 Reading: {path}") | |
| file_path = find_file(path) | |
| if not file_path: | |
| return f"Error: File not found: '{path}'\nCWD files: {os.listdir('.')}" | |
| try: | |
| content = file_path.read_text(encoding='utf-8') | |
| return truncate_if_needed(content) | |
| except UnicodeDecodeError: | |
| return f"Error: Binary file. Size: {file_path.stat().st_size} bytes. Try audio_transcription_tool for audio." | |
| except Exception as e: | |
| return f"Read error: {str(e)}" | |
| class WriteFileInput(BaseModel): | |
| path: str = Field(description="File path") | |
| content: str = Field(description="Content to write") | |
| def write_file(path: str, content: str) -> str: | |
| """Writes content to file.""" | |
| if not path or not isinstance(content, str): | |
| return "Error: Invalid inputs." | |
| print(f"✍️ Writing: {path}") | |
| try: | |
| file_path = Path.cwd() / path | |
| file_path.parent.mkdir(parents=True, exist_ok=True) | |
| file_path.write_text(content, encoding='utf-8') | |
| return f"Wrote {len(content)} chars to '{path}'." | |
| except Exception as e: | |
| return f"Write error: {str(e)}" | |
| class ListDirInput(BaseModel): | |
| path: str = Field(description="Directory path", default=".") | |
| def list_directory(path: str = ".") -> str: | |
| """Lists directory contents.""" | |
| print(f"📁 Listing: {path}") | |
| try: | |
| dir_path = Path.cwd() / path if path != "." else Path.cwd() | |
| if not dir_path.is_dir(): | |
| return f"Error: '{path}' not a directory." | |
| items = sorted(dir_path.iterdir()) | |
| if not items: | |
| return f"Directory '{path}' is empty." | |
| files, dirs = [], [] | |
| for item in items: | |
| if item.is_dir(): | |
| dirs.append(f"📁 {item.name}/") | |
| else: | |
| files.append(f"📄 {item.name} ({item.stat().st_size} bytes)") | |
| result = f"Contents of '{path}':\n\n" | |
| if dirs: | |
| result += "Directories:\n" + "\n".join(dirs) + "\n\n" | |
| if files: | |
| result += "Files:\n" + "\n".join(files) | |
| return result | |
| except Exception as e: | |
| return f"List error: {str(e)}" | |
| class AudioInput(BaseModel): | |
| file_path: str = Field(description="Audio file path") | |
| def audio_transcription_tool(file_path: str) -> str: | |
| """Transcribes audio using Whisper.""" | |
| if not file_path: | |
| return "Error: Invalid file path." | |
| print(f"🎤 Transcribing: {file_path}") | |
| if asr_pipeline is None: | |
| return "Error: ASR not available." | |
| audio_path = find_file(file_path) | |
| if not audio_path: | |
| return f"Error: Audio file not found: '{file_path}'" | |
| try: | |
| transcription = asr_pipeline( | |
| str(audio_path), | |
| return_timestamps=True, # ← Add this! | |
| chunk_length_s=30, # ← Process in 30-second chunks | |
| stride_length_s=5 # ← 5-second overlap between chunks | |
| ) | |
| # Extract just the text (ignore timestamps) | |
| result_text = transcription.get("text", "") | |
| # OR if you want to see the chunks: | |
| # chunks = transcription.get("chunks", []) | |
| # result_text = " ".join([chunk["text"] for chunk in chunks]) | |
| if not result_text: | |
| return "Error: Transcription empty." | |
| return f"Transcription:\n{truncate_if_needed(result_text)}" | |
| except Exception as e: | |
| return f"Transcription error: {str(e)}" | |
| class ImageAnalysisInput(BaseModel): | |
| file_path: str = Field(description="Image file path") | |
| query: str = Field(description="What to analyze in the image") | |
| def analyze_image(file_path: str, query: str) -> str: | |
| """ | |
| Analyzes images using Google Gemini Vision API. | |
| Use for: chess positions, diagrams, charts, photos, screenshots. | |
| Provide the EXACT file path from [FILE ATTACHED: ...] in the question. | |
| """ | |
| if not file_path or not query: | |
| return "Error: file_path and query required." | |
| print(f"🖼️ Analyzing image: {file_path}") | |
| print(f" Query: {query[:100]}...") | |
| # Try to find the file | |
| image_path = find_file(file_path) | |
| # If not found via find_file, try the path directly (for /tmp files) | |
| if not image_path and os.path.exists(file_path): | |
| image_path = Path(file_path) | |
| if not image_path or not image_path.exists(): | |
| return f"Error: Image not found at '{file_path}'. Check [FILE ATTACHED: ...] in question for correct path." | |
| print(f"✓ Found image at: {image_path}") | |
| try: | |
| GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GOOGLE_API_KEY: | |
| return "Error: GEMINI_API_KEY not set." | |
| # Load and encode image | |
| img = Image.open(image_path) | |
| print(f" Image size: {img.size}, mode: {img.mode}") | |
| # Convert to RGB if necessary | |
| if img.mode not in ['RGB', 'RGBA']: | |
| img = img.convert('RGB') | |
| # Convert to base64 | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| print(f" Encoded image: {len(img_base64)} bytes") | |
| # Use Gemini Vision | |
| vision_llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash", | |
| google_api_key=GOOGLE_API_KEY, | |
| temperature=0 | |
| ) | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": query}, | |
| { | |
| "type": "image_url", | |
| "image_url": f"data:image/jpeg;base64,{img_base64}" | |
| } | |
| ] | |
| ) | |
| print(f" Sending to Gemini Vision...") | |
| response = vision_llm.invoke([message]) | |
| print(f"✓ Got response: {len(response.content)} chars") | |
| return f"Image Analysis:\n{truncate_if_needed(response.content)}" | |
| except Exception as e: | |
| error_msg = f"Image analysis error: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| print(traceback.format_exc()) | |
| return error_msg | |
| class YoutubeInput(BaseModel): | |
| video_url: str = Field(description="YouTube URL") | |
| def get_youtube_transcript(video_url: str) -> str: | |
| """ | |
| Fetches YouTube video transcript using AssemblyAI. | |
| Works reliably on Hugging Face Spaces. | |
| """ | |
| try: | |
| # Set API key (store in HF Spaces secrets) | |
| aai.settings.api_key = os.getenv("ASSEMBLYAI_API_KEY") | |
| print(f"📺 Transcribing: {video_url}") | |
| # Transcribe directly from YouTube URL | |
| transcriber = aai.Transcriber() | |
| transcript = transcriber.transcribe(video_url) | |
| # Wait for transcription | |
| if transcript.status == aai.TranscriptStatus.error: | |
| return f"Error: {transcript.error}" | |
| print(f"✓ Transcribed {len(transcript.text)} chars") | |
| return f"Transcript:\n{transcript.text}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| class ScrapeInput(BaseModel): | |
| url: str = Field(description="URL (must start with http:// or https://)") | |
| query: str = Field(description="Specific information to find on the page") | |
| def scrape_and_retrieve(url: str, query: str) -> str: | |
| """ | |
| Fetch and search FULL webpage content using RAG (not just snippets like search_tool). | |
| CRITICAL: Use this after search_tool gives you a URL. This gets the COMPLETE page. | |
| Workflow Example: | |
| 1. search_tool('Mercedes Sosa Wikipedia') → get URL | |
| 2. scrape_and_retrieve( | |
| url='https://en.wikipedia.org/wiki/Mercedes_Sosa', | |
| query='studio albums released between 2000 and 2009' | |
| ) → Returns FULL discography section | |
| Use when: | |
| - Counting items (albums, people, events, etc.) | |
| - Finding specific names, dates, or numbers | |
| - Need complete tables or lists | |
| - Wikipedia articles, documentation, papers | |
| - Search snippets weren't enough | |
| """ | |
| if not url.startswith(('http://', 'https://')): | |
| return f"Error: Invalid URL format. Must start with http:// or https://" | |
| if not query: | |
| return "Error: Query required to search the page content." | |
| if global_embeddings is None or global_text_splitter is None: | |
| if not initialize_rag_components(): | |
| return "Error: RAG components not initialized." | |
| print(f"🌐 Scraping: {url}") | |
| print(f" Looking for: {query[:100]}...") | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
| } | |
| response = requests.get(url, headers=headers, timeout=20) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| # Remove noise | |
| for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe"]): | |
| tag.extract() | |
| # Extract main content | |
| main = soup.find('main') or soup.find('article') or soup.find('div', class_='mw-parser-output') or soup.body | |
| if not main: | |
| return "Error: Could not find main content on page." | |
| text = main.get_text(separator='\n', strip=True) | |
| lines = [l.strip() for l in text.splitlines() if l.strip()] | |
| text = '\n'.join(lines) | |
| if len(text) < 50: | |
| return f"Error: Page content too short ({len(text)} chars). May be blocked or empty." | |
| print(f"✓ Extracted {len(text)} characters from page") | |
| # Chunk and search | |
| chunks = global_text_splitter.split_text(text) | |
| if not chunks: | |
| return "Error: Could not process page content." | |
| print(f"✓ Created {len(chunks)} chunks") | |
| docs = [Document(page_content=c, metadata={"source": url}) for c in chunks] | |
| db = FAISS.from_documents(docs, global_embeddings) | |
| retriever = db.as_retriever(search_kwargs={"k": 5}) | |
| retrieved = retriever.invoke(query) | |
| if not retrieved: | |
| return f"No information found matching: '{query}'\nTry a different query or the information may not be on this page." | |
| print(f"✓ Found {len(retrieved)} relevant chunks") | |
| context = "\n\n---\n\n".join([f"[Section {i+1}]\n{d.page_content}" for i, d in enumerate(retrieved)]) | |
| return truncate_if_needed(f"From {url}:\n\n{context}") | |
| except requests.Timeout: | |
| if attempt < max_retries - 1: | |
| print(f"⚠️ Timeout, retrying... (attempt {attempt + 1}/{max_retries})") | |
| time.sleep(2 ** attempt) | |
| continue | |
| return f"Error: Page request timed out after {max_retries} attempts." | |
| except requests.RequestException as e: | |
| if attempt < max_retries - 1: | |
| time.sleep(2 ** attempt) | |
| continue | |
| return f"Error fetching page: {str(e)}" | |
| except Exception as e: | |
| return f"Error processing page: {str(e)}\n{traceback.format_exc()}" | |
| class ChessAnalysisInput(BaseModel): | |
| image_path: str = Field(description="Path to chess board image file") | |
| description: str = Field(description="Any additional context about the position (optional)", default="") | |
| def analyze_chess_position(image_path: str, description: str = "") -> str: | |
| """ | |
| Analyzes a chess position from an image using Stockfish engine. | |
| MUCH MORE RELIABLE than Lichess API because: | |
| - Works offline | |
| - Analyzes ANY position (not just cloud database) | |
| - Stronger engine (Stockfish 16+) | |
| - No rate limits or 404 errors | |
| Use this tool when: | |
| - Question mentions chess, checkmate, or chess notation | |
| - An image file shows a chess board | |
| - Need to find the best move in a position | |
| Args: | |
| image_path: Path to chess board image | |
| description: The full question text - IMPORTANT for determining whose turn it is! | |
| Returns: Best move in algebraic notation (e.g., "Qh5", "Nf6+", "Rd5") | |
| """ | |
| if not image_path: | |
| return "Error: image_path is required." | |
| print(f"♟️ Analyzing chess position from: {image_path}") | |
| # Find the file | |
| chess_image = find_file(image_path) | |
| # If not found via find_file, try direct path | |
| if not chess_image and os.path.exists(image_path): | |
| chess_image = Path(image_path) | |
| if not chess_image or not chess_image.exists(): | |
| return f"Error: Chess board image not found at '{image_path}'. Check the [FILE ATTACHED: ...] path in the question." | |
| print(f"✓ Found chess image at: {chess_image}") | |
| try: | |
| # ==================================================================== | |
| # STEP 1: Extract FEN notation from image using Gemini Vision | |
| # ==================================================================== | |
| GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GOOGLE_API_KEY: | |
| return "Error: GEMINI_API_KEY not set in Space secrets." | |
| print("📸 Extracting chess position from image using Gemini...") | |
| # Load and encode image | |
| img = Image.open(chess_image) | |
| print(f" Image loaded: {img.size}, mode: {img.mode}") | |
| if img.mode not in ['RGB', 'RGBA']: | |
| img = img.convert('RGB') | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Use Gemini Vision to extract FEN | |
| vision_llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-pro", | |
| google_api_key=GOOGLE_API_KEY, | |
| temperature=0 | |
| ) | |
| # Check if the question explicitly states whose turn it is | |
| whose_turn = None | |
| if description: | |
| desc_lower = description.lower() | |
| if "black" in desc_lower and ("turn" in desc_lower or "move" in desc_lower): | |
| whose_turn = "b" | |
| elif "white" in desc_lower and ("turn" in desc_lower or "move" in desc_lower): | |
| whose_turn = "w" | |
| fen_prompt = f"""Analyze this chess board image and provide the position in FEN notation. | |
| CRITICAL INSTRUCTIONS: | |
| 1. Carefully identify each piece: | |
| - White pieces (UPPERCASE): K=King, Q=Queen, R=Rook, B=Bishop, N=Knight, P=Pawn | |
| - Black pieces (lowercase): k, q, r, b, n, p | |
| 2. BOARD ORIENTATION - This is CRITICAL: | |
| - In chess diagrams, the board is shown from the perspective of the player to move | |
| - Look at the BOTTOM rank (closest to viewer): | |
| * If bottom pieces are BLACK (lowercase in FEN) → Black to move → active color = 'b' | |
| * If bottom pieces are WHITE (uppercase in FEN) → White to move → active color = 'w' | |
| - The rank labels (1-8) on the side can help: | |
| * If rank 8 is at bottom and rank 1 at top → Black's perspective → use 'b' | |
| * If rank 1 is at bottom and rank 8 at top → White's perspective → use 'w' | |
| {"- OVERRIDE: The question explicitly states BLACK's turn, so use 'b'" if whose_turn == "b" else ""} | |
| {"- OVERRIDE: The question explicitly states WHITE's turn, so use 'w'" if whose_turn == "w" else ""} | |
| 3. FEN Format (read from rank 8 to rank 1, left to right): | |
| - Use numbers (1-8) for consecutive empty squares | |
| - Use '/' to separate ranks | |
| - IMPORTANT: Always write FEN from White's perspective (rank 8 first, rank 1 last) | |
| - But set the active_color based on whose perspective the board shows | |
| 4. Return ONLY the FEN string in this exact format: | |
| piece_placement active_color castling en_passant halfmove fullmove | |
| Example: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1 | |
| DOUBLE-CHECK: | |
| - Did you identify whose pieces are at the BOTTOM of the board? | |
| - Did you set active_color correctly based on board orientation? | |
| - Did you write piece_placement from rank 8 to rank 1? | |
| Return ONLY the FEN string, nothing else.""" | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": fen_prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": f"data:image/jpeg;base64,{img_base64}" | |
| } | |
| ] | |
| ) | |
| response = vision_llm.invoke([message]) | |
| fen_raw = response.content.strip() | |
| print(f"📝 Raw FEN response: {fen_raw}") | |
| # Clean up FEN (remove markdown, explanations, etc.) | |
| fen = None | |
| for line in fen_raw.split('\n'): | |
| line = line.strip().replace('```', '').replace('fen', '') | |
| # FEN should have '/' for ranks and spaces for components | |
| if '/' in line and ' ' in line and not line.startswith('#'): | |
| if any(c in line for c in 'kqrbnpKQRBNP12345678'): | |
| fen = line | |
| break | |
| if not fen: | |
| return f"Error: Could not extract valid FEN notation from image. Response: {fen_raw[:200]}" | |
| print(f"✓ Extracted FEN: {fen}") | |
| # Override the turn indicator if we know from the question | |
| if whose_turn: | |
| fen_parts = fen.split() | |
| if len(fen_parts) >= 2: | |
| old_turn = fen_parts[1] | |
| fen_parts[1] = whose_turn | |
| fen = ' '.join(fen_parts) | |
| print(f"🔄 Corrected turn from '{old_turn}' to '{whose_turn}' based on question") | |
| print(f"✓ Corrected FEN: {fen}") | |
| # Additional verification: Check if board orientation matches turn | |
| # In FEN, rank 8 is first, rank 1 is last | |
| # If bottom of image shows black pieces, it's black's turn | |
| fen_parts = fen.split() | |
| piece_placement = fen_parts[0] | |
| active_color = fen_parts[1] if len(fen_parts) > 1 else 'w' | |
| # Get last rank (rank 1 in FEN, which is bottom if white's perspective) | |
| ranks = piece_placement.split('/') | |
| rank_1 = ranks[-1] # Last rank in FEN | |
| rank_8 = ranks[0] # First rank in FEN | |
| # Check which color dominates bottom rank | |
| # If showing from black's perspective, rank 8 should be at bottom | |
| # and active color should be 'b' | |
| black_pieces_in_rank8 = sum(1 for c in rank_8 if c.islower() and c.isalpha()) | |
| white_pieces_in_rank8 = sum(1 for c in rank_8 if c.isupper() and c.isalpha()) | |
| if black_pieces_in_rank8 > white_pieces_in_rank8 and active_color == 'w': | |
| print(f"⚠️ Warning: Rank 8 has more black pieces, likely black's perspective") | |
| print(f" Changing active color from 'w' to 'b'") | |
| fen_parts[1] = 'b' | |
| fen = ' '.join(fen_parts) | |
| # ==================================================================== | |
| # STEP 2: Validate FEN with python-chess | |
| # ==================================================================== | |
| try: | |
| import chess | |
| except ImportError: | |
| return "Error: python-chess not installed. Add 'python-chess' to requirements.txt" | |
| try: | |
| board = chess.Board(fen) | |
| print(f"✓ FEN validated successfully") | |
| print(f" Turn: {'White' if board.turn else 'Black'}") | |
| print(f" Legal moves: {board.legal_moves.count()}") | |
| except ValueError as e: | |
| return f"Error: Invalid FEN notation: {e}\nExtracted FEN: {fen}" | |
| # ==================================================================== | |
| # STEP 3: Analyze with Stockfish | |
| # ==================================================================== | |
| print("🔍 Analyzing position with Stockfish...") | |
| try: | |
| from stockfish import Stockfish | |
| except ImportError: | |
| return "Error: stockfish not installed. Add 'stockfish' to requirements.txt and install Stockfish binary" | |
| # Try to find Stockfish binary | |
| stockfish_paths = [ | |
| "/usr/games/stockfish", # Linux (apt-get install) | |
| "/usr/local/bin/stockfish", # Mac (brew install) | |
| "/usr/bin/stockfish", # Alternative Linux | |
| "stockfish", # In PATH | |
| "./stockfish", # Local directory | |
| "C:\\Program Files\\stockfish\\stockfish.exe" # Windows | |
| ] | |
| stockfish_path = None | |
| for path in stockfish_paths: | |
| if os.path.exists(path) or os.path.isfile(path): | |
| stockfish_path = path | |
| break | |
| if not stockfish_path: | |
| # Try running 'which stockfish' on Unix systems | |
| try: | |
| import subprocess | |
| result = subprocess.run(['which', 'stockfish'], | |
| capture_output=True, | |
| text=True, | |
| timeout=5) | |
| if result.returncode == 0: | |
| stockfish_path = result.stdout.strip() | |
| except: | |
| pass | |
| if not stockfish_path: | |
| return """Error: Stockfish binary not found. Install it: | |
| - Linux: sudo apt-get install stockfish | |
| - Mac: brew install stockfish | |
| - Windows: Download from stockfishchess.org | |
| Or set the path manually in the code.""" | |
| print(f"✓ Found Stockfish at: {stockfish_path}") | |
| # Initialize Stockfish | |
| try: | |
| stockfish = Stockfish( | |
| path=stockfish_path, | |
| depth=35, # Analysis depth (higher = stronger but slower) | |
| parameters={ | |
| "Threads": 2, | |
| "Minimum Thinking Time": 5000, # milliseconds | |
| "Hash": 1024, # MB of RAM | |
| } | |
| ) | |
| except Exception as e: | |
| return f"Error initializing Stockfish: {e}" | |
| # Set position | |
| stockfish.set_fen_position(fen) | |
| # Get best move | |
| print(" Computing best move...") | |
| best_move_uci = stockfish.get_best_move() | |
| if not best_move_uci: | |
| return "Error: Stockfish could not find a legal move. Check if position is valid." | |
| print(f"🎯 Best move (UCI): {best_move_uci}") | |
| # Get evaluation | |
| evaluation = stockfish.get_evaluation() | |
| eval_type = evaluation.get("type", "cp") | |
| eval_value = evaluation.get("value", 0) | |
| if eval_type == "mate": | |
| eval_str = f" (Mate in {abs(eval_value)})" | |
| else: | |
| # Centipawns to pawns | |
| eval_str = f" (Eval: {eval_value/100:+.2f})" | |
| # ==================================================================== | |
| # STEP 4: Convert UCI to Standard Algebraic Notation (SAN) | |
| # ==================================================================== | |
| try: | |
| uci_move = chess.Move.from_uci(best_move_uci) | |
| san_move = board.san(uci_move) | |
| # Check if move leads to check/checkmate | |
| board.push(uci_move) | |
| if board.is_checkmate(): | |
| check_str = " - Checkmate!" | |
| elif board.is_check(): | |
| check_str = " - Check" | |
| else: | |
| check_str = "" | |
| final_result = f"{san_move}{eval_str}{check_str}" | |
| print(f"✅ Best move: {final_result}") | |
| # Return JUST the move notation for clean submission | |
| return san_move | |
| except Exception as e: | |
| print(f"⚠️ Could not convert to SAN: {e}") | |
| # Fall back to UCI notation | |
| return best_move_uci | |
| except Exception as e: | |
| error_msg = f"Chess analysis failed: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| print(traceback.format_exc()) | |
| return error_msg | |
| class FinalAnswerInput(BaseModel): | |
| answer: str = Field(description="Final answer - EXACTLY what was asked, nothing more") | |
| def final_answer_tool(answer: str) -> str: | |
| """ | |
| Submit final answer. CRITICAL RULES: | |
| 1. ALWAYS call validate_answer() first | |
| 2. Answer must be EXACTLY what was asked | |
| 3. NO conversational text | |
| 4. NO explanations | |
| 5. Match requested format exactly | |
| """ | |
| if not isinstance(answer, str): | |
| answer = str(answer) | |
| print(f"✅ FINAL ANSWER SUBMITTED: {answer}") | |
| return answer | |
| # ============================================================================= | |
| # DEFINED TOOLS LIST | |
| # ============================================================================= | |
| defined_tools = [ | |
| # Planning & Reflection | |
| think_through_logic, | |
| create_plan, | |
| reflect_on_progress, | |
| validate_answer, | |
| # Core tools | |
| search_tool, | |
| calculator, | |
| code_interpreter, | |
| # File operations | |
| read_file, | |
| write_file, | |
| list_directory, | |
| # Specialized | |
| audio_transcription_tool, | |
| analyze_image, | |
| get_youtube_transcript, | |
| scrape_and_retrieve, | |
| analyze_chess_position, | |
| # Final | |
| final_answer_tool | |
| ] | |
| # ============================================================================= | |
| # AGENT STATE | |
| # ============================================================================= | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[AnyMessage], add_messages] | |
| turn: int | |
| has_plan: bool | |
| consecutive_errors: int | |
| tool_history: List[str] | |
| last_tool_was_thinking: bool | |
| # ============================================================================= | |
| # ENHANCED FALLBACK PARSER | |
| # ============================================================================= | |
| def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]: | |
| """Enhanced parser with multiple strategies.""" | |
| print(f"🔧 Fallback parsing (first 300 chars):\n{content[:300]}") | |
| tool_name = None | |
| tool_input = None | |
| # STRATEGY 1: Groq's <function=name{...}> format | |
| groq_match = re.search(r"<function=(\w+)\s*(\{.*?\})\s*(?:>|</function>)", content, re.DOTALL) | |
| if groq_match: | |
| try: | |
| tool_name = groq_match.group(1).strip() | |
| json_str = groq_match.group(2).strip() | |
| json_str = json_str.encode().decode('unicode_escape') | |
| tool_input = json.loads(json_str) | |
| print(f"✓ Parsed Groq format: {tool_name}") | |
| except: | |
| tool_name = None | |
| # STRATEGY 2: Standard <function(name)>{...} format | |
| if not tool_name: | |
| func_match = re.search(r"<function[(=]\s*([^)]+)\s*[)>](.*)", content, re.DOTALL | re.IGNORECASE) | |
| if func_match: | |
| try: | |
| tool_name = func_match.group(1).strip().replace("'", "").replace('"', '') | |
| remaining = func_match.group(2) | |
| json_start = remaining.find('{') | |
| if json_start != -1: | |
| json_str = remaining[json_start:].strip().rstrip(',') | |
| tool_input = json.loads(json_str) | |
| print(f"✓ Parsed standard format: {tool_name}") | |
| except: | |
| tool_name = None | |
| # STRATEGY 3: Tool mention with code block → wrap in code_interpreter | |
| if not tool_name and "```python" in content: | |
| try: | |
| code_match = re.search(r"```python\n(.*?)```", content, re.DOTALL) | |
| if code_match: | |
| code = code_match.group(1).strip() | |
| tool_name = "code_interpreter" | |
| tool_input = {"code": code} | |
| print(f"✓ Extracted Python code → code_interpreter") | |
| except: | |
| pass | |
| # STRATEGY 4: Direct tool mention → create minimal valid call | |
| if not tool_name: | |
| for tool in tools: | |
| if tool.name.lower() in content.lower(): | |
| tool_name = tool.name | |
| tool_input = {} | |
| # Try to extract arguments from content | |
| if tool.args_schema: | |
| schema = tool.args_schema.model_json_schema() | |
| for prop in schema.get('properties', {}).keys(): | |
| if prop in schema.get('required', []): | |
| # Use placeholder | |
| tool_input[prop] = "auto_extracted" | |
| print(f"✓ Found mention of '{tool_name}' → creating default call") | |
| break | |
| # STRATEGY 5: Emergency - if no tool detected, force a reasonable one | |
| if not tool_name: | |
| # If content looks like reasoning, use think_through_logic | |
| if len(content) > 50 and not any(kw in content.lower() for kw in ["error", "failed", "invalid"]): | |
| tool_name = "think_through_logic" | |
| tool_input = {"reasoning": content[:150]} | |
| print(f"⚠️ No tool detected → forcing think_through_logic") | |
| # Validate and create tool call | |
| if tool_name and tool_input is not None: | |
| matching_tools = [t for t in tools if t.name == tool_name] | |
| if matching_tools: | |
| return [ToolCall(name=tool_name, args=tool_input, id=str(uuid.uuid4()))] | |
| else: | |
| print(f"❌ Tool '{tool_name}' not in available tools") | |
| print("❌ All parsing strategies failed") | |
| return [] | |
| # ============================================================================= | |
| # CONDITIONAL EDGE FUNCTION | |
| # ============================================================================= | |
| def should_continue(state: AgentState): | |
| """Decide next step with robust logic.""" | |
| messages = state.get('messages', []) | |
| if not messages: | |
| return "agent" | |
| last_message = messages[-1] | |
| current_turn = state.get('turn', 0) | |
| # Debug: Print what we're checking | |
| msg_type = type(last_message).__name__ | |
| print(f"📍 Conditional check - Turn {current_turn}, Last msg type: {msg_type}") | |
| # 1. Check turn limit | |
| if current_turn >= MAX_TURNS: | |
| print(f"🛑 Max turns ({MAX_TURNS}) reached") | |
| return END | |
| # 2. If last message is ToolMessage, agent needs to process it | |
| if isinstance(last_message, ToolMessage): | |
| print(f"📨 Tool result received from '{last_message.name}' → back to agent") | |
| return "agent" | |
| # 3. If last message is AIMessage with tool calls | |
| if isinstance(last_message, AIMessage) and last_message.tool_calls: | |
| # Only check the FIRST tool call, not all of them | |
| first_tool = last_message.tool_calls[0] | |
| tool_name = first_tool.get("name", "") | |
| if tool_name == "final_answer_tool": | |
| return END | |
| else: | |
| return "tools" | |
| # 4. If AIMessage but no tool calls (reasoning text) | |
| if isinstance(last_message, AIMessage) and not last_message.tool_calls: | |
| # Check for consecutive AI messages (loop) | |
| if len(messages) >= 2 and isinstance(messages[-2], AIMessage) and not messages[-2].tool_calls: | |
| print(f"⚠️ Loop detected: 2 consecutive AI messages without tools") | |
| return END | |
| print(f"💭 AI message without tool call → continuing to agent (will force tool)") | |
| return "agent" | |
| # 5. Default: continue to agent | |
| print(f"🔄 Default → continuing to agent") | |
| # ============================================================================= | |
| # ENHANCED AGENT CLASS | |
| # ============================================================================= | |
| class PlanningReflectionAgent: | |
| def __init__(self): | |
| print("🧠 PlanningReflectionAgent initializing...") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY not set!") | |
| HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not HUGGINGFACEHUB_API_TOKEN: | |
| raise ValueError("HUGGINGFACEHUB_API_TOKEN secret is not set! Please add it to your Space secrets.") | |
| GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GOOGLE_API_KEY: | |
| raise ValueError("GOOGLE_API_KEY not set!") | |
| self.tools = defined_tools | |
| # Initialize RAG | |
| if not initialize_rag_components(): | |
| print("⚠️ RAG components failed to initialize.") | |
| # Build tool descriptions | |
| tool_desc_list = [] | |
| for tool in self.tools: | |
| if tool.args_schema: | |
| schema = tool.args_schema.model_json_schema() | |
| args_desc = [f" - {p}: {d.get('description', '')}" | |
| for p, d in schema.get('properties', {}).items()] | |
| desc = f"- {tool.name}:\n {tool.description}\n" + "\n".join(args_desc) | |
| else: | |
| desc = f"- {tool.name}: {tool.description}" | |
| tool_desc_list.append(desc) | |
| tool_descriptions = "\n".join(tool_desc_list) | |
| # ULTRA-AGGRESSIVE SYSTEM PROMPT | |
| self.system_prompt = f"""You are an elite AI agent for GAIA benchmark. Your ONLY job: provide the EXACT answer requested. | |
| ═══════════════════════════════════════════════════════════════ | |
| ⚠️ ABSOLUTE RULES - VIOLATE THESE AND YOU FAIL: | |
| ═══════════════════════════════════════════════════════════════ | |
| 1. **EVERY TURN MUST CALL EXACTLY ONE TOOL** - No exceptions | |
| 2. **NEVER OUTPUT REASONING TEXT WITHOUT A TOOL CALL** - You will fail | |
| 3. **IDENTIFY QUESTION TYPE FIRST** - Logic? Factual? Data? Math? | |
| 4. **LOGIC PUZZLES**: think_through_logic → calculator (if needed) → validate → final_answer | |
| 5. **FACTUAL QUESTIONS**: search_tool → validate → final_answer | |
| 6. **DATA QUESTIONS**: read_file → code_interpreter → validate → final_answer | |
| 7. **ALWAYS VALIDATE**: Call validate_answer() before final_answer_tool() | |
| 8. **FINAL ANSWER FORMAT**: EXACTLY what was asked. NO "The answer is..." or explanations | |
| ═══════════════════════════════════════════════════════════════ | |
| 📋 QUESTION TYPE GUIDE: | |
| ═══════════════════════════════════════════════════════════════ | |
| **RIDDLES/LOGIC PUZZLES** (No web search needed): | |
| - Brain teasers, puzzles, logical deduction | |
| - Strategy: think_through_logic → calculator (if math) → validate → final_answer | |
| - Example: "If 200 coins, 30 face-down, divide into equal piles..." | |
| Turn 1: think_through_logic("Adventurer takes 30 coins and flips them") | |
| Turn 2: calculator("30") [if needed] | |
| Turn 3: validate_answer("30", question) | |
| Turn 4: final_answer_tool("30") | |
| **FACTUAL/RESEARCH** (Need web): | |
| - Who, what, when, where questions | |
| - Strategy: search_tool → scrape_and_retrieve → validate → final_answer | |
| - Example: "What was Einstein's birthplace population in 1900?" | |
| Turn 1: search_tool("Albert Einstein birthplace") | |
| Turn 2: search_tool("Ulm Germany population 1900") | |
| Turn 3: validate_answer("50000", question) | |
| Turn 4: final_answer_tool("50000") | |
| **DATA ANALYSIS** (Need files): | |
| - CSV/Excel questions | |
| - Strategy: list_directory → read_file → code_interpreter → validate → final_answer | |
| **SIMPLE MATH**: | |
| - Calculations | |
| - Strategy: calculator() → validate_answer() → final_answer_tool() | |
| ═══════════════════════════════════════════════════════════════ | |
| 🎓 CRITICAL EXAMPLES: | |
| ═══════════════════════════════════════════════════════════════ | |
| Example 1: Logic Puzzle | |
| Q: "Coin riddle with 200 coins, 30 face-down..." | |
| ✅ CORRECT: | |
| Turn 1: think_through_logic("Take 30 coins, flip all") | |
| Turn 2: validate_answer("30", "coin riddle...") | |
| Turn 3: final_answer_tool("30") | |
| ❌ WRONG: | |
| Turn 1: [reasoning text without tool] ← FAILS! | |
| Example 2: Letter Bank Puzzle | |
| Q: "Use letters to spell sentences, which letters need changing?" | |
| ✅ CORRECT: | |
| Turn 1: code_interpreter("code to count letters...") | |
| Turn 2: validate_answer("A, B, C", question) | |
| Turn 3: final_answer_tool("A, B, C") | |
| Example 3: Math Problem | |
| Q: "System of equations to solve..." | |
| ✅ CORRECT: | |
| Turn 1: code_interpreter("import numpy; solve equations...") | |
| Turn 2: validate_answer("0, 1, 2", question) | |
| Turn 3: final_answer_tool("0, 1, 2") | |
| ═══════════════════════════════════════════════════════════════ | |
| 📚 AVAILABLE TOOLS: | |
| ═══════════════════════════════════════════════════════════════ | |
| {tool_descriptions} | |
| ═══════════════════════════════════════════════════════════════ | |
| ⚡ EXECUTION RULES: | |
| ═══════════════════════════════════════════════════════════════ | |
| - If you output text without a tool call, you have FAILED | |
| - If you're unsure, use think_through_logic() to organize thoughts | |
| - ALWAYS call a tool - preferably the right one for the question type | |
| - After EVERY tool result, decide: "Do I have the answer? → validate → submit" | |
| - If stuck after 3 turns: call reflect_on_progress() | |
| REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. | |
| ═══════════════════════════════════════════════════════════════ | |
| """ | |
| #. Initialize the LLM () | |
| print("Initializing Groq LLM...") | |
| try: | |
| self.llm_with_tools = ChatGroq( | |
| temperature=0, | |
| groq_api_key=GROQ_API_KEY, | |
| model_name="qwen/qwen3-32b", | |
| max_tokens=4096, | |
| timeout=60 | |
| ).bind_tools(self.tools, tool_choice="auto") | |
| print("✅ LLM initialized without FORCED tool usage.") | |
| except Exception as e: | |
| print(f"❌ Error initializing HuggingFace: {e}") | |
| raise | |
| print("Initializing LLM Endpoint...") | |
| # print("Initializing HuggingFace LLM...") | |
| # | |
| # llm = HuggingFaceEndpoint( | |
| # repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", # Free on HF Inference API | |
| # huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, | |
| # max_new_tokens=4096, | |
| # temperature=0.01, | |
| # ) | |
| # chat_llm = ChatHuggingFace(llm=llm) | |
| # print("✅ HuggingFace LLM Endpoint initialized.") | |
| # | |
| # # Bind tools to the LLM | |
| # self.llm_with_tools = chat_llm.bind_tools(self.tools) | |
| # print("✅ Tools bound to LLM.") | |
| # print("Initializing Google Gemini LLM...") | |
| # try: | |
| # self.llm_with_tools = ChatGoogleGenerativeAI( | |
| # model="gemini-2.5-flash", # Latest model | |
| # google_api_key=GOOGLE_API_KEY, | |
| # temperature=0, | |
| # max_output_tokens=8192, | |
| # timeout=60, | |
| # convert_system_message_to_human=True # Important for Gemini | |
| # ).bind_tools(self.tools, tool_choice="auto") | |
| # print("✅ Gemini LLM initialized.") | |
| # except Exception as e: | |
| # print(f"❌ Error initializing Gemini: {e}") | |
| # raise | |
| # Agent Node with AGGRESSIVE tool forcing | |
| def agent_node(state: AgentState): | |
| current_turn = state.get('turn', 0) + 1 | |
| print(f"\n{'='*70}") | |
| print(f"🤖 AGENT TURN {current_turn}/{MAX_TURNS}") | |
| print('='*70) | |
| if current_turn > MAX_TURNS: | |
| return { | |
| "messages": [SystemMessage(content="Max turns reached.")], | |
| "turn": current_turn | |
| } | |
| # Check if we should force reflection | |
| consecutive_errors = state.get('consecutive_errors', 0) | |
| should_reflect = (current_turn > 5 and current_turn % REFLECT_EVERY_N_TURNS == 0) or consecutive_errors >= 3 | |
| messages_to_send = state["messages"].copy() | |
| # Add tool-forcing message if last turn had no tool call | |
| if len(messages_to_send) >= 2: | |
| last_msg = messages_to_send[-1] | |
| if isinstance(last_msg, AIMessage) and not last_msg.tool_calls: | |
| force_msg = SystemMessage( | |
| content="⚠️ CRITICAL: You MUST call a tool this turn. NO reasoning text. Pick the most appropriate tool and call it now." | |
| ) | |
| messages_to_send.append(force_msg) | |
| print("🚨 Injecting tool-forcing message") | |
| # Add reflection hint if needed | |
| if should_reflect: | |
| hint = SystemMessage( | |
| content="⚠️ HINT: Multiple turns without progress. Consider calling reflect_on_progress() or try a different approach." | |
| ) | |
| messages_to_send.append(hint) | |
| print("🤔 Injecting reflection hint") | |
| # Invoke LLM with retries and fallback | |
| max_retries = 3 | |
| ai_message = None | |
| for attempt in range(max_retries): | |
| try: | |
| ai_message = self.llm_with_tools.invoke(messages_to_send) | |
| # If we got a valid response with tool calls, break | |
| if ai_message.tool_calls: | |
| break | |
| # If no tool calls, this is a problem | |
| print(f"⚠️ LLM returned no tool calls on attempt {attempt+1}") | |
| except Exception as e: | |
| error_str = str(e) | |
| print(f"⚠️ LLM attempt {attempt+1}/{max_retries} failed: {error_str[:200]}") | |
| # If tool_use_failed, try without strict binding | |
| if "tool_use_failed" in error_str and attempt < max_retries - 1: | |
| print("🔧 Trying without strict tool enforcement...") | |
| try: | |
| simple_llm = ChatGroq( | |
| temperature=0, | |
| groq_api_key=os.getenv("GROQ_API_KEY"), | |
| model_name="llama-3.3-70b-versatile", | |
| max_tokens=4096, | |
| timeout=60 | |
| ) | |
| # Add explicit tool forcing to the message | |
| force_tool_msg = SystemMessage( | |
| content="You MUST call a tool. Respond with a tool call, not reasoning text." | |
| ) | |
| ai_message = simple_llm.invoke(messages_to_send + [force_tool_msg]) | |
| # Try to parse tool calls from content | |
| if ai_message.content and not ai_message.tool_calls: | |
| parsed = parse_tool_call_from_string(ai_message.content, self.tools) | |
| if parsed: | |
| ai_message.tool_calls = parsed | |
| ai_message.content = "" | |
| print("✓ Fallback parsing succeeded") | |
| break | |
| except Exception as e2: | |
| print(f"⚠️ Fallback also failed: {e2}") | |
| if attempt == max_retries - 1: | |
| # Last resort: inject a default tool call | |
| print("🚨 All attempts failed - forcing think_through_logic") | |
| ai_message = AIMessage( | |
| content="", | |
| tool_calls=[ToolCall( | |
| name="think_through_logic", | |
| args={"reasoning": "Processing question"}, | |
| id=str(uuid.uuid4()) | |
| )] | |
| ) | |
| else: | |
| time.sleep(2 ** attempt) | |
| # If still no tool calls after all attempts, force one | |
| if not ai_message.tool_calls: | |
| if isinstance(ai_message.content, str) and ai_message.content.strip(): | |
| # Try one more parse | |
| parsed = parse_tool_call_from_string(ai_message.content, self.tools) | |
| if parsed: | |
| ai_message.tool_calls = parsed | |
| ai_message.content = "" | |
| print("✓ Final parse succeeded") | |
| else: | |
| # Absolute last resort | |
| print("🚨 EMERGENCY: Forcing think_through_logic") | |
| ai_message.tool_calls = [ToolCall( | |
| name="think_through_logic", | |
| args={"reasoning": "analyzing question"}, | |
| id=str(uuid.uuid4()) | |
| )] | |
| ai_message.content = "" | |
| # Track tool usage | |
| tool_history = state.get('tool_history', []) | |
| has_plan = state.get('has_plan', False) | |
| if ai_message.tool_calls: | |
| tool_name = ai_message.tool_calls[0]['name'] | |
| print(f"🔧 Tool Call: {tool_name}") | |
| tool_history.append(tool_name) | |
| if tool_name == "create_plan": | |
| has_plan = True | |
| else: | |
| print(f"⚠️ No tool call (this shouldn't happen!)") | |
| print(f"💭 Content: {ai_message.content[:200]}...") | |
| return { | |
| "messages": [ai_message], | |
| "turn": current_turn, | |
| "has_plan": has_plan, | |
| "tool_history": tool_history, | |
| "last_tool_was_thinking": ai_message.tool_calls and ai_message.tool_calls[0]['name'] == 'think_through_logic' | |
| } | |
| # Tool Node with Error Tracking (FIXED) | |
| def tool_node_wrapper(state: AgentState): | |
| """Executes tools and tracks errors.""" | |
| print(f"🔧 Executing tools...") | |
| # Create fresh ToolNode instance | |
| tool_executor = ToolNode(self.tools) | |
| # Invoke properly | |
| result = tool_executor.invoke(state) | |
| # Track errors | |
| consecutive_errors = state.get('consecutive_errors', 0) | |
| if result.get('messages'): | |
| last_msg = result['messages'][-1] | |
| if isinstance(last_msg, ToolMessage): | |
| if "Error" in last_msg.content or "error" in last_msg.content.lower(): | |
| consecutive_errors += 1 | |
| print(f"⚠️ Tool error detected (consecutive: {consecutive_errors})") | |
| else: | |
| consecutive_errors = 0 | |
| result['consecutive_errors'] = consecutive_errors | |
| return result | |
| # Build Graph | |
| print("Building graph...") | |
| graph_builder = StateGraph(AgentState) | |
| graph_builder.add_node("agent", agent_node) | |
| graph_builder.add_node("tools", tool_node_wrapper) | |
| graph_builder.add_edge(START, "agent") | |
| graph_builder.add_conditional_edges( | |
| "agent", | |
| should_continue, | |
| { | |
| "tools": "tools", | |
| "agent": "agent", | |
| END: END | |
| } | |
| ) | |
| graph_builder.add_edge("tools", "agent") | |
| self.graph = graph_builder.compile() | |
| print("✅ Graph compiled successfully.") | |
| def __call__(self, question: str, file_path: str = None) -> str: | |
| """Execute agent on a question.""" | |
| print(f"\n{'='*70}") | |
| print(f"🎯 NEW QUESTION") | |
| print(f"{'='*70}") | |
| print(f"Q: {question[:200]}{'...' if len(question) > 200 else ''}") | |
| if file_path: | |
| print(f"📎 File attached: {file_path}") | |
| print(f"{'='*70}\n") | |
| # Enhanced question context with file information | |
| question_text = question | |
| if file_path: | |
| file_ext = Path(file_path).suffix.lower() | |
| file_type = "unknown" | |
| if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']: | |
| file_type = "image" | |
| elif file_ext in ['.mp3', '.wav', '.m4a', '.flac']: | |
| file_type = "audio" | |
| elif file_ext in ['.csv', '.xlsx', '.xls']: | |
| file_type = "data" | |
| elif file_ext in ['.txt', '.pdf', '.doc', '.docx']: | |
| file_type = "document" | |
| question_text += f"\n\n[FILE ATTACHED: {file_path}]" | |
| question_text += f"\n[FILE TYPE: {file_type}]" | |
| question_text += f"\nIMPORTANT: Use the appropriate tool to access this file first!" | |
| graph_input = { | |
| "messages": [ | |
| SystemMessage(content=self.system_prompt), | |
| HumanMessage(content=question_text) | |
| ], | |
| "file_path": file_path, | |
| "turn": 0, | |
| "has_plan": False, | |
| "consecutive_errors": 0, | |
| "tool_history": [], | |
| "last_tool_was_thinking": False | |
| } | |
| final_answer = "AGENT FAILED TO PRODUCE ANSWER" | |
| all_messages = [] | |
| try: | |
| config = {"recursion_limit": MAX_TURNS + 10} | |
| for event in self.graph.stream(graph_input, stream_mode="values", config=config): | |
| if not event.get('messages'): | |
| continue | |
| all_messages = event["messages"] | |
| last_message = all_messages[-1] | |
| # Check for final answer | |
| if isinstance(last_message, AIMessage) and last_message.tool_calls: | |
| for tool_call in last_message.tool_calls: | |
| if tool_call.get("name") == "final_answer_tool": | |
| args = tool_call.get('args', {}) | |
| if 'answer' in args: | |
| final_answer = args['answer'] | |
| print(f"\n{'='*70}") | |
| print(f"✅ FINAL ANSWER: '{final_answer}'") | |
| print(f"{'='*70}\n") | |
| break | |
| elif isinstance(last_message, ToolMessage): | |
| preview = last_message.content[:200].replace('\n', ' ') | |
| print(f"📊 Tool '{last_message.name}' result: {preview}...") | |
| elif isinstance(last_message, AIMessage) and not last_message.tool_calls: | |
| print(f"💭 AI: {last_message.content[:200]}...") | |
| # If no final answer, try to extract from tool messages | |
| if final_answer == "AGENT FAILED TO PRODUCE ANSWER": | |
| print("⚠️ No final_answer_tool called. Checking tool results...") | |
| for msg in reversed(all_messages): | |
| if isinstance(msg, ToolMessage): | |
| if msg.name in ["calculator", "think_through_logic", "code_interpreter"]: | |
| content = msg.content.strip() | |
| # Look for short, answer-like content | |
| if content and len(content) < 200 and not content.startswith("Error"): | |
| # Extract just the result part | |
| lines = content.split('\n') | |
| for line in reversed(lines): | |
| if line.strip() and not line.startswith(('✅', '⚠️', 'Next', 'Remember')): | |
| final_answer = line.strip() | |
| print(f"📝 Extracted from {msg.name}: '{final_answer}'") | |
| break | |
| break | |
| # Clean the answer | |
| cleaned = str(final_answer).strip() | |
| # Remove prefixes | |
| prefixes = [ | |
| "the answer is:", "here is the answer:", "based on", | |
| "final answer:", "answer:", "the final answer is:", | |
| "my answer is:", "according to", "i found that", | |
| "the result is:", "result:" | |
| ] | |
| for prefix in prefixes: | |
| if cleaned.lower().startswith(prefix.lower()): | |
| potential = cleaned[len(prefix):].strip() | |
| if potential: | |
| cleaned = potential | |
| break | |
| # Remove code fences and quotes | |
| cleaned = remove_fences_simple(cleaned) | |
| while cleaned.startswith("`") and cleaned.endswith("`"): | |
| cleaned = cleaned[1:-1].strip() | |
| if (cleaned.startswith('"') and cleaned.endswith('"')) or \ | |
| (cleaned.startswith("'") and cleaned.endswith("'")): | |
| cleaned = cleaned[1:-1].strip() | |
| # Remove trailing period for short answers | |
| if cleaned.endswith('.') and len(cleaned.split()) < 10: | |
| cleaned = cleaned[:-1] | |
| print(f"\n{'='*70}") | |
| print(f"🎉 RETURNING ANSWER") | |
| print(f"{'='*70}") | |
| print(f"{cleaned}") | |
| print(f"{'='*70}\n") | |
| return cleaned | |
| except Exception as e: | |
| print(f"❌ Graph error: {e}") | |
| print(traceback.format_exc()) | |
| return f"AGENT ERROR: {e}" | |
| # ============================================================================= | |
| # GLOBAL AGENT INSTANTIATION | |
| # ============================================================================= | |
| agent = None | |
| try: | |
| initialize_rag_components() | |
| agent = PlanningReflectionAgent() | |
| print("✅ Global PlanningReflectionAgent instantiated.") | |
| # Verify it's callable | |
| if not callable(agent): | |
| print("❌ ERROR: Agent not callable!") | |
| agent = None | |
| else: | |
| print("✅ Agent is callable.") | |
| if asr_pipeline is None: | |
| print("⚠️ ASR Pipeline not loaded.") | |
| except Exception as e: | |
| print(f"❌ FATAL: Agent initialization failed: {e}") | |
| traceback.print_exc() | |
| agent = None | |
| # ============================================================================= | |
| # RUN AND SUBMIT FUNCTION | |
| # ============================================================================= | |
| def run_and_submit_all(profile: gr.OAuthProfile | None): | |
| """ | |
| Fetches all questions, runs the BasicAgent on them, submits all answers, | |
| and displays the results. | |
| """ | |
| space_id = os.getenv("SPACE_ID") | |
| if profile: | |
| username = f"{profile.username}" | |
| print(f"User logged in: {username}") | |
| else: | |
| print("User not logged in.") | |
| return "Please Login to Hugging Face with the button.", None | |
| # Use the globally instantiated agent | |
| global agent | |
| if agent is None: | |
| error_msg = "FATAL: Agent failed to initialize at startup. Check logs for errors." | |
| print(error_msg) | |
| return error_msg, None | |
| print("✅ Using globally instantiated PlanningReflectionAgent") | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| submit_url = f"{api_url}/submit" | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| print(agent_code) | |
| # 2. Fetch Questions | |
| print(f"\n{'='*70}") | |
| print(f"📥 FETCHING QUESTIONS") | |
| print(f"{'='*70}") | |
| print(f"Fetching questions from: {questions_url}") | |
| try: | |
| response = requests.get(questions_url, timeout=15) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| print("Fetched questions list is empty.") | |
| return "Fetched questions list is empty or invalid format.", None | |
| print(f"✅ Fetched {len(questions_data)} questions.") | |
| print(f"{'='*70}\n") | |
| except requests.exceptions.RequestException as e: | |
| print(f"❌ Error fetching questions: {e}") | |
| return f"Error fetching questions: {e}", None | |
| except requests.exceptions.JSONDecodeError as e: | |
| print(f"❌ Error decoding JSON response from questions endpoint: {e}") | |
| print(f"Response text: {response.text[:500]}") | |
| return f"Error decoding server response for questions: {e}", None | |
| except Exception as e: | |
| print(f"❌ An unexpected error occurred fetching questions: {e}") | |
| return f"An unexpected error occurred fetching questions: {e}", None | |
| # Load answer sheet | |
| answer_sheet = load_answer_sheet("answer_sheet_json.json") | |
| # If answer sheet doesn't exist, create template | |
| if not answer_sheet: | |
| create_answer_sheet_template(questions_data, "answer_sheet.json") | |
| print("\n⚠️ Please fill in the answer_sheet.json file with correct answers") | |
| print(" Then run the script again to check agent performance\n") | |
| results = [] | |
| local_correct = 0 | |
| local_total = 0 | |
| # 3. Run your Agent | |
| print(f"\n{'='*70}") | |
| print(f"🚀 STARTING EVALUATION") | |
| print(f"{'='*70}") | |
| print(f"Total questions to process: {len(questions_data)}") | |
| print(f"{'='*70}\n") | |
| results_log = [] | |
| answers_payload = [] | |
| for idx, item in enumerate(questions_data, 1): | |
| print(f"\n{'='*70}") | |
| print(f"📝 PROCESSING QUESTION {idx}/{len(questions_data)}") | |
| print(f"{'='*70}") | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| correct_answer = answer_sheet.get(task_id, "") | |
| # Look for file locally in files/ directory | |
| local_file_path = None | |
| files_dir = "files" | |
| try: | |
| # Check if files directory exists | |
| if os.path.exists(files_dir): | |
| # Look for any file that starts with the task_id | |
| matching_files = [f for f in os.listdir(files_dir) if f.startswith(task_id)] | |
| if matching_files: | |
| # Use the first matching file | |
| local_file_path = os.path.join(files_dir, matching_files[0]) | |
| file_size = os.path.getsize(local_file_path) | |
| abs_path = os.path.abspath(local_file_path) | |
| print(f"✅ Found file: {matching_files[0]} ({file_size} bytes)") | |
| print(f" Path: {abs_path}") | |
| else: | |
| print(f"ℹ️ No file found for task {task_id}, proceeding without file.") | |
| else: | |
| print(f"⚠️ Warning: '{files_dir}' directory not found.") | |
| except Exception as e: | |
| print(f"❌ Error looking for file: {e}") | |
| try: | |
| # Pass file_path to agent | |
| submitted_answer = agent(question_text, local_file_path) | |
| answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
| # Check if answer is correct | |
| is_correct = submitted_answer.strip().lower() == correct_answer.strip().lower() | |
| correctness = "✅ CORRECT" if is_correct else "❌ WRONG" | |
| # Log with correctness indicator | |
| print(f"\n{correctness} - Task {task_id}") | |
| print(f" Submitted: '{submitted_answer}'") | |
| print(f" Expected: '{correct_answer}'") | |
| results_log.append({ | |
| "Task ID": task_id, | |
| "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, | |
| "Submitted Answer": submitted_answer, | |
| "Correct Answer": correct_answer, | |
| "Status": "✅" if is_correct else "❌" | |
| }) | |
| print(f"✅ Question {idx}/{len(questions_data)} completed") | |
| except Exception as e: | |
| print(f"❌ Error running agent on task {task_id}: {e}") | |
| print(traceback.format_exc()) | |
| results_log.append({ | |
| "Task ID": task_id, | |
| "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, | |
| "Submitted Answer": f"AGENT ERROR: {e}", | |
| "Correct Answer": correct_answer, | |
| "Status": "❌" | |
| }) | |
| # Continue with other questions even if one fails | |
| answers_payload.append({"task_id": task_id, "submitted_answer": f"ERROR: {str(e)[:100]}"}) | |
| # Summary after all questions processed | |
| print(f"\n{'='*70}") | |
| print(f"✅ ALL QUESTIONS PROCESSED") | |
| print(f"{'='*70}") | |
| print(f"Total answers collected: {len(answers_payload)}") | |
| # Calculate pre-submission accuracy | |
| correct_count = sum(1 for log in results_log if log.get("Status") == "✅") | |
| total_count = len(results_log) | |
| accuracy = (correct_count / total_count * 100) if total_count > 0 else 0 | |
| print(f"\n{'='*70}") | |
| print(f"📊 PRE-SUBMISSION SUMMARY") | |
| print(f"{'='*70}") | |
| print(f"Correct: {correct_count}/{total_count} ({accuracy:.1f}%)") | |
| print(f"{'='*70}\n") | |
| if not answers_payload: | |
| print("⚠️ Agent did not produce any answers to submit.") | |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
| # 4. Prepare Submission | |
| submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
| # 5. Submit | |
| print(f"\n{'='*70}") | |
| print(f"📤 SUBMITTING TO API") | |
| print(f"{'='*70}") | |
| print(f"URL: {submit_url}") | |
| print(f"Username: {username}") | |
| print(f"Answers to submit: {len(answers_payload)}") | |
| print(f"{'='*70}\n") | |
| try: | |
| print("⏳ Sending POST request...") | |
| response = requests.post(submit_url, json=submission_data, timeout=60) | |
| print(f"✅ Got response: Status {response.status_code}") | |
| response.raise_for_status() | |
| result_data = response.json() | |
| print(f"\n{'='*70}") | |
| print(f"📊 SUBMISSION RESULTS") | |
| print(f"{'='*70}") | |
| print(f"Response data: {result_data}") | |
| print(f"{'='*70}\n") | |
| final_status = ( | |
| f"Submission Successful!\n" | |
| f"User: {result_data.get('username')}\n" | |
| f"Overall Score: {result_data.get('score', 'N/A')}% " | |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
| f"Message: {result_data.get('message', 'No message received.')}" | |
| ) | |
| print(final_status) | |
| print("="*70) | |
| print("✅ Submission successful.") | |
| results_df = pd.DataFrame(results_log) | |
| return final_status, results_df | |
| except requests.exceptions.HTTPError as e: | |
| error_detail = f"Server responded with status {e.response.status_code}." | |
| try: | |
| error_json = e.response.json() | |
| error_detail += f" Detail: {error_json.get('detail', e.response.text)}" | |
| except requests.exceptions.JSONDecodeError: | |
| error_detail += f" Response: {e.response.text[:500]}" | |
| status_message = f"Submission Failed: {error_detail}" | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.Timeout: | |
| status_message = "Submission Failed: The request timed out." | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.RequestException as e: | |
| status_message = f"Submission Failed: Network error - {e}" | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except Exception as e: | |
| status_message = f"An unexpected error occurred during submission: {e}" | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(traceback.format_exc()) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| # --- Build Gradio Interface using Blocks --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Basic Agent Evaluation Runner") | |
| gr.Markdown( | |
| """ | |
| **Instructions:** | |
| 1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ... | |
| 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission. | |
| 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. | |
| --- | |
| **Disclaimers:** | |
| Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions). | |
| This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async. | |
| """ | |
| ) | |
| gr.LoginButton() | |
| run_button = gr.Button("Run Evaluation & Submit All Answers") | |
| status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
| results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) | |
| run_button.click( | |
| fn=run_and_submit_all, | |
| outputs=[status_output, results_table] | |
| ) | |
| if __name__ == "__main__": | |
| print("\n" + "-"*30 + " App Starting " + "-"*30) | |
| space_host_startup = os.getenv("SPACE_HOST") | |
| space_id_startup = os.getenv("SPACE_ID") | |
| if space_host_startup: | |
| print(f"✅ SPACE_HOST found: {space_host_startup}") | |
| print(f" Runtime URL should be: https://{space_host_startup}.hf.space") | |
| else: | |
| print("ℹ️ SPACE_HOST environment variable not found (running locally?).") | |
| if space_id_startup: | |
| print(f"✅ SPACE_ID found: {space_id_startup}") | |
| print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") | |
| print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main") | |
| else: | |
| print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.") | |
| print("-"*(60 + len(" App Starting ")) + "\n") | |
| print("Launching Gradio Interface for Basic Agent Evaluation...") | |
| demo.launch(debug=True, share=False) |