agentCourse / app.py
gabejavitt's picture
Update app.py
8ffc44e verified
raw
history blame
70 kB
import os
import io
import subprocess
import json
import re
import traceback
import contextlib
import uuid
import time
import ast
from typing import List, Optional, TypedDict, Annotated, Dict
from pathlib import Path
from collections import Counter
import gradio as gr
import pandas as pd
import numpy as np
import torch
from pydantic import BaseModel, Field
# Multimodal & Web Tools
from transformers import pipeline
from youtube_transcript_api import YouTubeTranscriptApi
from bs4 import BeautifulSoup
import requests
from PIL import Image
import base64
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
# LangChain & LangGraph
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, AnyMessage, ToolCall
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, END, StateGraph
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.llms import HuggingFaceHub
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
# RAG
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.documents import Document
# =============================================================================
# CONFIGURATION
# =============================================================================
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
MAX_TURNS = 25
MAX_MESSAGE_LENGTH = 8000
REFLECT_EVERY_N_TURNS = 5
# =============================================================================
# GLOBAL RAG COMPONENTS
# =============================================================================
global_embeddings = None
global_text_splitter = None
def initialize_rag_components():
"""Initialize RAG components globally."""
global global_embeddings, global_text_splitter
if global_embeddings is None:
print("Initializing RAG embeddings...")
try:
global_embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'}
)
print("✅ Embeddings initialized.")
except Exception as e:
print(f"⚠️ Failed to initialize embeddings: {e}")
return False
if global_text_splitter is None:
print("Initializing text splitter...")
global_text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", ". ", " ", ""]
)
print("✅ Text splitter initialized.")
return True
# =============================================================================
# ASR INITIALIZATION
# =============================================================================
asr_pipeline = None
try:
print("Loading ASR (Whisper) pipeline globally...")
device = 0 if torch.cuda.is_available() else -1
device_name = "cuda:0" if device == 0 else "cpu"
print(f"Attempting to use device: {device_name} for ASR.")
asr_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-base",
torch_dtype=torch.float16 if device == 0 else torch.float32,
device=device
)
print("✅ ASR (Whisper) pipeline loaded successfully.")
except Exception as e:
print(f"⚠️ Warning: Could not load ASR pipeline globally. Error: {e}")
asr_pipeline = None
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def remove_fences_simple(text):
"""Remove code fences from text."""
original_text = text
text = text.strip()
if text.startswith("```") and text.endswith("```"):
text = text[3:-3].strip()
if '\n' in text:
first_line, rest = text.split('\n', 1)
if first_line.strip().replace('_','').isalnum() and len(first_line.strip()) < 15:
text = rest.strip()
return text
return original_text
def truncate_if_needed(content: str, max_length: int = MAX_MESSAGE_LENGTH) -> str:
"""Truncate content if it exceeds max length."""
if len(content) > max_length:
return content[:max_length] + f"\n...[truncated, {len(content)} total chars]"
return content
def find_file(path: str) -> Optional[Path]:
"""Find a file by trying multiple path variations."""
script_dir = Path.cwd()
safe_path = Path(path).as_posix()
paths_to_try = [
script_dir / safe_path,
Path(safe_path),
script_dir / Path(path).name
]
for attempt_path in paths_to_try:
if attempt_path.exists():
return attempt_path
return None
# =============================================================================
# PLANNING & REFLECTION TOOLS
# =============================================================================
class ThinkInput(BaseModel):
reasoning: str = Field(description="Brief reasoning summary (under 150 chars)")
@tool(args_schema=ThinkInput)
def think_through_logic(reasoning: str) -> str:
"""
Use this to work through logic puzzles, riddles, or reasoning problems.
Call this when:
- The question is a riddle or brain teaser
- You need to reason through a logical problem
- No external information is needed, just thinking
After thinking, use calculator if math is involved, then validate and submit answer.
"""
print(f"🧠 Thinking: {reasoning[:100]}...")
return f"""✅ Logic reasoning recorded.
Next steps:
1. If math needed → use calculator()
2. Once you have answer → use validate_answer()
3. Then → use final_answer_tool()
Remember: You MUST call another tool. Do not output reasoning text."""
class PlanInput(BaseModel):
task_summary: str = Field(description="Very brief task summary (under 80 chars)")
@tool(args_schema=PlanInput)
def create_plan(task_summary: str) -> str:
"""
Creates a plan for multi-step questions. Use for complex tasks only.
Keep the summary VERY brief to avoid errors.
"""
print(f"📋 Planning: {task_summary[:80]}...")
return f"""✅ Plan created for: {task_summary}
FRAMEWORK:
1. What info do I need?
2. What tools will I use?
3. In what order?
Now execute step 1. You MUST call a tool next."""
class ReflectInput(BaseModel):
situation: str = Field(description="Brief situation summary (under 80 chars)")
@tool(args_schema=ReflectInput)
def reflect_on_progress(situation: str) -> str:
"""
Reflects on progress when stuck. Use after 5+ turns without progress.
Keep situation summary VERY brief.
"""
print(f"🤔 Reflecting: {situation[:80]}...")
return f"""🔍 REFLECTION on: {situation}
QUESTIONS:
1. Am I using the right approach?
2. Should I try a different tool?
3. Do I actually have the answer already?
Take a DIFFERENT approach now. You MUST call a tool next."""
class ValidateInput(BaseModel):
proposed_answer: str = Field(description="The answer to validate")
original_question: str = Field(description="Original question (first 100 chars)")
@tool(args_schema=ValidateInput)
def validate_answer(proposed_answer: str, original_question: str) -> str:
"""
Validates answer format before submission. ALWAYS use before final_answer_tool.
"""
print(f"✓ Validating: '{proposed_answer[:50]}...'")
issues = []
warnings = []
# Check for conversational fluff
fluff = ["the answer is", "based on", "according to", "i found", "here is"]
if any(p in proposed_answer.lower() for p in fluff):
issues.append("❌ Remove conversational text. Answer only.")
# Check for code fences
if "```" in proposed_answer:
issues.append("❌ Remove code fences (```).")
# Check length
if len(proposed_answer) > 500:
warnings.append("⚠️ Answer very long. Just the answer?")
# Check for number questions
if any(k in original_question.lower() for k in ["how many", "what number", "count"]):
if not any(c.isdigit() for c in proposed_answer):
warnings.append("⚠️ Question asks for number but answer has no digits.")
if issues:
return "🚫 VALIDATION FAILED:\n" + "\n".join(issues) + "\n\nFix then retry."
if warnings:
return "⚠️ WARNINGS:\n" + "\n".join(warnings) + "\n\nConsider fixing, or proceed if confident."
return "✅ VALIDATION PASSED! Now call final_answer_tool() with this answer."
# =============================================================================
# CORE TOOLS
# =============================================================================
class SearchInput(BaseModel):
query: str = Field(description="Search query (concise)")
@tool(args_schema=SearchInput)
def search_tool(query: str) -> str:
"""Searches web via DuckDuckGo. Use for facts, recent info."""
if not isinstance(query, str) or not query.strip():
return "Error: Invalid query."
print(f"🔍 Searching: {query}")
try:
search = DuckDuckGoSearchRun()
result = search.run(query)
return truncate_if_needed(result)
except Exception as e:
return f"Search error: {str(e)}"
class CalcInput(BaseModel):
expression: str = Field(description="Math expression (e.g., '2+2', 'sqrt(16)')")
@tool(args_schema=CalcInput)
def calculator(expression: str) -> str:
"""
Evaluates math expressions. Use for ANY calculations.
Supports: +, -, *, /, **, sqrt, sin, cos, log, pi, e, etc.
"""
if not isinstance(expression, str) or not expression.strip():
return "Error: Invalid expression."
print(f"🧮 Calculating: {expression}")
try:
import math
safe_dict = {
'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
'log': math.log, 'log10': math.log10, 'exp': math.exp,
'pi': math.pi, 'e': math.e, 'abs': abs, 'round': round,
'pow': pow, 'sum': sum, 'min': min, 'max': max
}
result = eval(expression, {"__builtins__": {}}, safe_dict)
return str(result)
except Exception as e:
return f"Calculation error for '{expression}': {str(e)}"
class CodeInput(BaseModel):
code: str = Field(description="Python code (MUST include print() for output)")
@tool(args_schema=CodeInput)
def code_interpreter(code: str) -> str:
"""
Executes Python code. Use for data processing, complex logic.
Available: pandas, numpy, json, re, datetime
CRITICAL: Always use print() to output results!
"""
if not isinstance(code, str):
return "Error: code must be string."
# Safety checks
dangerous = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system', 'exec(']
if any(d in code.lower() for d in dangerous):
return f"Error: Dangerous operation not allowed."
if 'open(' in code.lower() and any(m in code for m in ["'w'", '"w"', "'a'", '"a"']):
return "Error: File writing not allowed. Use write_file tool."
print(f"💻 Executing code ({len(code)} chars)...")
output_stream = io.StringIO()
error_stream = io.StringIO()
try:
with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream):
safe_globals = {
"pd": pd,
"np": np,
"json": json,
"re": re,
"__builtins__": __builtins__
}
exec(code, safe_globals, {})
stdout = output_stream.getvalue()
stderr = error_stream.getvalue()
if stderr:
return f"Error:\n{stderr}\n\nStdout:\n{stdout}"
if stdout:
return truncate_if_needed(stdout)
return "Code executed but no output. Remember to use print()!"
except Exception as e:
return f"Execution failed:\n{traceback.format_exc()}"
class ReadFileInput(BaseModel):
path: str = Field(description="File path")
@tool(args_schema=ReadFileInput)
def read_file(path: str) -> str:
"""Reads file content."""
if not isinstance(path, str) or not path.strip():
return "Error: Invalid path."
print(f"📄 Reading: {path}")
file_path = find_file(path)
if not file_path:
return f"Error: File not found: '{path}'\nCWD files: {os.listdir('.')}"
try:
content = file_path.read_text(encoding='utf-8')
return truncate_if_needed(content)
except UnicodeDecodeError:
return f"Error: Binary file. Size: {file_path.stat().st_size} bytes. Try audio_transcription_tool for audio."
except Exception as e:
return f"Read error: {str(e)}"
class WriteFileInput(BaseModel):
path: str = Field(description="File path")
content: str = Field(description="Content to write")
@tool(args_schema=WriteFileInput)
def write_file(path: str, content: str) -> str:
"""Writes content to file."""
if not path or not isinstance(content, str):
return "Error: Invalid inputs."
print(f"✍️ Writing: {path}")
try:
file_path = Path.cwd() / path
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content, encoding='utf-8')
return f"Wrote {len(content)} chars to '{path}'."
except Exception as e:
return f"Write error: {str(e)}"
class ListDirInput(BaseModel):
path: str = Field(description="Directory path", default=".")
@tool(args_schema=ListDirInput)
def list_directory(path: str = ".") -> str:
"""Lists directory contents."""
print(f"📁 Listing: {path}")
try:
dir_path = Path.cwd() / path if path != "." else Path.cwd()
if not dir_path.is_dir():
return f"Error: '{path}' not a directory."
items = sorted(dir_path.iterdir())
if not items:
return f"Directory '{path}' is empty."
files, dirs = [], []
for item in items:
if item.is_dir():
dirs.append(f"📁 {item.name}/")
else:
files.append(f"📄 {item.name} ({item.stat().st_size} bytes)")
result = f"Contents of '{path}':\n\n"
if dirs:
result += "Directories:\n" + "\n".join(dirs) + "\n\n"
if files:
result += "Files:\n" + "\n".join(files)
return result
except Exception as e:
return f"List error: {str(e)}"
class AudioInput(BaseModel):
file_path: str = Field(description="Audio file path")
@tool(args_schema=AudioInput)
def audio_transcription_tool(file_path: str) -> str:
"""Transcribes audio using Whisper."""
if not file_path:
return "Error: Invalid file path."
print(f"🎤 Transcribing: {file_path}")
if asr_pipeline is None:
return "Error: ASR not available."
audio_path = find_file(file_path)
if not audio_path:
return f"Error: Audio file not found: '{file_path}'"
try:
transcription = asr_pipeline(str(audio_path))
result_text = transcription.get("text", "")
if not result_text:
return "Error: Transcription empty."
return f"Transcription:\n{truncate_if_needed(result_text)}"
except Exception as e:
return f"Transcription error: {str(e)}"
class ImageAnalysisInput(BaseModel):
file_path: str = Field(description="Image file path")
query: str = Field(description="What to analyze in the image")
@tool(args_schema=ImageAnalysisInput)
def analyze_image(file_path: str, query: str) -> str:
"""
Analyzes images using Google Gemini Vision API.
Use for: chess positions, diagrams, charts, photos, screenshots.
Provide the EXACT file path from [FILE ATTACHED: ...] in the question.
"""
if not file_path or not query:
return "Error: file_path and query required."
print(f"🖼️ Analyzing image: {file_path}")
print(f" Query: {query[:100]}...")
# Try to find the file
image_path = find_file(file_path)
# If not found via find_file, try the path directly (for /tmp files)
if not image_path and os.path.exists(file_path):
image_path = Path(file_path)
if not image_path or not image_path.exists():
return f"Error: Image not found at '{file_path}'. Check [FILE ATTACHED: ...] in question for correct path."
print(f"✓ Found image at: {image_path}")
try:
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
if not GOOGLE_API_KEY:
return "Error: GEMINI_API_KEY not set."
# Load and encode image
img = Image.open(image_path)
print(f" Image size: {img.size}, mode: {img.mode}")
# Convert to RGB if necessary
if img.mode not in ['RGB', 'RGBA']:
img = img.convert('RGB')
# Convert to base64
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
print(f" Encoded image: {len(img_base64)} bytes")
# Use Gemini Vision
vision_llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash-exp",
google_api_key=GOOGLE_API_KEY,
temperature=0
)
message = HumanMessage(
content=[
{"type": "text", "text": query},
{
"type": "image_url",
"image_url": f"data:image/jpeg;base64,{img_base64}"
}
]
)
print(f" Sending to Gemini Vision...")
response = vision_llm.invoke([message])
print(f"✓ Got response: {len(response.content)} chars")
return f"Image Analysis:\n{truncate_if_needed(response.content)}"
except Exception as e:
error_msg = f"Image analysis error: {str(e)}"
print(f"❌ {error_msg}")
print(traceback.format_exc())
return error_msg
class YoutubeInput(BaseModel):
video_url: str = Field(description="YouTube URL")
@tool(args_schema=YoutubeInput)
def get_youtube_transcript(video_url: str) -> str:
"""Fetches YouTube video transcript using yt-dlp."""
if not video_url:
return "Error: Invalid URL."
print(f"📺 YouTube transcript: {video_url}")
try:
# Extract video ID
video_id = None
if "watch?v=" in video_url:
video_id = video_url.split("v=")[1].split("&")[0]
elif "youtu.be/" in video_url:
video_id = video_url.split("youtu.be/")[1].split("?")[0]
if not video_id:
return f"Error: Could not extract video ID."
# Use yt-dlp to get subtitles
subtitle_file = f'{video_id}.en.vtt'
cmd = [
'yt-dlp',
'--skip-download',
'--write-auto-subs',
'--write-subs',
'--sub-lang', 'en',
'--sub-format', 'vtt',
'--output', video_id,
video_url
]
print(f"🔧 Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=45)
if result.returncode != 0:
print(f"⚠️ yt-dlp stderr: {result.stderr}")
return f"Error: Could not fetch subtitles - {result.stderr[:200]}"
# Try to find the subtitle file (might have different naming)
import glob
vtt_files = glob.glob(f"{video_id}*.vtt")
if not vtt_files:
return "Error: No English subtitles found for this video."
subtitle_file = vtt_files[0]
print(f"✓ Found subtitle file: {subtitle_file}")
# Read and parse VTT file
with open(subtitle_file, 'r', encoding='utf-8') as f:
content = f.read()
# Remove VTT headers and timestamps
lines = content.split('\n')
transcript_parts = []
for line in lines:
line = line.strip()
# Skip WEBVTT header, timestamps, and empty lines
if (line and
not line.startswith('WEBVTT') and
not '-->' in line and
not line.isdigit() and
not line.startswith('Kind:') and
not line.startswith('Language:')):
transcript_parts.append(line)
full_transcript = " ".join(transcript_parts)
# Cleanup subtitle files
for vtt_file in vtt_files:
try:
os.remove(vtt_file)
except:
pass
if not full_transcript:
return "Error: Transcript was empty."
print(f"✓ Transcript extracted: {len(full_transcript)} chars")
return f"Transcript:\n{truncate_if_needed(full_transcript)}"
except subprocess.TimeoutExpired:
return "Error: yt-dlp timed out after 45 seconds."
except FileNotFoundError:
return "Error: yt-dlp not installed. Add 'yt-dlp' to requirements.txt"
except Exception as e:
print(f"❌ Error: {str(e)}")
print(traceback.format_exc())
return f"Transcript error: {str(e)}"
class ScrapeInput(BaseModel):
url: str = Field(description="URL (must start with http:// or https://)")
query: str = Field(description="What to find on the page")
@tool(args_schema=ScrapeInput)
def scrape_and_retrieve(url: str, query: str) -> str:
"""
Scrapes webpage and uses RAG to find relevant info.
Use when you need specific info from a known URL.
"""
if not url.startswith(('http://', 'https://')):
return f"Error: Invalid URL format."
if not query:
return "Error: Query required."
if global_embeddings is None or global_text_splitter is None:
if not initialize_rag_components():
return "Error: RAG not initialized."
print(f"🌐 Scraping: {url}")
try:
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
response = requests.get(url, headers=headers, timeout=20)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe"]):
tag.extract()
main = soup.find('main') or soup.find('article') or soup.body
if not main:
return "Error: No main content found."
text = main.get_text(separator='\n', strip=True)
lines = [l.strip() for l in text.splitlines() if l.strip()]
text = '\n'.join(lines)
if len(text) < 50:
return f"Error: Content too short ({len(text)} chars)."
chunks = global_text_splitter.split_text(text)
if not chunks:
return "Error: Could not chunk text."
docs = [Document(page_content=c, metadata={"source": url}) for c in chunks]
db = FAISS.from_documents(docs, global_embeddings)
retriever = db.as_retriever(search_kwargs={"k": 5})
retrieved = retriever.invoke(query)
if not retrieved:
return f"No relevant info found for: '{query}'"
context = "\n\n---\n\n".join([f"[Chunk {i+1}]\n{d.page_content}" for i, d in enumerate(retrieved)])
return truncate_if_needed(f"From {url}:\n\n{context}")
except requests.RequestException as e:
return f"Fetch error: {str(e)}"
except Exception as e:
return f"Scrape error: {str(e)}\n{traceback.format_exc()}"
class FinalAnswerInput(BaseModel):
answer: str = Field(description="Final answer - EXACTLY what was asked, nothing more")
@tool(args_schema=FinalAnswerInput)
def final_answer_tool(answer: str) -> str:
"""
Submit final answer. CRITICAL RULES:
1. ALWAYS call validate_answer() first
2. Answer must be EXACTLY what was asked
3. NO conversational text
4. NO explanations
5. Match requested format exactly
"""
if not isinstance(answer, str):
answer = str(answer)
print(f"✅ FINAL ANSWER SUBMITTED: {answer}")
return answer
# =============================================================================
# DEFINED TOOLS LIST
# =============================================================================
defined_tools = [
# Planning & Reflection
think_through_logic,
create_plan,
reflect_on_progress,
validate_answer,
# Core tools
search_tool,
calculator,
code_interpreter,
# File operations
read_file,
write_file,
list_directory,
# Specialized
audio_transcription_tool,
analyze_image, # NEW: Image analysis tool
get_youtube_transcript,
scrape_and_retrieve,
# Final
final_answer_tool
]
# =============================================================================
# AGENT STATE
# =============================================================================
class AgentState(TypedDict):
messages: Annotated[List[AnyMessage], add_messages]
turn: int
has_plan: bool
consecutive_errors: int
tool_history: List[str]
last_tool_was_thinking: bool
# =============================================================================
# ENHANCED FALLBACK PARSER
# =============================================================================
def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]:
"""Enhanced parser with multiple strategies."""
print(f"🔧 Fallback parsing (first 300 chars):\n{content[:300]}")
tool_name = None
tool_input = None
# STRATEGY 1: Groq's <function=name{...}> format
groq_match = re.search(r"<function=(\w+)\s*(\{.*?\})\s*(?:>|</function>)", content, re.DOTALL)
if groq_match:
try:
tool_name = groq_match.group(1).strip()
json_str = groq_match.group(2).strip()
json_str = json_str.encode().decode('unicode_escape')
tool_input = json.loads(json_str)
print(f"✓ Parsed Groq format: {tool_name}")
except:
tool_name = None
# STRATEGY 2: Standard <function(name)>{...} format
if not tool_name:
func_match = re.search(r"<function[(=]\s*([^)]+)\s*[)>](.*)", content, re.DOTALL | re.IGNORECASE)
if func_match:
try:
tool_name = func_match.group(1).strip().replace("'", "").replace('"', '')
remaining = func_match.group(2)
json_start = remaining.find('{')
if json_start != -1:
json_str = remaining[json_start:].strip().rstrip(',')
tool_input = json.loads(json_str)
print(f"✓ Parsed standard format: {tool_name}")
except:
tool_name = None
# STRATEGY 3: Tool mention with code block → wrap in code_interpreter
if not tool_name and "```python" in content:
try:
code_match = re.search(r"```python\n(.*?)```", content, re.DOTALL)
if code_match:
code = code_match.group(1).strip()
tool_name = "code_interpreter"
tool_input = {"code": code}
print(f"✓ Extracted Python code → code_interpreter")
except:
pass
# STRATEGY 4: Direct tool mention → create minimal valid call
if not tool_name:
for tool in tools:
if tool.name.lower() in content.lower():
tool_name = tool.name
tool_input = {}
# Try to extract arguments from content
if tool.args_schema:
schema = tool.args_schema.model_json_schema()
for prop in schema.get('properties', {}).keys():
if prop in schema.get('required', []):
# Use placeholder
tool_input[prop] = "auto_extracted"
print(f"✓ Found mention of '{tool_name}' → creating default call")
break
# STRATEGY 5: Emergency - if no tool detected, force a reasonable one
if not tool_name:
# If content looks like reasoning, use think_through_logic
if len(content) > 50 and not any(kw in content.lower() for kw in ["error", "failed", "invalid"]):
tool_name = "think_through_logic"
tool_input = {"reasoning": content[:150]}
print(f"⚠️ No tool detected → forcing think_through_logic")
# Validate and create tool call
if tool_name and tool_input is not None:
matching_tools = [t for t in tools if t.name == tool_name]
if matching_tools:
return [ToolCall(name=tool_name, args=tool_input, id=str(uuid.uuid4()))]
else:
print(f"❌ Tool '{tool_name}' not in available tools")
print("❌ All parsing strategies failed")
return []
# =============================================================================
# CONDITIONAL EDGE FUNCTION
# =============================================================================
def should_continue(state: AgentState):
"""Decide next step with robust logic."""
messages = state.get('messages', [])
if not messages:
return "agent"
last_message = messages[-1]
current_turn = state.get('turn', 0)
# Debug: Print what we're checking
msg_type = type(last_message).__name__
print(f"📍 Conditional check - Turn {current_turn}, Last msg type: {msg_type}")
# 1. Check turn limit
if current_turn >= MAX_TURNS:
print(f"🛑 Max turns ({MAX_TURNS}) reached")
return END
# 2. If last message is ToolMessage, agent needs to process it
if isinstance(last_message, ToolMessage):
print(f"📨 Tool result received from '{last_message.name}' → back to agent")
return "agent"
# 3. If last message is AIMessage with tool calls
if isinstance(last_message, AIMessage) and last_message.tool_calls:
# Only check the FIRST tool call, not all of them
first_tool = last_message.tool_calls[0]
tool_name = first_tool.get("name", "")
if tool_name == "final_answer_tool":
return END
else:
return "tools"
# 4. If AIMessage but no tool calls (reasoning text)
if isinstance(last_message, AIMessage) and not last_message.tool_calls:
# Check for consecutive AI messages (loop)
if len(messages) >= 2 and isinstance(messages[-2], AIMessage) and not messages[-2].tool_calls:
print(f"⚠️ Loop detected: 2 consecutive AI messages without tools")
return END
print(f"💭 AI message without tool call → continuing to agent (will force tool)")
return "agent"
# 5. Default: continue to agent
print(f"🔄 Default → continuing to agent")
# =============================================================================
# ENHANCED AGENT CLASS
# =============================================================================
class PlanningReflectionAgent:
def __init__(self):
print("🧠 PlanningReflectionAgent initializing...")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not set!")
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not HUGGINGFACEHUB_API_TOKEN:
raise ValueError("HUGGINGFACEHUB_API_TOKEN secret is not set! Please add it to your Space secrets.")
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY not set!")
self.tools = defined_tools
# Initialize RAG
if not initialize_rag_components():
print("⚠️ RAG components failed to initialize.")
# Build tool descriptions
tool_desc_list = []
for tool in self.tools:
if tool.args_schema:
schema = tool.args_schema.model_json_schema()
args_desc = [f" - {p}: {d.get('description', '')}"
for p, d in schema.get('properties', {}).items()]
desc = f"- {tool.name}:\n {tool.description}\n" + "\n".join(args_desc)
else:
desc = f"- {tool.name}: {tool.description}"
tool_desc_list.append(desc)
tool_descriptions = "\n".join(tool_desc_list)
# ULTRA-AGGRESSIVE SYSTEM PROMPT
self.system_prompt = f"""You are an elite AI agent for GAIA benchmark. Your ONLY job: provide the EXACT answer requested.
═══════════════════════════════════════════════════════════════
⚠️ ABSOLUTE RULES - VIOLATE THESE AND YOU FAIL:
═══════════════════════════════════════════════════════════════
1. **EVERY TURN MUST CALL EXACTLY ONE TOOL** - No exceptions
2. **NEVER OUTPUT REASONING TEXT WITHOUT A TOOL CALL** - You will fail
3. **IDENTIFY QUESTION TYPE FIRST** - Logic? Factual? Data? Math?
4. **LOGIC PUZZLES**: think_through_logic → calculator (if needed) → validate → final_answer
5. **FACTUAL QUESTIONS**: search_tool → validate → final_answer
6. **DATA QUESTIONS**: read_file → code_interpreter → validate → final_answer
7. **ALWAYS VALIDATE**: Call validate_answer() before final_answer_tool()
8. **FINAL ANSWER FORMAT**: EXACTLY what was asked. NO "The answer is..." or explanations
═══════════════════════════════════════════════════════════════
📋 QUESTION TYPE GUIDE:
═══════════════════════════════════════════════════════════════
**RIDDLES/LOGIC PUZZLES** (No web search needed):
- Brain teasers, puzzles, logical deduction
- Strategy: think_through_logic → calculator (if math) → validate → final_answer
- Example: "If 200 coins, 30 face-down, divide into equal piles..."
Turn 1: think_through_logic("Adventurer takes 30 coins and flips them")
Turn 2: calculator("30") [if needed]
Turn 3: validate_answer("30", question)
Turn 4: final_answer_tool("30")
**FACTUAL/RESEARCH** (Need web):
- Who, what, when, where questions
- Strategy: search_tool → scrape_and_retrieve → validate → final_answer
- Example: "What was Einstein's birthplace population in 1900?"
Turn 1: search_tool("Albert Einstein birthplace")
Turn 2: search_tool("Ulm Germany population 1900")
Turn 3: validate_answer("50000", question)
Turn 4: final_answer_tool("50000")
**DATA ANALYSIS** (Need files):
- CSV/Excel questions
- Strategy: list_directory → read_file → code_interpreter → validate → final_answer
**SIMPLE MATH**:
- Calculations
- Strategy: calculator() → validate_answer() → final_answer_tool()
═══════════════════════════════════════════════════════════════
🎓 CRITICAL EXAMPLES:
═══════════════════════════════════════════════════════════════
Example 1: Logic Puzzle
Q: "Coin riddle with 200 coins, 30 face-down..."
✅ CORRECT:
Turn 1: think_through_logic("Take 30 coins, flip all")
Turn 2: validate_answer("30", "coin riddle...")
Turn 3: final_answer_tool("30")
❌ WRONG:
Turn 1: [reasoning text without tool] ← FAILS!
Example 2: Letter Bank Puzzle
Q: "Use letters to spell sentences, which letters need changing?"
✅ CORRECT:
Turn 1: code_interpreter("code to count letters...")
Turn 2: validate_answer("A, B, C", question)
Turn 3: final_answer_tool("A, B, C")
Example 3: Math Problem
Q: "System of equations to solve..."
✅ CORRECT:
Turn 1: code_interpreter("import numpy; solve equations...")
Turn 2: validate_answer("0, 1, 2", question)
Turn 3: final_answer_tool("0, 1, 2")
═══════════════════════════════════════════════════════════════
📚 AVAILABLE TOOLS:
═══════════════════════════════════════════════════════════════
{tool_descriptions}
═══════════════════════════════════════════════════════════════
⚡ EXECUTION RULES:
═══════════════════════════════════════════════════════════════
- If you output text without a tool call, you have FAILED
- If you're unsure, use think_through_logic() to organize thoughts
- ALWAYS call a tool - preferably the right one for the question type
- After EVERY tool result, decide: "Do I have the answer? → validate → submit"
- If stuck after 3 turns: call reflect_on_progress()
REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
═══════════════════════════════════════════════════════════════
"""
#. Initialize the LLM ()
# print("Initializing Groq LLM...")
# try:
# self.llm_with_tools = ChatGroq(
# temperature=0,
# groq_api_key=GROQ_API_KEY,
# model_name="llama-3.1-8b-instant",
# max_tokens=4096,
# timeout=60
# ).bind_tools(self.tools, tool_choice="auto")
# print("✅ LLM initialized without FORCED tool usage.")
#
# except Exception as e:
# print(f"❌ Error initializing HuggingFace: {e}")
# raise
# print("Initializing LLM Endpoint...")
print("Initializing HuggingFace LLM...")
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Llama-3.1-70B-Instruct", # Free on HF Inference API
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
max_new_tokens=4096,
temperature=0.01,
)
chat_llm = ChatHuggingFace(llm=llm)
print("✅ HuggingFace LLM Endpoint initialized.")
# Bind tools to the LLM
self.llm_with_tools = chat_llm.bind_tools(self.tools)
print("✅ Tools bound to LLM.")
# print("Initializing Google Gemini LLM...")
# try:
# self.llm_with_tools = ChatGoogleGenerativeAI(
# model="gemini-2.5-flash", # Latest model
# google_api_key=GOOGLE_API_KEY,
# temperature=0,
# max_output_tokens=8192,
# timeout=60,
# convert_system_message_to_human=True # Important for Gemini
# ).bind_tools(self.tools, tool_choice="auto")
# print("✅ Gemini LLM initialized.")
# except Exception as e:
# print(f"❌ Error initializing Gemini: {e}")
# raise
# Agent Node with AGGRESSIVE tool forcing
def agent_node(state: AgentState):
current_turn = state.get('turn', 0) + 1
print(f"\n{'='*70}")
print(f"🤖 AGENT TURN {current_turn}/{MAX_TURNS}")
print('='*70)
if current_turn > MAX_TURNS:
return {
"messages": [SystemMessage(content="Max turns reached.")],
"turn": current_turn
}
# Check if we should force reflection
consecutive_errors = state.get('consecutive_errors', 0)
should_reflect = (current_turn > 5 and current_turn % REFLECT_EVERY_N_TURNS == 0) or consecutive_errors >= 3
messages_to_send = state["messages"].copy()
# Add tool-forcing message if last turn had no tool call
if len(messages_to_send) >= 2:
last_msg = messages_to_send[-1]
if isinstance(last_msg, AIMessage) and not last_msg.tool_calls:
force_msg = SystemMessage(
content="⚠️ CRITICAL: You MUST call a tool this turn. NO reasoning text. Pick the most appropriate tool and call it now."
)
messages_to_send.append(force_msg)
print("🚨 Injecting tool-forcing message")
# Add reflection hint if needed
if should_reflect:
hint = SystemMessage(
content="⚠️ HINT: Multiple turns without progress. Consider calling reflect_on_progress() or try a different approach."
)
messages_to_send.append(hint)
print("🤔 Injecting reflection hint")
# Invoke LLM with retries and fallback
max_retries = 3
ai_message = None
for attempt in range(max_retries):
try:
ai_message = self.llm_with_tools.invoke(messages_to_send)
# If we got a valid response with tool calls, break
if ai_message.tool_calls:
break
# If no tool calls, this is a problem
print(f"⚠️ LLM returned no tool calls on attempt {attempt+1}")
except Exception as e:
error_str = str(e)
print(f"⚠️ LLM attempt {attempt+1}/{max_retries} failed: {error_str[:200]}")
# If tool_use_failed, try without strict binding
if "tool_use_failed" in error_str and attempt < max_retries - 1:
print("🔧 Trying without strict tool enforcement...")
try:
simple_llm = ChatGroq(
temperature=0,
groq_api_key=os.getenv("GROQ_API_KEY"),
model_name="llama-3.3-70b-versatile",
max_tokens=4096,
timeout=60
)
# Add explicit tool forcing to the message
force_tool_msg = SystemMessage(
content="You MUST call a tool. Respond with a tool call, not reasoning text."
)
ai_message = simple_llm.invoke(messages_to_send + [force_tool_msg])
# Try to parse tool calls from content
if ai_message.content and not ai_message.tool_calls:
parsed = parse_tool_call_from_string(ai_message.content, self.tools)
if parsed:
ai_message.tool_calls = parsed
ai_message.content = ""
print("✓ Fallback parsing succeeded")
break
except Exception as e2:
print(f"⚠️ Fallback also failed: {e2}")
if attempt == max_retries - 1:
# Last resort: inject a default tool call
print("🚨 All attempts failed - forcing think_through_logic")
ai_message = AIMessage(
content="",
tool_calls=[ToolCall(
name="think_through_logic",
args={"reasoning": "Processing question"},
id=str(uuid.uuid4())
)]
)
else:
time.sleep(2 ** attempt)
# If still no tool calls after all attempts, force one
if not ai_message.tool_calls:
if isinstance(ai_message.content, str) and ai_message.content.strip():
# Try one more parse
parsed = parse_tool_call_from_string(ai_message.content, self.tools)
if parsed:
ai_message.tool_calls = parsed
ai_message.content = ""
print("✓ Final parse succeeded")
else:
# Absolute last resort
print("🚨 EMERGENCY: Forcing think_through_logic")
ai_message.tool_calls = [ToolCall(
name="think_through_logic",
args={"reasoning": "analyzing question"},
id=str(uuid.uuid4())
)]
ai_message.content = ""
# Track tool usage
tool_history = state.get('tool_history', [])
has_plan = state.get('has_plan', False)
if ai_message.tool_calls:
tool_name = ai_message.tool_calls[0]['name']
print(f"🔧 Tool Call: {tool_name}")
tool_history.append(tool_name)
if tool_name == "create_plan":
has_plan = True
else:
print(f"⚠️ No tool call (this shouldn't happen!)")
print(f"💭 Content: {ai_message.content[:200]}...")
return {
"messages": [ai_message],
"turn": current_turn,
"has_plan": has_plan,
"tool_history": tool_history,
"last_tool_was_thinking": ai_message.tool_calls and ai_message.tool_calls[0]['name'] == 'think_through_logic'
}
# Tool Node with Error Tracking (FIXED)
def tool_node_wrapper(state: AgentState):
"""Executes tools and tracks errors."""
print(f"🔧 Executing tools...")
# Create fresh ToolNode instance
tool_executor = ToolNode(self.tools)
# Invoke properly
result = tool_executor.invoke(state)
# Track errors
consecutive_errors = state.get('consecutive_errors', 0)
if result.get('messages'):
last_msg = result['messages'][-1]
if isinstance(last_msg, ToolMessage):
if "Error" in last_msg.content or "error" in last_msg.content.lower():
consecutive_errors += 1
print(f"⚠️ Tool error detected (consecutive: {consecutive_errors})")
else:
consecutive_errors = 0
result['consecutive_errors'] = consecutive_errors
return result
# Build Graph
print("Building graph...")
graph_builder = StateGraph(AgentState)
graph_builder.add_node("agent", agent_node)
graph_builder.add_node("tools", tool_node_wrapper)
graph_builder.add_edge(START, "agent")
graph_builder.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
"agent": "agent",
END: END
}
)
graph_builder.add_edge("tools", "agent")
self.graph = graph_builder.compile()
print("✅ Graph compiled successfully.")
def __call__(self, question: str, file_path: str = None) -> str:
"""Execute agent on a question."""
print(f"\n{'='*70}")
print(f"🎯 NEW QUESTION")
print(f"{'='*70}")
print(f"Q: {question[:200]}{'...' if len(question) > 200 else ''}")
if file_path:
print(f"📎 File attached: {file_path}")
print(f"{'='*70}\n")
# Enhanced question context with file information
question_text = question
if file_path:
file_ext = Path(file_path).suffix.lower()
file_type = "unknown"
if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
file_type = "image"
elif file_ext in ['.mp3', '.wav', '.m4a', '.flac']:
file_type = "audio"
elif file_ext in ['.csv', '.xlsx', '.xls']:
file_type = "data"
elif file_ext in ['.txt', '.pdf', '.doc', '.docx']:
file_type = "document"
question_text += f"\n\n[FILE ATTACHED: {file_path}]"
question_text += f"\n[FILE TYPE: {file_type}]"
question_text += f"\nIMPORTANT: Use the appropriate tool to access this file first!"
graph_input = {
"messages": [
SystemMessage(content=self.system_prompt),
HumanMessage(content=question_text)
],
"file_path": file_path,
"turn": 0,
"has_plan": False,
"consecutive_errors": 0,
"tool_history": [],
"last_tool_was_thinking": False
}
final_answer = "AGENT FAILED TO PRODUCE ANSWER"
all_messages = []
try:
config = {"recursion_limit": MAX_TURNS + 10}
for event in self.graph.stream(graph_input, stream_mode="values", config=config):
if not event.get('messages'):
continue
all_messages = event["messages"]
last_message = all_messages[-1]
# Check for final answer
if isinstance(last_message, AIMessage) and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if tool_call.get("name") == "final_answer_tool":
args = tool_call.get('args', {})
if 'answer' in args:
final_answer = args['answer']
print(f"\n{'='*70}")
print(f"✅ FINAL ANSWER: '{final_answer}'")
print(f"{'='*70}\n")
break
elif isinstance(last_message, ToolMessage):
preview = last_message.content[:200].replace('\n', ' ')
print(f"📊 Tool '{last_message.name}' result: {preview}...")
elif isinstance(last_message, AIMessage) and not last_message.tool_calls:
print(f"💭 AI: {last_message.content[:200]}...")
# If no final answer, try to extract from tool messages
if final_answer == "AGENT FAILED TO PRODUCE ANSWER":
print("⚠️ No final_answer_tool called. Checking tool results...")
for msg in reversed(all_messages):
if isinstance(msg, ToolMessage):
if msg.name in ["calculator", "think_through_logic", "code_interpreter"]:
content = msg.content.strip()
# Look for short, answer-like content
if content and len(content) < 200 and not content.startswith("Error"):
# Extract just the result part
lines = content.split('\n')
for line in reversed(lines):
if line.strip() and not line.startswith(('✅', '⚠️', 'Next', 'Remember')):
final_answer = line.strip()
print(f"📝 Extracted from {msg.name}: '{final_answer}'")
break
break
# Clean the answer
cleaned = str(final_answer).strip()
# Remove prefixes
prefixes = [
"the answer is:", "here is the answer:", "based on",
"final answer:", "answer:", "the final answer is:",
"my answer is:", "according to", "i found that",
"the result is:", "result:"
]
for prefix in prefixes:
if cleaned.lower().startswith(prefix.lower()):
potential = cleaned[len(prefix):].strip()
if potential:
cleaned = potential
break
# Remove code fences and quotes
cleaned = remove_fences_simple(cleaned)
while cleaned.startswith("`") and cleaned.endswith("`"):
cleaned = cleaned[1:-1].strip()
if (cleaned.startswith('"') and cleaned.endswith('"')) or \
(cleaned.startswith("'") and cleaned.endswith("'")):
cleaned = cleaned[1:-1].strip()
# Remove trailing period for short answers
if cleaned.endswith('.') and len(cleaned.split()) < 10:
cleaned = cleaned[:-1]
print(f"\n{'='*70}")
print(f"🎉 RETURNING ANSWER")
print(f"{'='*70}")
print(f"{cleaned}")
print(f"{'='*70}\n")
return cleaned
except Exception as e:
print(f"❌ Graph error: {e}")
print(traceback.format_exc())
return f"AGENT ERROR: {e}"
# =============================================================================
# GLOBAL AGENT INSTANTIATION
# =============================================================================
agent = None
try:
initialize_rag_components()
agent = PlanningReflectionAgent()
print("✅ Global PlanningReflectionAgent instantiated.")
# Verify it's callable
if not callable(agent):
print("❌ ERROR: Agent not callable!")
agent = None
else:
print("✅ Agent is callable.")
if asr_pipeline is None:
print("⚠️ ASR Pipeline not loaded.")
except Exception as e:
print(f"❌ FATAL: Agent initialization failed: {e}")
traceback.print_exc()
agent = None
# =============================================================================
# RUN AND SUBMIT FUNCTION
# =============================================================================
def run_and_submit_all(profile: gr.OAuthProfile | None):
"""
Fetches all questions, runs the BasicAgent on them, submits all answers,
and displays the results.
"""
space_id = os.getenv("SPACE_ID")
if profile:
username = f"{profile.username}"
print(f"User logged in: {username}")
else:
print("User not logged in.")
return "Please Login to Hugging Face with the button.", None
# Use the globally instantiated agent
global agent
if agent is None:
error_msg = "FATAL: Agent failed to initialize at startup. Check logs for errors."
print(error_msg)
return error_msg, None
print("✅ Using globally instantiated PlanningReflectionAgent")
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
print(agent_code)
# 2. Fetch Questions
print(f"\n{'='*70}")
print(f"📥 FETCHING QUESTIONS")
print(f"{'='*70}")
print(f"Fetching questions from: {questions_url}")
try:
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
print("Fetched questions list is empty.")
return "Fetched questions list is empty or invalid format.", None
print(f"✅ Fetched {len(questions_data)} questions.")
print(f"{'='*70}\n")
except requests.exceptions.RequestException as e:
print(f"❌ Error fetching questions: {e}")
return f"Error fetching questions: {e}", None
except requests.exceptions.JSONDecodeError as e:
print(f"❌ Error decoding JSON response from questions endpoint: {e}")
print(f"Response text: {response.text[:500]}")
return f"Error decoding server response for questions: {e}", None
except Exception as e:
print(f"❌ An unexpected error occurred fetching questions: {e}")
return f"An unexpected error occurred fetching questions: {e}", None
# 3. Run your Agent
print(f"\n{'='*70}")
print(f"🚀 STARTING EVALUATION")
print(f"{'='*70}")
print(f"Total questions to process: {len(questions_data)}")
print(f"{'='*70}\n")
results_log = []
answers_payload = []
for idx, item in enumerate(questions_data, 1):
print(f"\n{'='*70}")
print(f"📝 PROCESSING QUESTION {idx}/{len(questions_data)}")
print(f"{'='*70}")
task_id = item.get("task_id")
question_text = item.get("question")
correct_answer = item.get("answer", "N/A") # Get correct answer from API
# Initialize file variables for the current question
# Try to download file for EVERY task (not just if file_path exists)
file_download_url = f"{DEFAULT_API_URL}/files/{task_id}"
local_file_path = None
try:
file_response = requests.get(file_download_url, timeout=15)
if file_response.status_code == 200:
# Get filename from Content-Disposition header if available
filename = None
if 'Content-Disposition' in file_response.headers:
cd = file_response.headers['Content-Disposition']
filename_match = re.findall('filename="?([^"]+)"?', cd)
if filename_match:
filename = filename_match[0]
# If no filename, use task_id with extension from Content-Type
if not filename:
content_type = file_response.headers.get('Content-Type', '')
ext_map = {
'image/png': '.png',
'image/jpeg': '.jpg',
'image/gif': '.gif',
'audio/mpeg': '.mp3',
'audio/wav': '.wav',
'text/plain': '.txt',
'text/csv': '.csv',
'application/pdf': '.pdf',
'text/x-python': '.py',
'application/x-python-code': '.py',
}
ext = ext_map.get(content_type, '')
filename = f"{task_id}{ext}"
# Save to current directory
local_file_path = filename
with open(local_file_path, 'wb') as f:
for chunk in file_response.iter_content(chunk_size=8192):
f.write(chunk)
file_size = os.path.getsize(local_file_path)
abs_path = os.path.abspath(local_file_path)
print(f"✅ Downloaded: {filename} ({file_size} bytes)")
print(f" Saved to: {abs_path}")
elif file_response.status_code == 404:
print(f"ℹ️ No file found for task {task_id} (404), proceeding without file.")
else:
print(f"⚠️ Warning: File download for {task_id} failed with status {file_response.status_code}")
except Exception as e:
# Handles any other unexpected errors
print(f"\n❌ An unexpected error occurred: {e}")
try:
# Pass file_path to agent
submitted_answer = agent(question_text, local_file_path)
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
# Check if answer is correct
is_correct = submitted_answer.strip().lower() == correct_answer.strip().lower()
correctness = "✅ CORRECT" if is_correct else "❌ WRONG"
# Log with correctness indicator
print(f"\n{correctness} - Task {task_id}")
print(f" Submitted: '{submitted_answer}'")
print(f" Expected: '{correct_answer}'")
results_log.append({
"Task ID": task_id,
"Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
"Submitted Answer": submitted_answer,
"Correct Answer": correct_answer,
"Status": "✅" if is_correct else "❌"
})
print(f"✅ Question {idx}/{len(questions_data)} completed")
except Exception as e:
print(f"❌ Error running agent on task {task_id}: {e}")
print(traceback.format_exc())
results_log.append({
"Task ID": task_id,
"Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
"Submitted Answer": f"AGENT ERROR: {e}",
"Correct Answer": correct_answer,
"Status": "❌"
})
# Continue with other questions even if one fails
answers_payload.append({"task_id": task_id, "submitted_answer": f"ERROR: {str(e)[:100]}"})
# Summary after all questions processed
print(f"\n{'='*70}")
print(f"✅ ALL QUESTIONS PROCESSED")
print(f"{'='*70}")
print(f"Total answers collected: {len(answers_payload)}")
# Calculate pre-submission accuracy
correct_count = sum(1 for log in results_log if log.get("Status") == "✅")
total_count = len(results_log)
accuracy = (correct_count / total_count * 100) if total_count > 0 else 0
print(f"\n{'='*70}")
print(f"📊 PRE-SUBMISSION SUMMARY")
print(f"{'='*70}")
print(f"Correct: {correct_count}/{total_count} ({accuracy:.1f}%)")
print(f"{'='*70}\n")
if not answers_payload:
print("⚠️ Agent did not produce any answers to submit.")
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
# 4. Prepare Submission
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
# 5. Submit
print(f"\n{'='*70}")
print(f"📤 SUBMITTING TO API")
print(f"{'='*70}")
print(f"URL: {submit_url}")
print(f"Username: {username}")
print(f"Answers to submit: {len(answers_payload)}")
print(f"{'='*70}\n")
try:
print("⏳ Sending POST request...")
response = requests.post(submit_url, json=submission_data, timeout=60)
print(f"✅ Got response: Status {response.status_code}")
response.raise_for_status()
result_data = response.json()
print(f"\n{'='*70}")
print(f"📊 SUBMISSION RESULTS")
print(f"{'='*70}")
print(f"Response data: {result_data}")
print(f"{'='*70}\n")
final_status = (
f"Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: {result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
f"Message: {result_data.get('message', 'No message received.')}"
)
print(final_status)
print("="*70)
print("✅ Submission successful.")
results_df = pd.DataFrame(results_log)
return final_status, results_df
except requests.exceptions.HTTPError as e:
error_detail = f"Server responded with status {e.response.status_code}."
try:
error_json = e.response.json()
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
except requests.exceptions.JSONDecodeError:
error_detail += f" Response: {e.response.text[:500]}"
status_message = f"Submission Failed: {error_detail}"
print(f"\n{'='*70}")
print(f"❌ SUBMISSION FAILED")
print(f"{'='*70}")
print(status_message)
print(f"{'='*70}\n")
results_df = pd.DataFrame(results_log)
return status_message, results_df
except requests.exceptions.Timeout:
status_message = "Submission Failed: The request timed out."
print(f"\n{'='*70}")
print(f"❌ SUBMISSION FAILED")
print(f"{'='*70}")
print(status_message)
print(f"{'='*70}\n")
results_df = pd.DataFrame(results_log)
return status_message, results_df
except requests.exceptions.RequestException as e:
status_message = f"Submission Failed: Network error - {e}"
print(f"\n{'='*70}")
print(f"❌ SUBMISSION FAILED")
print(f"{'='*70}")
print(status_message)
print(f"{'='*70}\n")
results_df = pd.DataFrame(results_log)
return status_message, results_df
except Exception as e:
status_message = f"An unexpected error occurred during submission: {e}"
print(f"\n{'='*70}")
print(f"❌ SUBMISSION FAILED")
print(f"{'='*70}")
print(status_message)
print(traceback.format_exc())
print(f"{'='*70}\n")
results_df = pd.DataFrame(results_log)
return status_message, results_df
# --- Build Gradio Interface using Blocks ---
with gr.Blocks() as demo:
gr.Markdown("# Basic Agent Evaluation Runner")
gr.Markdown(
"""
**Instructions:**
1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
---
**Disclaimers:**
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
"""
)
gr.LoginButton()
run_button = gr.Button("Run Evaluation & Submit All Answers")
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
run_button.click(
fn=run_and_submit_all,
outputs=[status_output, results_table]
)
if __name__ == "__main__":
print("\n" + "-"*30 + " App Starting " + "-"*30)
space_host_startup = os.getenv("SPACE_HOST")
space_id_startup = os.getenv("SPACE_ID")
if space_host_startup:
print(f"✅ SPACE_HOST found: {space_host_startup}")
print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
else:
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
if space_id_startup:
print(f"✅ SPACE_ID found: {space_id_startup}")
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
else:
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
print("-"*(60 + len(" App Starting ")) + "\n")
print("Launching Gradio Interface for Basic Agent Evaluation...")
demo.launch(debug=True, share=False)