Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import subprocess | |
| import json | |
| import re | |
| import traceback | |
| import contextlib | |
| import uuid | |
| import time | |
| import ast | |
| from typing import List, Optional, TypedDict, Annotated, Dict | |
| from pathlib import Path | |
| from collections import Counter | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from pydantic import BaseModel, Field | |
| # Multimodal & Web Tools | |
| from transformers import pipeline | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from bs4 import BeautifulSoup | |
| import requests | |
| from PIL import Image | |
| import base64 | |
| from googleapiclient.discovery import build | |
| from googleapiclient.errors import HttpError | |
| # LangChain & LangGraph | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, AnyMessage, ToolCall | |
| from langchain_core.tools import tool | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.graph import START, END, StateGraph | |
| from langchain_groq import ChatGroq | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
| # RAG | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_core.documents import Document | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| MAX_TURNS = 25 | |
| MAX_MESSAGE_LENGTH = 8000 | |
| REFLECT_EVERY_N_TURNS = 5 | |
| # ============================================================================= | |
| # GLOBAL RAG COMPONENTS | |
| # ============================================================================= | |
| global_embeddings = None | |
| global_text_splitter = None | |
| def initialize_rag_components(): | |
| """Initialize RAG components globally.""" | |
| global global_embeddings, global_text_splitter | |
| if global_embeddings is None: | |
| print("Initializing RAG embeddings...") | |
| try: | |
| global_embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| print("✅ Embeddings initialized.") | |
| except Exception as e: | |
| print(f"⚠️ Failed to initialize embeddings: {e}") | |
| return False | |
| if global_text_splitter is None: | |
| print("Initializing text splitter...") | |
| global_text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| separators=["\n\n", "\n", ". ", " ", ""] | |
| ) | |
| print("✅ Text splitter initialized.") | |
| return True | |
| # ============================================================================= | |
| # ASR INITIALIZATION | |
| # ============================================================================= | |
| asr_pipeline = None | |
| try: | |
| print("Loading ASR (Whisper) pipeline globally...") | |
| device = 0 if torch.cuda.is_available() else -1 | |
| device_name = "cuda:0" if device == 0 else "cpu" | |
| print(f"Attempting to use device: {device_name} for ASR.") | |
| asr_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base", | |
| torch_dtype=torch.float16 if device == 0 else torch.float32, | |
| device=device | |
| ) | |
| print("✅ ASR (Whisper) pipeline loaded successfully.") | |
| except Exception as e: | |
| print(f"⚠️ Warning: Could not load ASR pipeline globally. Error: {e}") | |
| asr_pipeline = None | |
| # ============================================================================= | |
| # UTILITY FUNCTIONS | |
| # ============================================================================= | |
| def remove_fences_simple(text): | |
| """Remove code fences from text.""" | |
| original_text = text | |
| text = text.strip() | |
| if text.startswith("```") and text.endswith("```"): | |
| text = text[3:-3].strip() | |
| if '\n' in text: | |
| first_line, rest = text.split('\n', 1) | |
| if first_line.strip().replace('_','').isalnum() and len(first_line.strip()) < 15: | |
| text = rest.strip() | |
| return text | |
| return original_text | |
| def truncate_if_needed(content: str, max_length: int = MAX_MESSAGE_LENGTH) -> str: | |
| """Truncate content if it exceeds max length.""" | |
| if len(content) > max_length: | |
| return content[:max_length] + f"\n...[truncated, {len(content)} total chars]" | |
| return content | |
| def find_file(path: str) -> Optional[Path]: | |
| """Find a file by trying multiple path variations.""" | |
| script_dir = Path.cwd() | |
| safe_path = Path(path).as_posix() | |
| paths_to_try = [ | |
| script_dir / safe_path, | |
| Path(safe_path), | |
| script_dir / Path(path).name | |
| ] | |
| for attempt_path in paths_to_try: | |
| if attempt_path.exists(): | |
| return attempt_path | |
| return None | |
| # ============================================================================= | |
| # PLANNING & REFLECTION TOOLS | |
| # ============================================================================= | |
| class ThinkInput(BaseModel): | |
| reasoning: str = Field(description="Brief reasoning summary (under 150 chars)") | |
| def think_through_logic(reasoning: str) -> str: | |
| """ | |
| Use this to work through logic puzzles, riddles, or reasoning problems. | |
| Call this when: | |
| - The question is a riddle or brain teaser | |
| - You need to reason through a logical problem | |
| - No external information is needed, just thinking | |
| After thinking, use calculator if math is involved, then validate and submit answer. | |
| """ | |
| print(f"🧠 Thinking: {reasoning[:100]}...") | |
| return f"""✅ Logic reasoning recorded. | |
| Next steps: | |
| 1. If math needed → use calculator() | |
| 2. Once you have answer → use validate_answer() | |
| 3. Then → use final_answer_tool() | |
| Remember: You MUST call another tool. Do not output reasoning text.""" | |
| class PlanInput(BaseModel): | |
| task_summary: str = Field(description="Very brief task summary (under 80 chars)") | |
| def create_plan(task_summary: str) -> str: | |
| """ | |
| Creates a plan for multi-step questions. Use for complex tasks only. | |
| Keep the summary VERY brief to avoid errors. | |
| """ | |
| print(f"📋 Planning: {task_summary[:80]}...") | |
| return f"""✅ Plan created for: {task_summary} | |
| FRAMEWORK: | |
| 1. What info do I need? | |
| 2. What tools will I use? | |
| 3. In what order? | |
| Now execute step 1. You MUST call a tool next.""" | |
| class ReflectInput(BaseModel): | |
| situation: str = Field(description="Brief situation summary (under 80 chars)") | |
| def reflect_on_progress(situation: str) -> str: | |
| """ | |
| Reflects on progress when stuck. Use after 5+ turns without progress. | |
| Keep situation summary VERY brief. | |
| """ | |
| print(f"🤔 Reflecting: {situation[:80]}...") | |
| return f"""🔍 REFLECTION on: {situation} | |
| QUESTIONS: | |
| 1. Am I using the right approach? | |
| 2. Should I try a different tool? | |
| 3. Do I actually have the answer already? | |
| Take a DIFFERENT approach now. You MUST call a tool next.""" | |
| class ValidateInput(BaseModel): | |
| proposed_answer: str = Field(description="The answer to validate") | |
| original_question: str = Field(description="Original question (first 100 chars)") | |
| def validate_answer(proposed_answer: str, original_question: str) -> str: | |
| """ | |
| Validates answer format before submission. ALWAYS use before final_answer_tool. | |
| """ | |
| print(f"✓ Validating: '{proposed_answer[:50]}...'") | |
| issues = [] | |
| warnings = [] | |
| # Check for conversational fluff | |
| fluff = ["the answer is", "based on", "according to", "i found", "here is"] | |
| if any(p in proposed_answer.lower() for p in fluff): | |
| issues.append("❌ Remove conversational text. Answer only.") | |
| # Check for code fences | |
| if "```" in proposed_answer: | |
| issues.append("❌ Remove code fences (```).") | |
| # Check length | |
| if len(proposed_answer) > 500: | |
| warnings.append("⚠️ Answer very long. Just the answer?") | |
| # Check for number questions | |
| if any(k in original_question.lower() for k in ["how many", "what number", "count"]): | |
| if not any(c.isdigit() for c in proposed_answer): | |
| warnings.append("⚠️ Question asks for number but answer has no digits.") | |
| if issues: | |
| return "🚫 VALIDATION FAILED:\n" + "\n".join(issues) + "\n\nFix then retry." | |
| if warnings: | |
| return "⚠️ WARNINGS:\n" + "\n".join(warnings) + "\n\nConsider fixing, or proceed if confident." | |
| return "✅ VALIDATION PASSED! Now call final_answer_tool() with this answer." | |
| # ============================================================================= | |
| # CORE TOOLS | |
| # ============================================================================= | |
| class SearchInput(BaseModel): | |
| query: str = Field(description="Search query (concise)") | |
| def search_tool(query: str) -> str: | |
| """Searches web via DuckDuckGo. Use for facts, recent info.""" | |
| if not isinstance(query, str) or not query.strip(): | |
| return "Error: Invalid query." | |
| print(f"🔍 Searching: {query}") | |
| try: | |
| search = DuckDuckGoSearchRun() | |
| result = search.run(query) | |
| return truncate_if_needed(result) | |
| except Exception as e: | |
| return f"Search error: {str(e)}" | |
| class CalcInput(BaseModel): | |
| expression: str = Field(description="Math expression (e.g., '2+2', 'sqrt(16)')") | |
| def calculator(expression: str) -> str: | |
| """ | |
| Evaluates math expressions. Use for ANY calculations. | |
| Supports: +, -, *, /, **, sqrt, sin, cos, log, pi, e, etc. | |
| """ | |
| if not isinstance(expression, str) or not expression.strip(): | |
| return "Error: Invalid expression." | |
| print(f"🧮 Calculating: {expression}") | |
| try: | |
| import math | |
| safe_dict = { | |
| 'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, | |
| 'log': math.log, 'log10': math.log10, 'exp': math.exp, | |
| 'pi': math.pi, 'e': math.e, 'abs': abs, 'round': round, | |
| 'pow': pow, 'sum': sum, 'min': min, 'max': max | |
| } | |
| result = eval(expression, {"__builtins__": {}}, safe_dict) | |
| return str(result) | |
| except Exception as e: | |
| return f"Calculation error for '{expression}': {str(e)}" | |
| class CodeInput(BaseModel): | |
| code: str = Field(description="Python code (MUST include print() for output)") | |
| def code_interpreter(code: str) -> str: | |
| """ | |
| Executes Python code. Use for data processing, complex logic. | |
| Available: pandas, numpy, json, re, datetime | |
| CRITICAL: Always use print() to output results! | |
| """ | |
| if not isinstance(code, str): | |
| return "Error: code must be string." | |
| # Safety checks | |
| dangerous = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system', 'exec('] | |
| if any(d in code.lower() for d in dangerous): | |
| return f"Error: Dangerous operation not allowed." | |
| if 'open(' in code.lower() and any(m in code for m in ["'w'", '"w"', "'a'", '"a"']): | |
| return "Error: File writing not allowed. Use write_file tool." | |
| print(f"💻 Executing code ({len(code)} chars)...") | |
| output_stream = io.StringIO() | |
| error_stream = io.StringIO() | |
| try: | |
| with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream): | |
| safe_globals = { | |
| "pd": pd, | |
| "np": np, | |
| "json": json, | |
| "re": re, | |
| "__builtins__": __builtins__ | |
| } | |
| exec(code, safe_globals, {}) | |
| stdout = output_stream.getvalue() | |
| stderr = error_stream.getvalue() | |
| if stderr: | |
| return f"Error:\n{stderr}\n\nStdout:\n{stdout}" | |
| if stdout: | |
| return truncate_if_needed(stdout) | |
| return "Code executed but no output. Remember to use print()!" | |
| except Exception as e: | |
| return f"Execution failed:\n{traceback.format_exc()}" | |
| class ReadFileInput(BaseModel): | |
| path: str = Field(description="File path") | |
| def read_file(path: str) -> str: | |
| """Reads file content.""" | |
| if not isinstance(path, str) or not path.strip(): | |
| return "Error: Invalid path." | |
| print(f"📄 Reading: {path}") | |
| file_path = find_file(path) | |
| if not file_path: | |
| return f"Error: File not found: '{path}'\nCWD files: {os.listdir('.')}" | |
| try: | |
| content = file_path.read_text(encoding='utf-8') | |
| return truncate_if_needed(content) | |
| except UnicodeDecodeError: | |
| return f"Error: Binary file. Size: {file_path.stat().st_size} bytes. Try audio_transcription_tool for audio." | |
| except Exception as e: | |
| return f"Read error: {str(e)}" | |
| class WriteFileInput(BaseModel): | |
| path: str = Field(description="File path") | |
| content: str = Field(description="Content to write") | |
| def write_file(path: str, content: str) -> str: | |
| """Writes content to file.""" | |
| if not path or not isinstance(content, str): | |
| return "Error: Invalid inputs." | |
| print(f"✍️ Writing: {path}") | |
| try: | |
| file_path = Path.cwd() / path | |
| file_path.parent.mkdir(parents=True, exist_ok=True) | |
| file_path.write_text(content, encoding='utf-8') | |
| return f"Wrote {len(content)} chars to '{path}'." | |
| except Exception as e: | |
| return f"Write error: {str(e)}" | |
| class ListDirInput(BaseModel): | |
| path: str = Field(description="Directory path", default=".") | |
| def list_directory(path: str = ".") -> str: | |
| """Lists directory contents.""" | |
| print(f"📁 Listing: {path}") | |
| try: | |
| dir_path = Path.cwd() / path if path != "." else Path.cwd() | |
| if not dir_path.is_dir(): | |
| return f"Error: '{path}' not a directory." | |
| items = sorted(dir_path.iterdir()) | |
| if not items: | |
| return f"Directory '{path}' is empty." | |
| files, dirs = [], [] | |
| for item in items: | |
| if item.is_dir(): | |
| dirs.append(f"📁 {item.name}/") | |
| else: | |
| files.append(f"📄 {item.name} ({item.stat().st_size} bytes)") | |
| result = f"Contents of '{path}':\n\n" | |
| if dirs: | |
| result += "Directories:\n" + "\n".join(dirs) + "\n\n" | |
| if files: | |
| result += "Files:\n" + "\n".join(files) | |
| return result | |
| except Exception as e: | |
| return f"List error: {str(e)}" | |
| class AudioInput(BaseModel): | |
| file_path: str = Field(description="Audio file path") | |
| def audio_transcription_tool(file_path: str) -> str: | |
| """Transcribes audio using Whisper.""" | |
| if not file_path: | |
| return "Error: Invalid file path." | |
| print(f"🎤 Transcribing: {file_path}") | |
| if asr_pipeline is None: | |
| return "Error: ASR not available." | |
| audio_path = find_file(file_path) | |
| if not audio_path: | |
| return f"Error: Audio file not found: '{file_path}'" | |
| try: | |
| transcription = asr_pipeline(str(audio_path)) | |
| result_text = transcription.get("text", "") | |
| if not result_text: | |
| return "Error: Transcription empty." | |
| return f"Transcription:\n{truncate_if_needed(result_text)}" | |
| except Exception as e: | |
| return f"Transcription error: {str(e)}" | |
| class ImageAnalysisInput(BaseModel): | |
| file_path: str = Field(description="Image file path") | |
| query: str = Field(description="What to analyze in the image") | |
| def analyze_image(file_path: str, query: str) -> str: | |
| """ | |
| Analyzes images using Google Gemini Vision API. | |
| Use for: chess positions, diagrams, charts, photos, screenshots. | |
| Provide the EXACT file path from [FILE ATTACHED: ...] in the question. | |
| """ | |
| if not file_path or not query: | |
| return "Error: file_path and query required." | |
| print(f"🖼️ Analyzing image: {file_path}") | |
| print(f" Query: {query[:100]}...") | |
| # Try to find the file | |
| image_path = find_file(file_path) | |
| # If not found via find_file, try the path directly (for /tmp files) | |
| if not image_path and os.path.exists(file_path): | |
| image_path = Path(file_path) | |
| if not image_path or not image_path.exists(): | |
| return f"Error: Image not found at '{file_path}'. Check [FILE ATTACHED: ...] in question for correct path." | |
| print(f"✓ Found image at: {image_path}") | |
| try: | |
| GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GOOGLE_API_KEY: | |
| return "Error: GEMINI_API_KEY not set." | |
| # Load and encode image | |
| img = Image.open(image_path) | |
| print(f" Image size: {img.size}, mode: {img.mode}") | |
| # Convert to RGB if necessary | |
| if img.mode not in ['RGB', 'RGBA']: | |
| img = img.convert('RGB') | |
| # Convert to base64 | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG") | |
| img_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| print(f" Encoded image: {len(img_base64)} bytes") | |
| # Use Gemini Vision | |
| vision_llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash-exp", | |
| google_api_key=GOOGLE_API_KEY, | |
| temperature=0 | |
| ) | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": query}, | |
| { | |
| "type": "image_url", | |
| "image_url": f"data:image/jpeg;base64,{img_base64}" | |
| } | |
| ] | |
| ) | |
| print(f" Sending to Gemini Vision...") | |
| response = vision_llm.invoke([message]) | |
| print(f"✓ Got response: {len(response.content)} chars") | |
| return f"Image Analysis:\n{truncate_if_needed(response.content)}" | |
| except Exception as e: | |
| error_msg = f"Image analysis error: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| print(traceback.format_exc()) | |
| return error_msg | |
| class YoutubeInput(BaseModel): | |
| video_url: str = Field(description="YouTube URL") | |
| def get_youtube_transcript(video_url: str) -> str: | |
| """Fetches YouTube video transcript using yt-dlp.""" | |
| if not video_url: | |
| return "Error: Invalid URL." | |
| print(f"📺 YouTube transcript: {video_url}") | |
| try: | |
| # Extract video ID | |
| video_id = None | |
| if "watch?v=" in video_url: | |
| video_id = video_url.split("v=")[1].split("&")[0] | |
| elif "youtu.be/" in video_url: | |
| video_id = video_url.split("youtu.be/")[1].split("?")[0] | |
| if not video_id: | |
| return f"Error: Could not extract video ID." | |
| # Use yt-dlp to get subtitles | |
| subtitle_file = f'{video_id}.en.vtt' | |
| cmd = [ | |
| 'yt-dlp', | |
| '--skip-download', | |
| '--write-auto-subs', | |
| '--write-subs', | |
| '--sub-lang', 'en', | |
| '--sub-format', 'vtt', | |
| '--output', video_id, | |
| video_url | |
| ] | |
| print(f"🔧 Running: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=45) | |
| if result.returncode != 0: | |
| print(f"⚠️ yt-dlp stderr: {result.stderr}") | |
| return f"Error: Could not fetch subtitles - {result.stderr[:200]}" | |
| # Try to find the subtitle file (might have different naming) | |
| import glob | |
| vtt_files = glob.glob(f"{video_id}*.vtt") | |
| if not vtt_files: | |
| return "Error: No English subtitles found for this video." | |
| subtitle_file = vtt_files[0] | |
| print(f"✓ Found subtitle file: {subtitle_file}") | |
| # Read and parse VTT file | |
| with open(subtitle_file, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # Remove VTT headers and timestamps | |
| lines = content.split('\n') | |
| transcript_parts = [] | |
| for line in lines: | |
| line = line.strip() | |
| # Skip WEBVTT header, timestamps, and empty lines | |
| if (line and | |
| not line.startswith('WEBVTT') and | |
| not '-->' in line and | |
| not line.isdigit() and | |
| not line.startswith('Kind:') and | |
| not line.startswith('Language:')): | |
| transcript_parts.append(line) | |
| full_transcript = " ".join(transcript_parts) | |
| # Cleanup subtitle files | |
| for vtt_file in vtt_files: | |
| try: | |
| os.remove(vtt_file) | |
| except: | |
| pass | |
| if not full_transcript: | |
| return "Error: Transcript was empty." | |
| print(f"✓ Transcript extracted: {len(full_transcript)} chars") | |
| return f"Transcript:\n{truncate_if_needed(full_transcript)}" | |
| except subprocess.TimeoutExpired: | |
| return "Error: yt-dlp timed out after 45 seconds." | |
| except FileNotFoundError: | |
| return "Error: yt-dlp not installed. Add 'yt-dlp' to requirements.txt" | |
| except Exception as e: | |
| print(f"❌ Error: {str(e)}") | |
| print(traceback.format_exc()) | |
| return f"Transcript error: {str(e)}" | |
| class ScrapeInput(BaseModel): | |
| url: str = Field(description="URL (must start with http:// or https://)") | |
| query: str = Field(description="What to find on the page") | |
| def scrape_and_retrieve(url: str, query: str) -> str: | |
| """ | |
| Scrapes webpage and uses RAG to find relevant info. | |
| Use when you need specific info from a known URL. | |
| """ | |
| if not url.startswith(('http://', 'https://')): | |
| return f"Error: Invalid URL format." | |
| if not query: | |
| return "Error: Query required." | |
| if global_embeddings is None or global_text_splitter is None: | |
| if not initialize_rag_components(): | |
| return "Error: RAG not initialized." | |
| print(f"🌐 Scraping: {url}") | |
| try: | |
| headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'} | |
| response = requests.get(url, headers=headers, timeout=20) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe"]): | |
| tag.extract() | |
| main = soup.find('main') or soup.find('article') or soup.body | |
| if not main: | |
| return "Error: No main content found." | |
| text = main.get_text(separator='\n', strip=True) | |
| lines = [l.strip() for l in text.splitlines() if l.strip()] | |
| text = '\n'.join(lines) | |
| if len(text) < 50: | |
| return f"Error: Content too short ({len(text)} chars)." | |
| chunks = global_text_splitter.split_text(text) | |
| if not chunks: | |
| return "Error: Could not chunk text." | |
| docs = [Document(page_content=c, metadata={"source": url}) for c in chunks] | |
| db = FAISS.from_documents(docs, global_embeddings) | |
| retriever = db.as_retriever(search_kwargs={"k": 5}) | |
| retrieved = retriever.invoke(query) | |
| if not retrieved: | |
| return f"No relevant info found for: '{query}'" | |
| context = "\n\n---\n\n".join([f"[Chunk {i+1}]\n{d.page_content}" for i, d in enumerate(retrieved)]) | |
| return truncate_if_needed(f"From {url}:\n\n{context}") | |
| except requests.RequestException as e: | |
| return f"Fetch error: {str(e)}" | |
| except Exception as e: | |
| return f"Scrape error: {str(e)}\n{traceback.format_exc()}" | |
| class FinalAnswerInput(BaseModel): | |
| answer: str = Field(description="Final answer - EXACTLY what was asked, nothing more") | |
| def final_answer_tool(answer: str) -> str: | |
| """ | |
| Submit final answer. CRITICAL RULES: | |
| 1. ALWAYS call validate_answer() first | |
| 2. Answer must be EXACTLY what was asked | |
| 3. NO conversational text | |
| 4. NO explanations | |
| 5. Match requested format exactly | |
| """ | |
| if not isinstance(answer, str): | |
| answer = str(answer) | |
| print(f"✅ FINAL ANSWER SUBMITTED: {answer}") | |
| return answer | |
| # ============================================================================= | |
| # DEFINED TOOLS LIST | |
| # ============================================================================= | |
| defined_tools = [ | |
| # Planning & Reflection | |
| think_through_logic, | |
| create_plan, | |
| reflect_on_progress, | |
| validate_answer, | |
| # Core tools | |
| search_tool, | |
| calculator, | |
| code_interpreter, | |
| # File operations | |
| read_file, | |
| write_file, | |
| list_directory, | |
| # Specialized | |
| audio_transcription_tool, | |
| analyze_image, # NEW: Image analysis tool | |
| get_youtube_transcript, | |
| scrape_and_retrieve, | |
| # Final | |
| final_answer_tool | |
| ] | |
| # ============================================================================= | |
| # AGENT STATE | |
| # ============================================================================= | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[AnyMessage], add_messages] | |
| turn: int | |
| has_plan: bool | |
| consecutive_errors: int | |
| tool_history: List[str] | |
| last_tool_was_thinking: bool | |
| # ============================================================================= | |
| # ENHANCED FALLBACK PARSER | |
| # ============================================================================= | |
| def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]: | |
| """Enhanced parser with multiple strategies.""" | |
| print(f"🔧 Fallback parsing (first 300 chars):\n{content[:300]}") | |
| tool_name = None | |
| tool_input = None | |
| # STRATEGY 1: Groq's <function=name{...}> format | |
| groq_match = re.search(r"<function=(\w+)\s*(\{.*?\})\s*(?:>|</function>)", content, re.DOTALL) | |
| if groq_match: | |
| try: | |
| tool_name = groq_match.group(1).strip() | |
| json_str = groq_match.group(2).strip() | |
| json_str = json_str.encode().decode('unicode_escape') | |
| tool_input = json.loads(json_str) | |
| print(f"✓ Parsed Groq format: {tool_name}") | |
| except: | |
| tool_name = None | |
| # STRATEGY 2: Standard <function(name)>{...} format | |
| if not tool_name: | |
| func_match = re.search(r"<function[(=]\s*([^)]+)\s*[)>](.*)", content, re.DOTALL | re.IGNORECASE) | |
| if func_match: | |
| try: | |
| tool_name = func_match.group(1).strip().replace("'", "").replace('"', '') | |
| remaining = func_match.group(2) | |
| json_start = remaining.find('{') | |
| if json_start != -1: | |
| json_str = remaining[json_start:].strip().rstrip(',') | |
| tool_input = json.loads(json_str) | |
| print(f"✓ Parsed standard format: {tool_name}") | |
| except: | |
| tool_name = None | |
| # STRATEGY 3: Tool mention with code block → wrap in code_interpreter | |
| if not tool_name and "```python" in content: | |
| try: | |
| code_match = re.search(r"```python\n(.*?)```", content, re.DOTALL) | |
| if code_match: | |
| code = code_match.group(1).strip() | |
| tool_name = "code_interpreter" | |
| tool_input = {"code": code} | |
| print(f"✓ Extracted Python code → code_interpreter") | |
| except: | |
| pass | |
| # STRATEGY 4: Direct tool mention → create minimal valid call | |
| if not tool_name: | |
| for tool in tools: | |
| if tool.name.lower() in content.lower(): | |
| tool_name = tool.name | |
| tool_input = {} | |
| # Try to extract arguments from content | |
| if tool.args_schema: | |
| schema = tool.args_schema.model_json_schema() | |
| for prop in schema.get('properties', {}).keys(): | |
| if prop in schema.get('required', []): | |
| # Use placeholder | |
| tool_input[prop] = "auto_extracted" | |
| print(f"✓ Found mention of '{tool_name}' → creating default call") | |
| break | |
| # STRATEGY 5: Emergency - if no tool detected, force a reasonable one | |
| if not tool_name: | |
| # If content looks like reasoning, use think_through_logic | |
| if len(content) > 50 and not any(kw in content.lower() for kw in ["error", "failed", "invalid"]): | |
| tool_name = "think_through_logic" | |
| tool_input = {"reasoning": content[:150]} | |
| print(f"⚠️ No tool detected → forcing think_through_logic") | |
| # Validate and create tool call | |
| if tool_name and tool_input is not None: | |
| matching_tools = [t for t in tools if t.name == tool_name] | |
| if matching_tools: | |
| return [ToolCall(name=tool_name, args=tool_input, id=str(uuid.uuid4()))] | |
| else: | |
| print(f"❌ Tool '{tool_name}' not in available tools") | |
| print("❌ All parsing strategies failed") | |
| return [] | |
| # ============================================================================= | |
| # CONDITIONAL EDGE FUNCTION | |
| # ============================================================================= | |
| def should_continue(state: AgentState): | |
| """Decide next step with robust logic.""" | |
| messages = state.get('messages', []) | |
| if not messages: | |
| return "agent" | |
| last_message = messages[-1] | |
| current_turn = state.get('turn', 0) | |
| # Debug: Print what we're checking | |
| msg_type = type(last_message).__name__ | |
| print(f"📍 Conditional check - Turn {current_turn}, Last msg type: {msg_type}") | |
| # 1. Check turn limit | |
| if current_turn >= MAX_TURNS: | |
| print(f"🛑 Max turns ({MAX_TURNS}) reached") | |
| return END | |
| # 2. If last message is ToolMessage, agent needs to process it | |
| if isinstance(last_message, ToolMessage): | |
| print(f"📨 Tool result received from '{last_message.name}' → back to agent") | |
| return "agent" | |
| # 3. If last message is AIMessage with tool calls | |
| if isinstance(last_message, AIMessage) and last_message.tool_calls: | |
| # Only check the FIRST tool call, not all of them | |
| first_tool = last_message.tool_calls[0] | |
| tool_name = first_tool.get("name", "") | |
| if tool_name == "final_answer_tool": | |
| return END | |
| else: | |
| return "tools" | |
| # 4. If AIMessage but no tool calls (reasoning text) | |
| if isinstance(last_message, AIMessage) and not last_message.tool_calls: | |
| # Check for consecutive AI messages (loop) | |
| if len(messages) >= 2 and isinstance(messages[-2], AIMessage) and not messages[-2].tool_calls: | |
| print(f"⚠️ Loop detected: 2 consecutive AI messages without tools") | |
| return END | |
| print(f"💭 AI message without tool call → continuing to agent (will force tool)") | |
| return "agent" | |
| # 5. Default: continue to agent | |
| print(f"🔄 Default → continuing to agent") | |
| # ============================================================================= | |
| # ENHANCED AGENT CLASS | |
| # ============================================================================= | |
| class PlanningReflectionAgent: | |
| def __init__(self): | |
| print("🧠 PlanningReflectionAgent initializing...") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY not set!") | |
| HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not HUGGINGFACEHUB_API_TOKEN: | |
| raise ValueError("HUGGINGFACEHUB_API_TOKEN secret is not set! Please add it to your Space secrets.") | |
| GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GOOGLE_API_KEY: | |
| raise ValueError("GOOGLE_API_KEY not set!") | |
| self.tools = defined_tools | |
| # Initialize RAG | |
| if not initialize_rag_components(): | |
| print("⚠️ RAG components failed to initialize.") | |
| # Build tool descriptions | |
| tool_desc_list = [] | |
| for tool in self.tools: | |
| if tool.args_schema: | |
| schema = tool.args_schema.model_json_schema() | |
| args_desc = [f" - {p}: {d.get('description', '')}" | |
| for p, d in schema.get('properties', {}).items()] | |
| desc = f"- {tool.name}:\n {tool.description}\n" + "\n".join(args_desc) | |
| else: | |
| desc = f"- {tool.name}: {tool.description}" | |
| tool_desc_list.append(desc) | |
| tool_descriptions = "\n".join(tool_desc_list) | |
| # ULTRA-AGGRESSIVE SYSTEM PROMPT | |
| self.system_prompt = f"""You are an elite AI agent for GAIA benchmark. Your ONLY job: provide the EXACT answer requested. | |
| ═══════════════════════════════════════════════════════════════ | |
| ⚠️ ABSOLUTE RULES - VIOLATE THESE AND YOU FAIL: | |
| ═══════════════════════════════════════════════════════════════ | |
| 1. **EVERY TURN MUST CALL EXACTLY ONE TOOL** - No exceptions | |
| 2. **NEVER OUTPUT REASONING TEXT WITHOUT A TOOL CALL** - You will fail | |
| 3. **IDENTIFY QUESTION TYPE FIRST** - Logic? Factual? Data? Math? | |
| 4. **LOGIC PUZZLES**: think_through_logic → calculator (if needed) → validate → final_answer | |
| 5. **FACTUAL QUESTIONS**: search_tool → validate → final_answer | |
| 6. **DATA QUESTIONS**: read_file → code_interpreter → validate → final_answer | |
| 7. **ALWAYS VALIDATE**: Call validate_answer() before final_answer_tool() | |
| 8. **FINAL ANSWER FORMAT**: EXACTLY what was asked. NO "The answer is..." or explanations | |
| ═══════════════════════════════════════════════════════════════ | |
| 📋 QUESTION TYPE GUIDE: | |
| ═══════════════════════════════════════════════════════════════ | |
| **RIDDLES/LOGIC PUZZLES** (No web search needed): | |
| - Brain teasers, puzzles, logical deduction | |
| - Strategy: think_through_logic → calculator (if math) → validate → final_answer | |
| - Example: "If 200 coins, 30 face-down, divide into equal piles..." | |
| Turn 1: think_through_logic("Adventurer takes 30 coins and flips them") | |
| Turn 2: calculator("30") [if needed] | |
| Turn 3: validate_answer("30", question) | |
| Turn 4: final_answer_tool("30") | |
| **FACTUAL/RESEARCH** (Need web): | |
| - Who, what, when, where questions | |
| - Strategy: search_tool → scrape_and_retrieve → validate → final_answer | |
| - Example: "What was Einstein's birthplace population in 1900?" | |
| Turn 1: search_tool("Albert Einstein birthplace") | |
| Turn 2: search_tool("Ulm Germany population 1900") | |
| Turn 3: validate_answer("50000", question) | |
| Turn 4: final_answer_tool("50000") | |
| **DATA ANALYSIS** (Need files): | |
| - CSV/Excel questions | |
| - Strategy: list_directory → read_file → code_interpreter → validate → final_answer | |
| **SIMPLE MATH**: | |
| - Calculations | |
| - Strategy: calculator() → validate_answer() → final_answer_tool() | |
| ═══════════════════════════════════════════════════════════════ | |
| 🎓 CRITICAL EXAMPLES: | |
| ═══════════════════════════════════════════════════════════════ | |
| Example 1: Logic Puzzle | |
| Q: "Coin riddle with 200 coins, 30 face-down..." | |
| ✅ CORRECT: | |
| Turn 1: think_through_logic("Take 30 coins, flip all") | |
| Turn 2: validate_answer("30", "coin riddle...") | |
| Turn 3: final_answer_tool("30") | |
| ❌ WRONG: | |
| Turn 1: [reasoning text without tool] ← FAILS! | |
| Example 2: Letter Bank Puzzle | |
| Q: "Use letters to spell sentences, which letters need changing?" | |
| ✅ CORRECT: | |
| Turn 1: code_interpreter("code to count letters...") | |
| Turn 2: validate_answer("A, B, C", question) | |
| Turn 3: final_answer_tool("A, B, C") | |
| Example 3: Math Problem | |
| Q: "System of equations to solve..." | |
| ✅ CORRECT: | |
| Turn 1: code_interpreter("import numpy; solve equations...") | |
| Turn 2: validate_answer("0, 1, 2", question) | |
| Turn 3: final_answer_tool("0, 1, 2") | |
| ═══════════════════════════════════════════════════════════════ | |
| 📚 AVAILABLE TOOLS: | |
| ═══════════════════════════════════════════════════════════════ | |
| {tool_descriptions} | |
| ═══════════════════════════════════════════════════════════════ | |
| ⚡ EXECUTION RULES: | |
| ═══════════════════════════════════════════════════════════════ | |
| - If you output text without a tool call, you have FAILED | |
| - If you're unsure, use think_through_logic() to organize thoughts | |
| - ALWAYS call a tool - preferably the right one for the question type | |
| - After EVERY tool result, decide: "Do I have the answer? → validate → submit" | |
| - If stuck after 3 turns: call reflect_on_progress() | |
| REMEMBER: One tool per turn. No reasoning without tools. Exact answer format. | |
| ═══════════════════════════════════════════════════════════════ | |
| """ | |
| #. Initialize the LLM () | |
| # print("Initializing Groq LLM...") | |
| # try: | |
| # self.llm_with_tools = ChatGroq( | |
| # temperature=0, | |
| # groq_api_key=GROQ_API_KEY, | |
| # model_name="llama-3.1-8b-instant", | |
| # max_tokens=4096, | |
| # timeout=60 | |
| # ).bind_tools(self.tools, tool_choice="auto") | |
| # print("✅ LLM initialized without FORCED tool usage.") | |
| # | |
| # except Exception as e: | |
| # print(f"❌ Error initializing HuggingFace: {e}") | |
| # raise | |
| # print("Initializing LLM Endpoint...") | |
| print("Initializing HuggingFace LLM...") | |
| llm = HuggingFaceEndpoint( | |
| repo_id="meta-llama/Llama-3.1-70B-Instruct", # Free on HF Inference API | |
| huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN, | |
| max_new_tokens=4096, | |
| temperature=0.01, | |
| ) | |
| chat_llm = ChatHuggingFace(llm=llm) | |
| print("✅ HuggingFace LLM Endpoint initialized.") | |
| # Bind tools to the LLM | |
| self.llm_with_tools = chat_llm.bind_tools(self.tools) | |
| print("✅ Tools bound to LLM.") | |
| # print("Initializing Google Gemini LLM...") | |
| # try: | |
| # self.llm_with_tools = ChatGoogleGenerativeAI( | |
| # model="gemini-2.5-flash", # Latest model | |
| # google_api_key=GOOGLE_API_KEY, | |
| # temperature=0, | |
| # max_output_tokens=8192, | |
| # timeout=60, | |
| # convert_system_message_to_human=True # Important for Gemini | |
| # ).bind_tools(self.tools, tool_choice="auto") | |
| # print("✅ Gemini LLM initialized.") | |
| # except Exception as e: | |
| # print(f"❌ Error initializing Gemini: {e}") | |
| # raise | |
| # Agent Node with AGGRESSIVE tool forcing | |
| def agent_node(state: AgentState): | |
| current_turn = state.get('turn', 0) + 1 | |
| print(f"\n{'='*70}") | |
| print(f"🤖 AGENT TURN {current_turn}/{MAX_TURNS}") | |
| print('='*70) | |
| if current_turn > MAX_TURNS: | |
| return { | |
| "messages": [SystemMessage(content="Max turns reached.")], | |
| "turn": current_turn | |
| } | |
| # Check if we should force reflection | |
| consecutive_errors = state.get('consecutive_errors', 0) | |
| should_reflect = (current_turn > 5 and current_turn % REFLECT_EVERY_N_TURNS == 0) or consecutive_errors >= 3 | |
| messages_to_send = state["messages"].copy() | |
| # Add tool-forcing message if last turn had no tool call | |
| if len(messages_to_send) >= 2: | |
| last_msg = messages_to_send[-1] | |
| if isinstance(last_msg, AIMessage) and not last_msg.tool_calls: | |
| force_msg = SystemMessage( | |
| content="⚠️ CRITICAL: You MUST call a tool this turn. NO reasoning text. Pick the most appropriate tool and call it now." | |
| ) | |
| messages_to_send.append(force_msg) | |
| print("🚨 Injecting tool-forcing message") | |
| # Add reflection hint if needed | |
| if should_reflect: | |
| hint = SystemMessage( | |
| content="⚠️ HINT: Multiple turns without progress. Consider calling reflect_on_progress() or try a different approach." | |
| ) | |
| messages_to_send.append(hint) | |
| print("🤔 Injecting reflection hint") | |
| # Invoke LLM with retries and fallback | |
| max_retries = 3 | |
| ai_message = None | |
| for attempt in range(max_retries): | |
| try: | |
| ai_message = self.llm_with_tools.invoke(messages_to_send) | |
| # If we got a valid response with tool calls, break | |
| if ai_message.tool_calls: | |
| break | |
| # If no tool calls, this is a problem | |
| print(f"⚠️ LLM returned no tool calls on attempt {attempt+1}") | |
| except Exception as e: | |
| error_str = str(e) | |
| print(f"⚠️ LLM attempt {attempt+1}/{max_retries} failed: {error_str[:200]}") | |
| # If tool_use_failed, try without strict binding | |
| if "tool_use_failed" in error_str and attempt < max_retries - 1: | |
| print("🔧 Trying without strict tool enforcement...") | |
| try: | |
| simple_llm = ChatGroq( | |
| temperature=0, | |
| groq_api_key=os.getenv("GROQ_API_KEY"), | |
| model_name="llama-3.3-70b-versatile", | |
| max_tokens=4096, | |
| timeout=60 | |
| ) | |
| # Add explicit tool forcing to the message | |
| force_tool_msg = SystemMessage( | |
| content="You MUST call a tool. Respond with a tool call, not reasoning text." | |
| ) | |
| ai_message = simple_llm.invoke(messages_to_send + [force_tool_msg]) | |
| # Try to parse tool calls from content | |
| if ai_message.content and not ai_message.tool_calls: | |
| parsed = parse_tool_call_from_string(ai_message.content, self.tools) | |
| if parsed: | |
| ai_message.tool_calls = parsed | |
| ai_message.content = "" | |
| print("✓ Fallback parsing succeeded") | |
| break | |
| except Exception as e2: | |
| print(f"⚠️ Fallback also failed: {e2}") | |
| if attempt == max_retries - 1: | |
| # Last resort: inject a default tool call | |
| print("🚨 All attempts failed - forcing think_through_logic") | |
| ai_message = AIMessage( | |
| content="", | |
| tool_calls=[ToolCall( | |
| name="think_through_logic", | |
| args={"reasoning": "Processing question"}, | |
| id=str(uuid.uuid4()) | |
| )] | |
| ) | |
| else: | |
| time.sleep(2 ** attempt) | |
| # If still no tool calls after all attempts, force one | |
| if not ai_message.tool_calls: | |
| if isinstance(ai_message.content, str) and ai_message.content.strip(): | |
| # Try one more parse | |
| parsed = parse_tool_call_from_string(ai_message.content, self.tools) | |
| if parsed: | |
| ai_message.tool_calls = parsed | |
| ai_message.content = "" | |
| print("✓ Final parse succeeded") | |
| else: | |
| # Absolute last resort | |
| print("🚨 EMERGENCY: Forcing think_through_logic") | |
| ai_message.tool_calls = [ToolCall( | |
| name="think_through_logic", | |
| args={"reasoning": "analyzing question"}, | |
| id=str(uuid.uuid4()) | |
| )] | |
| ai_message.content = "" | |
| # Track tool usage | |
| tool_history = state.get('tool_history', []) | |
| has_plan = state.get('has_plan', False) | |
| if ai_message.tool_calls: | |
| tool_name = ai_message.tool_calls[0]['name'] | |
| print(f"🔧 Tool Call: {tool_name}") | |
| tool_history.append(tool_name) | |
| if tool_name == "create_plan": | |
| has_plan = True | |
| else: | |
| print(f"⚠️ No tool call (this shouldn't happen!)") | |
| print(f"💭 Content: {ai_message.content[:200]}...") | |
| return { | |
| "messages": [ai_message], | |
| "turn": current_turn, | |
| "has_plan": has_plan, | |
| "tool_history": tool_history, | |
| "last_tool_was_thinking": ai_message.tool_calls and ai_message.tool_calls[0]['name'] == 'think_through_logic' | |
| } | |
| # Tool Node with Error Tracking (FIXED) | |
| def tool_node_wrapper(state: AgentState): | |
| """Executes tools and tracks errors.""" | |
| print(f"🔧 Executing tools...") | |
| # Create fresh ToolNode instance | |
| tool_executor = ToolNode(self.tools) | |
| # Invoke properly | |
| result = tool_executor.invoke(state) | |
| # Track errors | |
| consecutive_errors = state.get('consecutive_errors', 0) | |
| if result.get('messages'): | |
| last_msg = result['messages'][-1] | |
| if isinstance(last_msg, ToolMessage): | |
| if "Error" in last_msg.content or "error" in last_msg.content.lower(): | |
| consecutive_errors += 1 | |
| print(f"⚠️ Tool error detected (consecutive: {consecutive_errors})") | |
| else: | |
| consecutive_errors = 0 | |
| result['consecutive_errors'] = consecutive_errors | |
| return result | |
| # Build Graph | |
| print("Building graph...") | |
| graph_builder = StateGraph(AgentState) | |
| graph_builder.add_node("agent", agent_node) | |
| graph_builder.add_node("tools", tool_node_wrapper) | |
| graph_builder.add_edge(START, "agent") | |
| graph_builder.add_conditional_edges( | |
| "agent", | |
| should_continue, | |
| { | |
| "tools": "tools", | |
| "agent": "agent", | |
| END: END | |
| } | |
| ) | |
| graph_builder.add_edge("tools", "agent") | |
| self.graph = graph_builder.compile() | |
| print("✅ Graph compiled successfully.") | |
| def __call__(self, question: str, file_path: str = None) -> str: | |
| """Execute agent on a question.""" | |
| print(f"\n{'='*70}") | |
| print(f"🎯 NEW QUESTION") | |
| print(f"{'='*70}") | |
| print(f"Q: {question[:200]}{'...' if len(question) > 200 else ''}") | |
| if file_path: | |
| print(f"📎 File attached: {file_path}") | |
| print(f"{'='*70}\n") | |
| # Enhanced question context with file information | |
| question_text = question | |
| if file_path: | |
| file_ext = Path(file_path).suffix.lower() | |
| file_type = "unknown" | |
| if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']: | |
| file_type = "image" | |
| elif file_ext in ['.mp3', '.wav', '.m4a', '.flac']: | |
| file_type = "audio" | |
| elif file_ext in ['.csv', '.xlsx', '.xls']: | |
| file_type = "data" | |
| elif file_ext in ['.txt', '.pdf', '.doc', '.docx']: | |
| file_type = "document" | |
| question_text += f"\n\n[FILE ATTACHED: {file_path}]" | |
| question_text += f"\n[FILE TYPE: {file_type}]" | |
| question_text += f"\nIMPORTANT: Use the appropriate tool to access this file first!" | |
| graph_input = { | |
| "messages": [ | |
| SystemMessage(content=self.system_prompt), | |
| HumanMessage(content=question_text) | |
| ], | |
| "file_path": file_path, | |
| "turn": 0, | |
| "has_plan": False, | |
| "consecutive_errors": 0, | |
| "tool_history": [], | |
| "last_tool_was_thinking": False | |
| } | |
| final_answer = "AGENT FAILED TO PRODUCE ANSWER" | |
| all_messages = [] | |
| try: | |
| config = {"recursion_limit": MAX_TURNS + 10} | |
| for event in self.graph.stream(graph_input, stream_mode="values", config=config): | |
| if not event.get('messages'): | |
| continue | |
| all_messages = event["messages"] | |
| last_message = all_messages[-1] | |
| # Check for final answer | |
| if isinstance(last_message, AIMessage) and last_message.tool_calls: | |
| for tool_call in last_message.tool_calls: | |
| if tool_call.get("name") == "final_answer_tool": | |
| args = tool_call.get('args', {}) | |
| if 'answer' in args: | |
| final_answer = args['answer'] | |
| print(f"\n{'='*70}") | |
| print(f"✅ FINAL ANSWER: '{final_answer}'") | |
| print(f"{'='*70}\n") | |
| break | |
| elif isinstance(last_message, ToolMessage): | |
| preview = last_message.content[:200].replace('\n', ' ') | |
| print(f"📊 Tool '{last_message.name}' result: {preview}...") | |
| elif isinstance(last_message, AIMessage) and not last_message.tool_calls: | |
| print(f"💭 AI: {last_message.content[:200]}...") | |
| # If no final answer, try to extract from tool messages | |
| if final_answer == "AGENT FAILED TO PRODUCE ANSWER": | |
| print("⚠️ No final_answer_tool called. Checking tool results...") | |
| for msg in reversed(all_messages): | |
| if isinstance(msg, ToolMessage): | |
| if msg.name in ["calculator", "think_through_logic", "code_interpreter"]: | |
| content = msg.content.strip() | |
| # Look for short, answer-like content | |
| if content and len(content) < 200 and not content.startswith("Error"): | |
| # Extract just the result part | |
| lines = content.split('\n') | |
| for line in reversed(lines): | |
| if line.strip() and not line.startswith(('✅', '⚠️', 'Next', 'Remember')): | |
| final_answer = line.strip() | |
| print(f"📝 Extracted from {msg.name}: '{final_answer}'") | |
| break | |
| break | |
| # Clean the answer | |
| cleaned = str(final_answer).strip() | |
| # Remove prefixes | |
| prefixes = [ | |
| "the answer is:", "here is the answer:", "based on", | |
| "final answer:", "answer:", "the final answer is:", | |
| "my answer is:", "according to", "i found that", | |
| "the result is:", "result:" | |
| ] | |
| for prefix in prefixes: | |
| if cleaned.lower().startswith(prefix.lower()): | |
| potential = cleaned[len(prefix):].strip() | |
| if potential: | |
| cleaned = potential | |
| break | |
| # Remove code fences and quotes | |
| cleaned = remove_fences_simple(cleaned) | |
| while cleaned.startswith("`") and cleaned.endswith("`"): | |
| cleaned = cleaned[1:-1].strip() | |
| if (cleaned.startswith('"') and cleaned.endswith('"')) or \ | |
| (cleaned.startswith("'") and cleaned.endswith("'")): | |
| cleaned = cleaned[1:-1].strip() | |
| # Remove trailing period for short answers | |
| if cleaned.endswith('.') and len(cleaned.split()) < 10: | |
| cleaned = cleaned[:-1] | |
| print(f"\n{'='*70}") | |
| print(f"🎉 RETURNING ANSWER") | |
| print(f"{'='*70}") | |
| print(f"{cleaned}") | |
| print(f"{'='*70}\n") | |
| return cleaned | |
| except Exception as e: | |
| print(f"❌ Graph error: {e}") | |
| print(traceback.format_exc()) | |
| return f"AGENT ERROR: {e}" | |
| # ============================================================================= | |
| # GLOBAL AGENT INSTANTIATION | |
| # ============================================================================= | |
| agent = None | |
| try: | |
| initialize_rag_components() | |
| agent = PlanningReflectionAgent() | |
| print("✅ Global PlanningReflectionAgent instantiated.") | |
| # Verify it's callable | |
| if not callable(agent): | |
| print("❌ ERROR: Agent not callable!") | |
| agent = None | |
| else: | |
| print("✅ Agent is callable.") | |
| if asr_pipeline is None: | |
| print("⚠️ ASR Pipeline not loaded.") | |
| except Exception as e: | |
| print(f"❌ FATAL: Agent initialization failed: {e}") | |
| traceback.print_exc() | |
| agent = None | |
| # ============================================================================= | |
| # RUN AND SUBMIT FUNCTION | |
| # ============================================================================= | |
| def run_and_submit_all(profile: gr.OAuthProfile | None): | |
| """ | |
| Fetches all questions, runs the BasicAgent on them, submits all answers, | |
| and displays the results. | |
| """ | |
| space_id = os.getenv("SPACE_ID") | |
| if profile: | |
| username = f"{profile.username}" | |
| print(f"User logged in: {username}") | |
| else: | |
| print("User not logged in.") | |
| return "Please Login to Hugging Face with the button.", None | |
| # Use the globally instantiated agent | |
| global agent | |
| if agent is None: | |
| error_msg = "FATAL: Agent failed to initialize at startup. Check logs for errors." | |
| print(error_msg) | |
| return error_msg, None | |
| print("✅ Using globally instantiated PlanningReflectionAgent") | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| submit_url = f"{api_url}/submit" | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| print(agent_code) | |
| # 2. Fetch Questions | |
| print(f"\n{'='*70}") | |
| print(f"📥 FETCHING QUESTIONS") | |
| print(f"{'='*70}") | |
| print(f"Fetching questions from: {questions_url}") | |
| try: | |
| response = requests.get(questions_url, timeout=15) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| print("Fetched questions list is empty.") | |
| return "Fetched questions list is empty or invalid format.", None | |
| print(f"✅ Fetched {len(questions_data)} questions.") | |
| print(f"{'='*70}\n") | |
| except requests.exceptions.RequestException as e: | |
| print(f"❌ Error fetching questions: {e}") | |
| return f"Error fetching questions: {e}", None | |
| except requests.exceptions.JSONDecodeError as e: | |
| print(f"❌ Error decoding JSON response from questions endpoint: {e}") | |
| print(f"Response text: {response.text[:500]}") | |
| return f"Error decoding server response for questions: {e}", None | |
| except Exception as e: | |
| print(f"❌ An unexpected error occurred fetching questions: {e}") | |
| return f"An unexpected error occurred fetching questions: {e}", None | |
| # 3. Run your Agent | |
| print(f"\n{'='*70}") | |
| print(f"🚀 STARTING EVALUATION") | |
| print(f"{'='*70}") | |
| print(f"Total questions to process: {len(questions_data)}") | |
| print(f"{'='*70}\n") | |
| results_log = [] | |
| answers_payload = [] | |
| for idx, item in enumerate(questions_data, 1): | |
| print(f"\n{'='*70}") | |
| print(f"📝 PROCESSING QUESTION {idx}/{len(questions_data)}") | |
| print(f"{'='*70}") | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| correct_answer = item.get("answer", "N/A") # Get correct answer from API | |
| # Initialize file variables for the current question | |
| # Try to download file for EVERY task (not just if file_path exists) | |
| file_download_url = f"{DEFAULT_API_URL}/files/{task_id}" | |
| local_file_path = None | |
| try: | |
| file_response = requests.get(file_download_url, timeout=15) | |
| if file_response.status_code == 200: | |
| # Get filename from Content-Disposition header if available | |
| filename = None | |
| if 'Content-Disposition' in file_response.headers: | |
| cd = file_response.headers['Content-Disposition'] | |
| filename_match = re.findall('filename="?([^"]+)"?', cd) | |
| if filename_match: | |
| filename = filename_match[0] | |
| # If no filename, use task_id with extension from Content-Type | |
| if not filename: | |
| content_type = file_response.headers.get('Content-Type', '') | |
| ext_map = { | |
| 'image/png': '.png', | |
| 'image/jpeg': '.jpg', | |
| 'image/gif': '.gif', | |
| 'audio/mpeg': '.mp3', | |
| 'audio/wav': '.wav', | |
| 'text/plain': '.txt', | |
| 'text/csv': '.csv', | |
| 'application/pdf': '.pdf', | |
| 'text/x-python': '.py', | |
| 'application/x-python-code': '.py', | |
| } | |
| ext = ext_map.get(content_type, '') | |
| filename = f"{task_id}{ext}" | |
| # Save to current directory | |
| local_file_path = filename | |
| with open(local_file_path, 'wb') as f: | |
| for chunk in file_response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| file_size = os.path.getsize(local_file_path) | |
| abs_path = os.path.abspath(local_file_path) | |
| print(f"✅ Downloaded: {filename} ({file_size} bytes)") | |
| print(f" Saved to: {abs_path}") | |
| elif file_response.status_code == 404: | |
| print(f"ℹ️ No file found for task {task_id} (404), proceeding without file.") | |
| else: | |
| print(f"⚠️ Warning: File download for {task_id} failed with status {file_response.status_code}") | |
| except Exception as e: | |
| # Handles any other unexpected errors | |
| print(f"\n❌ An unexpected error occurred: {e}") | |
| try: | |
| # Pass file_path to agent | |
| submitted_answer = agent(question_text, local_file_path) | |
| answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
| # Check if answer is correct | |
| is_correct = submitted_answer.strip().lower() == correct_answer.strip().lower() | |
| correctness = "✅ CORRECT" if is_correct else "❌ WRONG" | |
| # Log with correctness indicator | |
| print(f"\n{correctness} - Task {task_id}") | |
| print(f" Submitted: '{submitted_answer}'") | |
| print(f" Expected: '{correct_answer}'") | |
| results_log.append({ | |
| "Task ID": task_id, | |
| "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, | |
| "Submitted Answer": submitted_answer, | |
| "Correct Answer": correct_answer, | |
| "Status": "✅" if is_correct else "❌" | |
| }) | |
| print(f"✅ Question {idx}/{len(questions_data)} completed") | |
| except Exception as e: | |
| print(f"❌ Error running agent on task {task_id}: {e}") | |
| print(traceback.format_exc()) | |
| results_log.append({ | |
| "Task ID": task_id, | |
| "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text, | |
| "Submitted Answer": f"AGENT ERROR: {e}", | |
| "Correct Answer": correct_answer, | |
| "Status": "❌" | |
| }) | |
| # Continue with other questions even if one fails | |
| answers_payload.append({"task_id": task_id, "submitted_answer": f"ERROR: {str(e)[:100]}"}) | |
| # Summary after all questions processed | |
| print(f"\n{'='*70}") | |
| print(f"✅ ALL QUESTIONS PROCESSED") | |
| print(f"{'='*70}") | |
| print(f"Total answers collected: {len(answers_payload)}") | |
| # Calculate pre-submission accuracy | |
| correct_count = sum(1 for log in results_log if log.get("Status") == "✅") | |
| total_count = len(results_log) | |
| accuracy = (correct_count / total_count * 100) if total_count > 0 else 0 | |
| print(f"\n{'='*70}") | |
| print(f"📊 PRE-SUBMISSION SUMMARY") | |
| print(f"{'='*70}") | |
| print(f"Correct: {correct_count}/{total_count} ({accuracy:.1f}%)") | |
| print(f"{'='*70}\n") | |
| if not answers_payload: | |
| print("⚠️ Agent did not produce any answers to submit.") | |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
| # 4. Prepare Submission | |
| submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
| # 5. Submit | |
| print(f"\n{'='*70}") | |
| print(f"📤 SUBMITTING TO API") | |
| print(f"{'='*70}") | |
| print(f"URL: {submit_url}") | |
| print(f"Username: {username}") | |
| print(f"Answers to submit: {len(answers_payload)}") | |
| print(f"{'='*70}\n") | |
| try: | |
| print("⏳ Sending POST request...") | |
| response = requests.post(submit_url, json=submission_data, timeout=60) | |
| print(f"✅ Got response: Status {response.status_code}") | |
| response.raise_for_status() | |
| result_data = response.json() | |
| print(f"\n{'='*70}") | |
| print(f"📊 SUBMISSION RESULTS") | |
| print(f"{'='*70}") | |
| print(f"Response data: {result_data}") | |
| print(f"{'='*70}\n") | |
| final_status = ( | |
| f"Submission Successful!\n" | |
| f"User: {result_data.get('username')}\n" | |
| f"Overall Score: {result_data.get('score', 'N/A')}% " | |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
| f"Message: {result_data.get('message', 'No message received.')}" | |
| ) | |
| print(final_status) | |
| print("="*70) | |
| print("✅ Submission successful.") | |
| results_df = pd.DataFrame(results_log) | |
| return final_status, results_df | |
| except requests.exceptions.HTTPError as e: | |
| error_detail = f"Server responded with status {e.response.status_code}." | |
| try: | |
| error_json = e.response.json() | |
| error_detail += f" Detail: {error_json.get('detail', e.response.text)}" | |
| except requests.exceptions.JSONDecodeError: | |
| error_detail += f" Response: {e.response.text[:500]}" | |
| status_message = f"Submission Failed: {error_detail}" | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.Timeout: | |
| status_message = "Submission Failed: The request timed out." | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.RequestException as e: | |
| status_message = f"Submission Failed: Network error - {e}" | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except Exception as e: | |
| status_message = f"An unexpected error occurred during submission: {e}" | |
| print(f"\n{'='*70}") | |
| print(f"❌ SUBMISSION FAILED") | |
| print(f"{'='*70}") | |
| print(status_message) | |
| print(traceback.format_exc()) | |
| print(f"{'='*70}\n") | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| # --- Build Gradio Interface using Blocks --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Basic Agent Evaluation Runner") | |
| gr.Markdown( | |
| """ | |
| **Instructions:** | |
| 1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ... | |
| 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission. | |
| 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. | |
| --- | |
| **Disclaimers:** | |
| Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions). | |
| This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async. | |
| """ | |
| ) | |
| gr.LoginButton() | |
| run_button = gr.Button("Run Evaluation & Submit All Answers") | |
| status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
| results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) | |
| run_button.click( | |
| fn=run_and_submit_all, | |
| outputs=[status_output, results_table] | |
| ) | |
| if __name__ == "__main__": | |
| print("\n" + "-"*30 + " App Starting " + "-"*30) | |
| space_host_startup = os.getenv("SPACE_HOST") | |
| space_id_startup = os.getenv("SPACE_ID") | |
| if space_host_startup: | |
| print(f"✅ SPACE_HOST found: {space_host_startup}") | |
| print(f" Runtime URL should be: https://{space_host_startup}.hf.space") | |
| else: | |
| print("ℹ️ SPACE_HOST environment variable not found (running locally?).") | |
| if space_id_startup: | |
| print(f"✅ SPACE_ID found: {space_id_startup}") | |
| print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") | |
| print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main") | |
| else: | |
| print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.") | |
| print("-"*(60 + len(" App Starting ")) + "\n") | |
| print("Launching Gradio Interface for Basic Agent Evaluation...") | |
| demo.launch(debug=True, share=False) |