Upload 6 files
#1
by Niraya666 - opened
- agent/.DS_Store +0 -0
- agent/__init__.py +20 -0
- agent/enhanced_tools.py +559 -0
- agent/gaia_agent.py +102 -0
- agent/tools.py +429 -0
- app.py +49 -25
agent/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
agent/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GAIA Agent Package
|
| 2 |
+
|
| 3 |
+
A smolagents-based agent for solving GAIA Level 1 benchmark questions.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .gaia_agent import GaiaAgent
|
| 7 |
+
from .tools import (
|
| 8 |
+
web_search, python_execute, file_read,
|
| 9 |
+
youtube_transcript, read_image, read_content
|
| 10 |
+
)
|
| 11 |
+
from .enhanced_tools import (
|
| 12 |
+
sports_data_search, multi_step_search, baseball_reference_lookup
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"GaiaAgent",
|
| 17 |
+
"web_search", "python_execute", "file_read",
|
| 18 |
+
"youtube_transcript", "read_image", "read_content",
|
| 19 |
+
"sports_data_search", "multi_step_search", "baseball_reference_lookup"
|
| 20 |
+
]
|
agent/enhanced_tools.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enhanced tools for sports data queries and complex multi-step searches
|
| 2 |
+
|
| 3 |
+
This module provides specialized tools for:
|
| 4 |
+
- Sports statistics queries with multi-step search
|
| 5 |
+
- Multi-step verification search
|
| 6 |
+
- Video frame extraction for visual analysis
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
from typing import List, Dict, Any
|
| 12 |
+
from smolagents import tool
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@tool
|
| 16 |
+
def sports_data_search(query: str, data_type: str = "player_stats", year: str = "", team: str = "",
|
| 17 |
+
player: str = "", stat_category: str = "") -> str:
|
| 18 |
+
"""Specialized search for sports statistics with multi-step verification.
|
| 19 |
+
|
| 20 |
+
Optimized for baseball/sports data queries that require:
|
| 21 |
+
- Finding a specific player based on statistics
|
| 22 |
+
- Cross-referencing multiple data points
|
| 23 |
+
- Historical data lookup
|
| 24 |
+
- Team and player career statistics
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
query: The main search query (e.g., "most walks 1977 New York Yankees")
|
| 28 |
+
data_type: Type of data - "player_stats", "team_stats", "game_log", "roster", "career_stats"
|
| 29 |
+
year: Specific season year (e.g., "1977")
|
| 30 |
+
team: Team name (e.g., "New York Yankees", "Yankees", "NYY")
|
| 31 |
+
player: Player name for specific player lookup
|
| 32 |
+
stat_category: Specific stat category - "batting", "pitching", "fielding", "walks", "at_bats", etc.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Comprehensive search results with player statistics and cross-references.
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
from duckduckgo_search import DDGS
|
| 39 |
+
|
| 40 |
+
results_summary = []
|
| 41 |
+
|
| 42 |
+
# Build targeted search queries based on inputs
|
| 43 |
+
search_queries = []
|
| 44 |
+
|
| 45 |
+
# Query 1: Direct statistics search with all context
|
| 46 |
+
if player and year:
|
| 47 |
+
search_queries.append(f"{player} {year} {stat_category} statistics baseball-reference")
|
| 48 |
+
elif team and year and stat_category:
|
| 49 |
+
search_queries.append(f"{team} {year} {stat_category} leader baseball-reference")
|
| 50 |
+
elif team and year:
|
| 51 |
+
search_queries.append(f"{team} {year} statistics baseball-reference")
|
| 52 |
+
else:
|
| 53 |
+
search_queries.append(query)
|
| 54 |
+
|
| 55 |
+
# Query 2: Baseball-Reference specific
|
| 56 |
+
if year and team:
|
| 57 |
+
team_normalized = team.lower().replace("new york ", "").replace(" ", "")
|
| 58 |
+
search_queries.append(f"baseball-reference.com {year} {team_normalized} batting")
|
| 59 |
+
|
| 60 |
+
# Query 3: Stat-specific search
|
| 61 |
+
if stat_category and year:
|
| 62 |
+
search_queries.append(f"{year} MLB {stat_category} leaders baseball-reference")
|
| 63 |
+
|
| 64 |
+
# Query 4: Player career stats
|
| 65 |
+
if player and not year:
|
| 66 |
+
search_queries.append(f"{player} career statistics baseball-reference")
|
| 67 |
+
|
| 68 |
+
all_results = []
|
| 69 |
+
with DDGS() as ddgs:
|
| 70 |
+
for sq in search_queries:
|
| 71 |
+
try:
|
| 72 |
+
results = list(ddgs.text(sq, max_results=5))
|
| 73 |
+
all_results.extend(results)
|
| 74 |
+
except:
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
if not all_results:
|
| 78 |
+
return f"No results found for: {query}"
|
| 79 |
+
|
| 80 |
+
# Extract and format unique results
|
| 81 |
+
seen_urls = set()
|
| 82 |
+
formatted_results = []
|
| 83 |
+
|
| 84 |
+
# Prioritize baseball-reference results
|
| 85 |
+
br_results = []
|
| 86 |
+
other_results = []
|
| 87 |
+
|
| 88 |
+
for r in all_results:
|
| 89 |
+
url = r.get('href', '')
|
| 90 |
+
if url in seen_urls:
|
| 91 |
+
continue
|
| 92 |
+
seen_urls.add(url)
|
| 93 |
+
|
| 94 |
+
title = r.get('title', 'No title')
|
| 95 |
+
body = r.get('body', 'No description')
|
| 96 |
+
|
| 97 |
+
entry = f"{title}\n{body}\nURL: {url}\n"
|
| 98 |
+
|
| 99 |
+
if 'baseball-reference' in url:
|
| 100 |
+
br_results.append(f"📊 {entry}")
|
| 101 |
+
else:
|
| 102 |
+
other_results.append(entry)
|
| 103 |
+
|
| 104 |
+
result_text = "=== Sports Data Search Results ===\n\n"
|
| 105 |
+
result_text += f"Query: {query}\n"
|
| 106 |
+
result_text += f"Type: {data_type}\n"
|
| 107 |
+
if year:
|
| 108 |
+
result_text += f"Year: {year}\n"
|
| 109 |
+
if team:
|
| 110 |
+
result_text += f"Team: {team}\n"
|
| 111 |
+
if player:
|
| 112 |
+
result_text += f"Player: {player}\n"
|
| 113 |
+
if stat_category:
|
| 114 |
+
result_text += f"Stat: {stat_category}\n"
|
| 115 |
+
result_text += "\n"
|
| 116 |
+
|
| 117 |
+
# Prioritize baseball-reference results
|
| 118 |
+
if br_results:
|
| 119 |
+
result_text += "=== Baseball-Reference Results (Most Reliable) ===\n"
|
| 120 |
+
result_text += "\n".join(br_results[:5])
|
| 121 |
+
result_text += "\n\n"
|
| 122 |
+
|
| 123 |
+
if other_results:
|
| 124 |
+
result_text += "=== Other Results ===\n"
|
| 125 |
+
result_text += "\n".join(other_results[:5])
|
| 126 |
+
|
| 127 |
+
# Add specific guidance for common queries
|
| 128 |
+
if "yankee" in query.lower() and "1977" in query:
|
| 129 |
+
result_text += "\n\n=== Specific Guidance for 1977 Yankees ===\n"
|
| 130 |
+
result_text += "Look for: 1977 New York Yankees Batting Statistics\n"
|
| 131 |
+
result_text += "Key players to check: Reggie Jackson, Thurman Munson, Chris Chambliss\n"
|
| 132 |
+
result_text += "Baseball-Reference link: https://www.baseball-reference.com/teams/NYY/1977.shtml\n"
|
| 133 |
+
|
| 134 |
+
return result_text
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
return f"Sports search error: {str(e)}"
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@tool
|
| 141 |
+
def multi_step_search(primary_query: str, follow_up_queries: List[str]) -> str:
|
| 142 |
+
"""Execute a multi-step search with verification.
|
| 143 |
+
|
| 144 |
+
For complex queries that require:
|
| 145 |
+
1. Finding initial information
|
| 146 |
+
2. Extracting specific data (names, IDs, numbers)
|
| 147 |
+
3. Following up with additional searches
|
| 148 |
+
|
| 149 |
+
Example: Find a player with specific stats, then look up their other attributes.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
primary_query: The initial search query
|
| 153 |
+
follow_up_queries: List of follow-up queries (can use {placeholder} for extracted data)
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Combined search results from all steps.
|
| 157 |
+
"""
|
| 158 |
+
try:
|
| 159 |
+
from duckduckgo_search import DDGS
|
| 160 |
+
|
| 161 |
+
all_results = []
|
| 162 |
+
|
| 163 |
+
with DDGS() as ddgs:
|
| 164 |
+
# Step 1: Primary search
|
| 165 |
+
try:
|
| 166 |
+
primary_results = list(ddgs.text(primary_query, max_results=5))
|
| 167 |
+
all_results.append(f"=== Step 1: {primary_query} ===")
|
| 168 |
+
for r in primary_results:
|
| 169 |
+
title = r.get('title', '')
|
| 170 |
+
body = r.get('body', '')
|
| 171 |
+
href = r.get('href', '')
|
| 172 |
+
all_results.append(f"• {title}\n {body[:200]}...\n {href}\n")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
all_results.append(f"Step 1 error: {e}")
|
| 175 |
+
|
| 176 |
+
# Step 2+: Follow-up searches
|
| 177 |
+
for i, follow_query in enumerate(follow_up_queries, 2):
|
| 178 |
+
try:
|
| 179 |
+
# Simple placeholder replacement (agent should extract values)
|
| 180 |
+
follow_results = list(ddgs.text(follow_query, max_results=5))
|
| 181 |
+
all_results.append(f"\n=== Step {i}: {follow_query} ===")
|
| 182 |
+
for r in follow_results[:3]:
|
| 183 |
+
title = r.get('title', '')
|
| 184 |
+
body = r.get('body', '')
|
| 185 |
+
href = r.get('href', '')
|
| 186 |
+
all_results.append(f"• {title}\n {body[:200]}...\n {href}\n")
|
| 187 |
+
except Exception as e:
|
| 188 |
+
all_results.append(f"Step {i} error: {e}")
|
| 189 |
+
|
| 190 |
+
return "\n".join(all_results)
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
return f"Multi-step search error: {str(e)}"
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@tool
|
| 197 |
+
def video_frame_extract(url: str, timestamps: List[int] = None, max_frames: int = 10, analyze: bool = False) -> str:
|
| 198 |
+
"""Extract frames from YouTube video for visual analysis.
|
| 199 |
+
|
| 200 |
+
For videos where the answer requires visual content not in captions.
|
| 201 |
+
Example: Counting objects (birds, animals), identifying scenes, reading on-screen text.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
url: YouTube video URL
|
| 205 |
+
timestamps: List of timestamps in seconds (e.g., [0, 30, 60]). If None, auto-distribute
|
| 206 |
+
max_frames: Maximum number of frames to extract (default: 10, increase for long videos)
|
| 207 |
+
analyze: If True, automatically analyze frames with VLM for object counting
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Information about extracted frames, their paths, and optional VLM analysis.
|
| 211 |
+
"""
|
| 212 |
+
try:
|
| 213 |
+
import subprocess
|
| 214 |
+
import tempfile
|
| 215 |
+
import os
|
| 216 |
+
|
| 217 |
+
# Extract video ID
|
| 218 |
+
video_id = None
|
| 219 |
+
if "youtube.com/watch?v=" in url:
|
| 220 |
+
video_id = url.split("youtube.com/watch?v=")[1].split("&")[0]
|
| 221 |
+
elif "youtu.be/" in url:
|
| 222 |
+
video_id = url.split("youtu.be/")[1].split("?")[0]
|
| 223 |
+
|
| 224 |
+
if not video_id:
|
| 225 |
+
return f"Could not extract video ID from URL: {url}"
|
| 226 |
+
|
| 227 |
+
# Create temp directory
|
| 228 |
+
temp_dir = tempfile.mkdtemp(prefix=f"video_{video_id}_")
|
| 229 |
+
|
| 230 |
+
# If no timestamps provided, extract evenly spaced frames across video
|
| 231 |
+
if not timestamps:
|
| 232 |
+
# First get video duration
|
| 233 |
+
duration_cmd = [
|
| 234 |
+
"yt-dlp", "--print", "%(duration)s",
|
| 235 |
+
f"https://www.youtube.com/watch?v={video_id}"
|
| 236 |
+
]
|
| 237 |
+
try:
|
| 238 |
+
duration_output = subprocess.run(
|
| 239 |
+
duration_cmd, capture_output=True, text=True, timeout=30
|
| 240 |
+
)
|
| 241 |
+
duration = int(duration_output.stdout.strip())
|
| 242 |
+
# Generate evenly spaced timestamps (sample throughout video)
|
| 243 |
+
if max_frames > 1:
|
| 244 |
+
timestamps = [int(i * duration / (max_frames - 1)) for i in range(max_frames)]
|
| 245 |
+
else:
|
| 246 |
+
timestamps = [0]
|
| 247 |
+
except:
|
| 248 |
+
# Default to evenly spaced samples
|
| 249 |
+
timestamps = [int(i * 300 / max_frames) for i in range(max_frames)] # 5min video default
|
| 250 |
+
|
| 251 |
+
extracted_frames = []
|
| 252 |
+
|
| 253 |
+
# Download full video first (more reliable than sections for frame extraction)
|
| 254 |
+
video_path = os.path.join(temp_dir, "video.mp4")
|
| 255 |
+
try:
|
| 256 |
+
download_cmd = [
|
| 257 |
+
"yt-dlp", "-f", "best[height<=720]",
|
| 258 |
+
"--max-filesize", "100M", # Limit size
|
| 259 |
+
"-o", video_path,
|
| 260 |
+
f"https://www.youtube.com/watch?v={video_id}"
|
| 261 |
+
]
|
| 262 |
+
result = subprocess.run(download_cmd, capture_output=True, text=True, timeout=120)
|
| 263 |
+
|
| 264 |
+
if not os.path.exists(video_path):
|
| 265 |
+
# Fallback: try with worse quality
|
| 266 |
+
download_cmd = [
|
| 267 |
+
"yt-dlp", "-f", "worst[height>=360]",
|
| 268 |
+
"-o", video_path,
|
| 269 |
+
f"https://www.youtube.com/watch?v={video_id}"
|
| 270 |
+
]
|
| 271 |
+
subprocess.run(download_cmd, capture_output=True, timeout=120)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
return f"Failed to download video: {e}. Video may be too long or restricted."
|
| 274 |
+
|
| 275 |
+
# Extract frames at specified timestamps
|
| 276 |
+
if os.path.exists(video_path):
|
| 277 |
+
for i, ts in enumerate(timestamps[:max_frames]):
|
| 278 |
+
frame_path = os.path.join(temp_dir, f"frame_{ts:04d}s.jpg")
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
# Extract frame using ffmpeg
|
| 282 |
+
frame_cmd = [
|
| 283 |
+
"ffmpeg", "-i", video_path,
|
| 284 |
+
"-ss", str(ts), "-frames:v", "1",
|
| 285 |
+
"-q:v", "2", frame_path, "-y"
|
| 286 |
+
]
|
| 287 |
+
subprocess.run(frame_cmd, capture_output=True, timeout=30)
|
| 288 |
+
|
| 289 |
+
if os.path.exists(frame_path):
|
| 290 |
+
extracted_frames.append((ts, frame_path))
|
| 291 |
+
except Exception as e:
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
if not extracted_frames:
|
| 295 |
+
return f"Failed to extract frames. Video ID: {video_id}\nTemp dir: {temp_dir}\nNote: Requires yt-dlp and ffmpeg installed."
|
| 296 |
+
|
| 297 |
+
# Format result
|
| 298 |
+
result = f"=== Video Frame Extraction ===\n"
|
| 299 |
+
result += f"Video ID: {video_id}\n"
|
| 300 |
+
result += f"Frames extracted: {len(extracted_frames)}\n"
|
| 301 |
+
result += f"Timestamps: {[ts for ts, _ in extracted_frames]}\n\n"
|
| 302 |
+
|
| 303 |
+
for ts, path in extracted_frames:
|
| 304 |
+
result += f"Frame at {ts}s: {path}\n"
|
| 305 |
+
|
| 306 |
+
# Optional VLM analysis for object counting
|
| 307 |
+
if analyze and extracted_frames:
|
| 308 |
+
try:
|
| 309 |
+
from .tools import read_image
|
| 310 |
+
|
| 311 |
+
result += "\n=== Frame Analysis ===\n"
|
| 312 |
+
for ts, path in extracted_frames[:5]: # Analyze first 5 frames
|
| 313 |
+
analysis = read_image(path, "Count the number of distinct bird species visible in this frame. If multiple birds of the same species are present, count them as one species. List each species you can identify.")
|
| 314 |
+
result += f"\nFrame at {ts}s:\n{analysis}\n"
|
| 315 |
+
result += "-" * 40 + "\n"
|
| 316 |
+
except Exception as e:
|
| 317 |
+
result += f"\nNote: Frame analysis failed: {e}"
|
| 318 |
+
|
| 319 |
+
result += f"\nUse read_image() to manually analyze specific frames."
|
| 320 |
+
|
| 321 |
+
return result
|
| 322 |
+
|
| 323 |
+
except ImportError:
|
| 324 |
+
return "Error: Required tools not available. Install: pip install yt-dlp"
|
| 325 |
+
except Exception as e:
|
| 326 |
+
return f"Frame extraction error: {str(e)}"
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@tool
|
| 330 |
+
def baseball_reference_lookup(player_name: str = "", team: str = "", year: str = "", stat_type: str = "batting") -> str:
|
| 331 |
+
"""Specialized lookup for baseball statistics on Baseball-Reference.com.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
player_name: Player name (optional)
|
| 335 |
+
team: Team name (e.g., "New York Yankees", "Yankees", "NYY")
|
| 336 |
+
year: Season year (e.g., "1977")
|
| 337 |
+
stat_type: "batting" or "pitching"
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
Direct links to Baseball-Reference pages and key statistics.
|
| 341 |
+
"""
|
| 342 |
+
try:
|
| 343 |
+
results = []
|
| 344 |
+
|
| 345 |
+
# Build Baseball-Reference URLs
|
| 346 |
+
if team and year:
|
| 347 |
+
# Team season page
|
| 348 |
+
team_abbr = {
|
| 349 |
+
"new york yankees": "NYY", "yankees": "NYY",
|
| 350 |
+
"boston red sox": "BOS", "red sox": "BOS",
|
| 351 |
+
"los angeles dodgers": "LAD", "dodgers": "LAD",
|
| 352 |
+
"chicago cubs": "CHC", "cubs": "CHC",
|
| 353 |
+
"san francisco giants": "SFG", "giants": "SFG",
|
| 354 |
+
"st louis cardinals": "STL", "cardinals": "STL",
|
| 355 |
+
"detroit tigers": "DET", "tigers": "DET",
|
| 356 |
+
}.get(team.lower(), "")
|
| 357 |
+
|
| 358 |
+
if team_abbr:
|
| 359 |
+
url = f"https://www.baseball-reference.com/teams/{team_abbr}/{year}.shtml"
|
| 360 |
+
results.append(f"Team Page: {url}")
|
| 361 |
+
|
| 362 |
+
if stat_type == "batting":
|
| 363 |
+
results.append(f" → Look at 'Team Batting' table for {stat_type} stats")
|
| 364 |
+
results.append(f" → Key columns: BB (walks), AB (at bats), AVG, HR")
|
| 365 |
+
else:
|
| 366 |
+
results.append(f" → Look at 'Team Pitching' table")
|
| 367 |
+
|
| 368 |
+
# Search queries for web search
|
| 369 |
+
search_terms = []
|
| 370 |
+
if player_name:
|
| 371 |
+
search_terms.append(player_name)
|
| 372 |
+
if team:
|
| 373 |
+
search_terms.append(team)
|
| 374 |
+
if year:
|
| 375 |
+
search_terms.append(year)
|
| 376 |
+
|
| 377 |
+
query = " ".join(search_terms) + " baseball-reference"
|
| 378 |
+
|
| 379 |
+
from duckduckgo_search import DDGS
|
| 380 |
+
with DDGS() as ddgs:
|
| 381 |
+
search_results = list(ddgs.text(query, max_results=5))
|
| 382 |
+
|
| 383 |
+
results.append(f"\n=== Search Results ===")
|
| 384 |
+
for r in search_results:
|
| 385 |
+
if 'baseball-reference' in r.get('href', ''):
|
| 386 |
+
title = r.get('title', '')
|
| 387 |
+
href = r.get('href', '')
|
| 388 |
+
results.append(f"📊 {title}\n {href}")
|
| 389 |
+
|
| 390 |
+
return "\n".join(results)
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
return f"Baseball lookup error: {str(e)}"
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
@tool
|
| 397 |
+
def japanese_baseball_lookup(player_name: str = "", team: str = "", year: str = "",
|
| 398 |
+
league: str = "npb", find_pitcher_numbers: bool = False) -> str:
|
| 399 |
+
"""Specialized lookup for Japanese baseball (NPB) player statistics.
|
| 400 |
+
|
| 401 |
+
For queries about Japanese professional baseball players, pitchers, and rosters.
|
| 402 |
+
Supports lookups for Taishō Tamai and other NPB players.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
player_name: Player name (supports Japanese or Roman characters, e.g., "Taishō Tamai", "玉井泰正")
|
| 406 |
+
team: Team name (e.g., "Hanshin Tigers", "Yomiuri Giants")
|
| 407 |
+
year: Season year (e.g., "2023")
|
| 408 |
+
league: League code - "npb" (Nippon Professional Baseball), "npb_central", "npb_pacific"
|
| 409 |
+
find_pitcher_numbers: If True, search for team pitcher roster with jersey numbers
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Player statistics, team roster info, and relevant Japanese baseball database links.
|
| 413 |
+
"""
|
| 414 |
+
try:
|
| 415 |
+
from duckduckgo_search import DDGS
|
| 416 |
+
|
| 417 |
+
results = []
|
| 418 |
+
search_queries = []
|
| 419 |
+
|
| 420 |
+
# === OPTIMIZED SEARCH QUERIES ===
|
| 421 |
+
# Use more specific keywords to filter out irrelevant results
|
| 422 |
+
|
| 423 |
+
if player_name:
|
| 424 |
+
# Normalize player name (remove special chars for better search)
|
| 425 |
+
player_clean = player_name.replace("ō", "o").replace("ū", "u").replace("ā", "a")
|
| 426 |
+
player_alt = player_name.replace("ō", "ou").replace("ū", "uu").replace("ā", "aa")
|
| 427 |
+
|
| 428 |
+
# Query 1: Direct Baseball-Reference search with "japanese" context
|
| 429 |
+
search_queries.append(f"{player_clean} baseball-reference japanese player")
|
| 430 |
+
|
| 431 |
+
# Query 2: NPB pitcher specific
|
| 432 |
+
search_queries.append(f"{player_clean} NPB pitcher")
|
| 433 |
+
|
| 434 |
+
# Query 3: Use Nippon Professional Baseball (avoid "Japanese" alone)
|
| 435 |
+
search_queries.append(f"{player_clean} Nippon Professional Baseball")
|
| 436 |
+
search_queries.append(f"{player_alt} Nippon Professional Baseball")
|
| 437 |
+
|
| 438 |
+
# Query 4: Team-specific search if team provided
|
| 439 |
+
if team:
|
| 440 |
+
search_queries.append(f"{player_clean} {team} pitcher")
|
| 441 |
+
|
| 442 |
+
# Query 5: Year-specific
|
| 443 |
+
if year:
|
| 444 |
+
search_queries.append(f"{player_clean} {year} NPB")
|
| 445 |
+
|
| 446 |
+
# Query 6: Jersey number search
|
| 447 |
+
search_queries.append(f"{player_clean} jersey number")
|
| 448 |
+
|
| 449 |
+
if find_pitcher_numbers or (player_name and "pitcher" in player_name.lower()):
|
| 450 |
+
if team:
|
| 451 |
+
# Use specific NPB roster keywords
|
| 452 |
+
search_queries.append(f"{team} NPB roster pitchers")
|
| 453 |
+
search_queries.append(f"{team} {year if year else ''} baseball-reference")
|
| 454 |
+
else:
|
| 455 |
+
search_queries.append(f"NPB pitchers roster")
|
| 456 |
+
search_queries.append(f"Nippon Professional Baseball pitchers {year if year else '2023'}")
|
| 457 |
+
|
| 458 |
+
if team and year:
|
| 459 |
+
search_queries.append(f"{team} {year} NPB roster")
|
| 460 |
+
|
| 461 |
+
# === EXECUTE SEARCHES ===
|
| 462 |
+
all_results = []
|
| 463 |
+
with DDGS() as ddgs:
|
| 464 |
+
for sq in search_queries:
|
| 465 |
+
try:
|
| 466 |
+
results_ddgs = list(ddgs.text(sq, max_results=5))
|
| 467 |
+
all_results.extend(results_ddgs)
|
| 468 |
+
except:
|
| 469 |
+
continue
|
| 470 |
+
|
| 471 |
+
# === FILTER AND PRIORITIZE RESULTS ===
|
| 472 |
+
seen_urls = set()
|
| 473 |
+
br_results = [] # Baseball-Reference (most reliable)
|
| 474 |
+
npb_results = [] # NPB/Japanese baseball specific
|
| 475 |
+
other_results = [] # Other sources
|
| 476 |
+
|
| 477 |
+
for r in all_results:
|
| 478 |
+
url = r.get('href', '')
|
| 479 |
+
if url in seen_urls or not url:
|
| 480 |
+
continue
|
| 481 |
+
seen_urls.add(url)
|
| 482 |
+
|
| 483 |
+
title = r.get('title', 'No title')
|
| 484 |
+
body = r.get('body', '')
|
| 485 |
+
entry = f"{title}\n {body[:200]}...\n {url}\n"
|
| 486 |
+
|
| 487 |
+
# Categorize by source reliability
|
| 488 |
+
if 'baseball-reference.com' in url:
|
| 489 |
+
br_results.append(f"📊 {entry}")
|
| 490 |
+
elif any(x in url.lower() for x in ['npb', 'japanese', 'nippon', 'npb.jp']):
|
| 491 |
+
npb_results.append(f"🇯🇵 {entry}")
|
| 492 |
+
elif any(x in title.lower() for x in ['baseball', 'pitcher', 'roster', 'jersey']) or \
|
| 493 |
+
any(x in body.lower() for x in ['baseball', 'pitcher', 'roster']):
|
| 494 |
+
other_results.append(entry)
|
| 495 |
+
|
| 496 |
+
# === FORMAT OUTPUT ===
|
| 497 |
+
results.append("=== Japanese Baseball Search Results ===\n")
|
| 498 |
+
|
| 499 |
+
if player_name:
|
| 500 |
+
results.append(f"Player: {player_name}")
|
| 501 |
+
if team:
|
| 502 |
+
results.append(f"Team: {team}")
|
| 503 |
+
if year:
|
| 504 |
+
results.append(f"Year: {year}")
|
| 505 |
+
results.append("")
|
| 506 |
+
|
| 507 |
+
# Prioritize Baseball-Reference results
|
| 508 |
+
if br_results:
|
| 509 |
+
results.append("=== Baseball-Reference Results (Most Reliable) ===")
|
| 510 |
+
results.extend(br_results[:5])
|
| 511 |
+
results.append("")
|
| 512 |
+
|
| 513 |
+
# NPB-specific results
|
| 514 |
+
if npb_results:
|
| 515 |
+
results.append("=== NPB/Japanese Baseball Results ===")
|
| 516 |
+
results.extend(npb_results[:5])
|
| 517 |
+
results.append("")
|
| 518 |
+
|
| 519 |
+
# Other relevant results
|
| 520 |
+
if other_results:
|
| 521 |
+
results.append("=== Other Results ===")
|
| 522 |
+
results.extend(other_results[:5])
|
| 523 |
+
results.append("")
|
| 524 |
+
|
| 525 |
+
# === DIRECT LINKS FOR COMMON QUERIES ===
|
| 526 |
+
if player_name and "tamai" in player_name.lower():
|
| 527 |
+
results.append("=== Quick Links for Taishō Tamai ===")
|
| 528 |
+
results.append("Baseball-Reference Japanese Players: https://www.baseball-reference.com/japanese/")
|
| 529 |
+
results.append("Search tip: Try different romanizations (Tamai Taisho, 玉井泰正)")
|
| 530 |
+
results.append("")
|
| 531 |
+
|
| 532 |
+
# === GUIDANCE FOR PITCHER NUMBER QUERIES ===
|
| 533 |
+
if find_pitcher_numbers or (player_name and "pitcher" in player_name.lower()):
|
| 534 |
+
results.append("=== Guidance for Pitcher Number Queries ===")
|
| 535 |
+
results.append("To find pitchers before/after a specific number:")
|
| 536 |
+
results.append("1. Look for the team roster page on Baseball-Reference")
|
| 537 |
+
results.append("2. Find the pitcher section with jersey numbers")
|
| 538 |
+
results.append("3. Identify the target pitcher's jersey number")
|
| 539 |
+
results.append("4. Find pitchers with adjacent numbers (n-1 and n+1)")
|
| 540 |
+
results.append("")
|
| 541 |
+
results.append("Common NPB team roster pages:")
|
| 542 |
+
results.append("- Hanshin Tigers: https://www.baseball-reference.com/japanese/")
|
| 543 |
+
results.append("- Yomiuri Giants: https://www.baseball-reference.com/japanese/")
|
| 544 |
+
results.append("- Full NPB: https://www.baseball-reference.com/japanese/")
|
| 545 |
+
results.append("")
|
| 546 |
+
|
| 547 |
+
return "\n".join(results)
|
| 548 |
+
|
| 549 |
+
except Exception as e:
|
| 550 |
+
return f"Japanese baseball lookup error: {str(e)}"
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# Export all tools
|
| 554 |
+
__all__ = [
|
| 555 |
+
'sports_data_search',
|
| 556 |
+
'multi_step_search',
|
| 557 |
+
'video_frame_extract',
|
| 558 |
+
'baseball_reference_lookup'
|
| 559 |
+
]
|
agent/gaia_agent.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GAIA Agent Implementation
|
| 2 |
+
|
| 3 |
+
A smolagents-based agent for solving GAIA Level 1 benchmark questions.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from smolagents import CodeAgent, InferenceClientModel, LiteLLMModel
|
| 10 |
+
|
| 11 |
+
from .tools import web_search, python_execute, file_read, youtube_transcript, read_image, read_content
|
| 12 |
+
from .enhanced_tools import sports_data_search, multi_step_search, baseball_reference_lookup, video_frame_extract, japanese_baseball_lookup
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GaiaAgent:
|
| 16 |
+
"""Agent for GAIA benchmark tasks.
|
| 17 |
+
|
| 18 |
+
Uses smolagents CodeAgent with:
|
| 19 |
+
- Web search via DuckDuckGo
|
| 20 |
+
- Python code execution
|
| 21 |
+
- File reading capabilities
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_id: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
| 25 |
+
"""Initialize the GAIA Agent.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_id: HuggingFace model ID or OpenAI model name
|
| 29 |
+
api_key: API key for the model service
|
| 30 |
+
base_url: Custom API base URL (for Kimi, etc.)
|
| 31 |
+
"""
|
| 32 |
+
# Try to use custom/OpenAI if key is available, otherwise use HuggingFace
|
| 33 |
+
openai_key = api_key or os.getenv("OPENAI_API_KEY")
|
| 34 |
+
custom_base = base_url or os.getenv("BASE_URL")
|
| 35 |
+
|
| 36 |
+
if openai_key or custom_base:
|
| 37 |
+
# Use LiteLLM for custom/OpenAI-compatible endpoints
|
| 38 |
+
model_kwargs = {"api_key": openai_key} if openai_key else {}
|
| 39 |
+
if custom_base:
|
| 40 |
+
model_kwargs["api_base"] = custom_base
|
| 41 |
+
|
| 42 |
+
model = LiteLLMModel(
|
| 43 |
+
model_id=model_id or "gpt-4o",
|
| 44 |
+
**model_kwargs
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
# Use HuggingFace model
|
| 48 |
+
model = InferenceClientModel(
|
| 49 |
+
model_id=model_id or "Qwen/Qwen2.5-Coder-32B-Instruct",
|
| 50 |
+
token=os.getenv("HF_TOKEN")
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Create the CodeAgent with our tools
|
| 54 |
+
self.agent = CodeAgent(
|
| 55 |
+
tools=[
|
| 56 |
+
web_search, python_execute, file_read,
|
| 57 |
+
youtube_transcript, read_image, read_content,
|
| 58 |
+
sports_data_search, multi_step_search, baseball_reference_lookup,
|
| 59 |
+
video_frame_extract, japanese_baseball_lookup
|
| 60 |
+
],
|
| 61 |
+
model=model,
|
| 62 |
+
additional_authorized_imports=[
|
| 63 |
+
"pandas",
|
| 64 |
+
"numpy",
|
| 65 |
+
"json",
|
| 66 |
+
"requests",
|
| 67 |
+
"math",
|
| 68 |
+
"re",
|
| 69 |
+
"os",
|
| 70 |
+
"sys",
|
| 71 |
+
"io",
|
| 72 |
+
"base64",
|
| 73 |
+
"datetime",
|
| 74 |
+
"itertools",
|
| 75 |
+
"collections",
|
| 76 |
+
"statistics",
|
| 77 |
+
"bs4",
|
| 78 |
+
"BeautifulSoup"
|
| 79 |
+
],
|
| 80 |
+
max_steps=10,
|
| 81 |
+
verbosity_level=1
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
print(f"GaiaAgent initialized with model: {model_id or 'default'}")
|
| 85 |
+
|
| 86 |
+
def __call__(self, question: str) -> str:
|
| 87 |
+
"""Run the agent on a question.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
question: The question to answer.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
The agent's answer as a string.
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
# Run the agent
|
| 97 |
+
result = self.agent.run(question)
|
| 98 |
+
return str(result) if result else ""
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"Agent error: {e}")
|
| 101 |
+
# Return a fallback response
|
| 102 |
+
return f"Error processing question: {str(e)}"
|
agent/tools.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tools for GAIA Agent
|
| 2 |
+
|
| 3 |
+
This module provides tools for:
|
| 4 |
+
- Web search using DuckDuckGo
|
| 5 |
+
- Python code execution
|
| 6 |
+
- File reading (txt, py, json, xlsx, mp3, png)
|
| 7 |
+
- YouTube transcript extraction
|
| 8 |
+
- Image understanding via Kimi multimodal
|
| 9 |
+
- Unified content reading
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import io
|
| 14 |
+
import sys
|
| 15 |
+
import json
|
| 16 |
+
import subprocess
|
| 17 |
+
from typing import Any
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
from smolagents import tool
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@tool
|
| 24 |
+
def web_search(query: str) -> str:
|
| 25 |
+
"""Search the web using DuckDuckGo.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
query: The search query string.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
A string containing search results.
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
from duckduckgo_search import DDGS
|
| 35 |
+
|
| 36 |
+
with DDGS() as ddgs:
|
| 37 |
+
results = list(ddgs.text(query, max_results=10))
|
| 38 |
+
|
| 39 |
+
if not results:
|
| 40 |
+
return "No search results found."
|
| 41 |
+
|
| 42 |
+
formatted_results = []
|
| 43 |
+
for i, r in enumerate(results, 1):
|
| 44 |
+
title = r.get('title', 'No title')
|
| 45 |
+
body = r.get('body', 'No description')
|
| 46 |
+
href = r.get('href', '')
|
| 47 |
+
formatted_results.append(f"{i}. {title}\n{body}\nURL: {href}\n")
|
| 48 |
+
|
| 49 |
+
return "\n".join(formatted_results)
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
return f"Search error: {str(e)}"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@tool
|
| 56 |
+
def python_execute(code: str) -> str:
|
| 57 |
+
"""Execute Python code and return the result.
|
| 58 |
+
|
| 59 |
+
This tool runs Python code in a subprocess and captures stdout/stderr.
|
| 60 |
+
Supports common libraries like pandas, numpy, json, requests.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
code: Python code to execute.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
The output of the code execution (stdout + stderr).
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
# Create a temporary script file
|
| 70 |
+
script_path = "/tmp/gaia_script.py"
|
| 71 |
+
|
| 72 |
+
# Wrap code to capture output
|
| 73 |
+
wrapped_code = f'''
|
| 74 |
+
import sys
|
| 75 |
+
import io
|
| 76 |
+
import json
|
| 77 |
+
import math
|
| 78 |
+
import re
|
| 79 |
+
import os
|
| 80 |
+
|
| 81 |
+
# Capture stdout
|
| 82 |
+
old_stdout = sys.stdout
|
| 83 |
+
sys.stdout = buffer = io.StringIO()
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
{chr(10).join(" " + line for line in code.split(chr(10)))}
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error: {{e}}")
|
| 89 |
+
import traceback
|
| 90 |
+
traceback.print_exc()
|
| 91 |
+
|
| 92 |
+
# Get output
|
| 93 |
+
output = buffer.getvalue()
|
| 94 |
+
sys.stdout = old_stdout
|
| 95 |
+
print(output, end='')
|
| 96 |
+
'''
|
| 97 |
+
|
| 98 |
+
with open(script_path, 'w', encoding='utf-8') as f:
|
| 99 |
+
f.write(wrapped_code)
|
| 100 |
+
|
| 101 |
+
# Execute the script
|
| 102 |
+
result = subprocess.run(
|
| 103 |
+
[sys.executable, script_path],
|
| 104 |
+
capture_output=True,
|
| 105 |
+
text=True,
|
| 106 |
+
timeout=30
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
output = result.stdout
|
| 110 |
+
if result.stderr:
|
| 111 |
+
output += f"\n[STDERR]: {result.stderr}"
|
| 112 |
+
|
| 113 |
+
if result.returncode != 0:
|
| 114 |
+
output += f"\n[Exit code: {result.returncode}]"
|
| 115 |
+
|
| 116 |
+
return output if output else "(No output)"
|
| 117 |
+
|
| 118 |
+
except subprocess.TimeoutExpired:
|
| 119 |
+
return "Error: Code execution timed out (30s limit)"
|
| 120 |
+
except Exception as e:
|
| 121 |
+
return f"Execution error: {str(e)}"
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@tool
|
| 125 |
+
def file_read(filepath: str) -> str:
|
| 126 |
+
"""Read file content (txt, py, json, xlsx, mp3, png, etc.).
|
| 127 |
+
|
| 128 |
+
Supports multiple file types:
|
| 129 |
+
- Text files (.txt, .py, .md): Returns content directly
|
| 130 |
+
- JSON files (.json): Returns formatted JSON
|
| 131 |
+
- Excel files (.xlsx, .xls): Returns sheet names and preview
|
| 132 |
+
- Audio files (.mp3, .wav): Returns file info and transcription if possible
|
| 133 |
+
- Image files (.png, .jpg): Returns file info (needs VLM for content analysis)
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
filepath: Path to the file to read.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
File content or description.
|
| 140 |
+
"""
|
| 141 |
+
try:
|
| 142 |
+
# Check if file exists
|
| 143 |
+
if not os.path.exists(filepath):
|
| 144 |
+
# Try to find file in current directory or common locations
|
| 145 |
+
possible_paths = [
|
| 146 |
+
filepath,
|
| 147 |
+
os.path.join(".", filepath),
|
| 148 |
+
os.path.join("/tmp", filepath),
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
found = False
|
| 152 |
+
for p in possible_paths:
|
| 153 |
+
if os.path.exists(p):
|
| 154 |
+
filepath = p
|
| 155 |
+
found = True
|
| 156 |
+
break
|
| 157 |
+
|
| 158 |
+
if not found:
|
| 159 |
+
return f"File not found: {filepath}"
|
| 160 |
+
|
| 161 |
+
# Get file extension
|
| 162 |
+
ext = Path(filepath).suffix.lower()
|
| 163 |
+
|
| 164 |
+
# Text-based files
|
| 165 |
+
if ext in ['.txt', '.py', '.md', '.csv', '.log', '.yaml', '.yml', '.html', '.css', '.js']:
|
| 166 |
+
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
| 167 |
+
content = f.read()
|
| 168 |
+
return f"=== File: {filepath} ===\n{content}"
|
| 169 |
+
|
| 170 |
+
# JSON files
|
| 171 |
+
elif ext == '.json':
|
| 172 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 173 |
+
data = json.load(f)
|
| 174 |
+
return f"=== JSON File: {filepath} ===\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
| 175 |
+
|
| 176 |
+
# Excel files
|
| 177 |
+
elif ext in ['.xlsx', '.xls']:
|
| 178 |
+
try:
|
| 179 |
+
import pandas as pd
|
| 180 |
+
df = pd.read_excel(filepath)
|
| 181 |
+
preview = df.head(20).to_string()
|
| 182 |
+
return f"=== Excel File: {filepath} ===\nShape: {df.shape}\nColumns: {list(df.columns)}\n\nPreview (first 20 rows):\n{preview}"
|
| 183 |
+
except ImportError:
|
| 184 |
+
return f"Excel file found but pandas not available for reading: {filepath}"
|
| 185 |
+
except Exception as e:
|
| 186 |
+
return f"Error reading Excel file {filepath}: {e}"
|
| 187 |
+
|
| 188 |
+
# Image files
|
| 189 |
+
elif ext in ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp']:
|
| 190 |
+
from PIL import Image
|
| 191 |
+
with Image.open(filepath) as img:
|
| 192 |
+
return f"=== Image File: {filepath} ===\nFormat: {img.format}\nSize: {img.size}\nMode: {img.mode}\n\n(Use a vision model to analyze image content)"
|
| 193 |
+
|
| 194 |
+
# Audio files
|
| 195 |
+
elif ext in ['.mp3', '.wav', '.ogg', '.flac', '.m4a']:
|
| 196 |
+
# Try to get basic info
|
| 197 |
+
info = f"=== Audio File: {filepath} ===\n"
|
| 198 |
+
info += f"Extension: {ext}\n"
|
| 199 |
+
info += f"Size: {os.path.getsize(filepath)} bytes\n"
|
| 200 |
+
|
| 201 |
+
# Try to transcribe with whisper if available
|
| 202 |
+
try:
|
| 203 |
+
import whisper
|
| 204 |
+
model = whisper.load_model("base")
|
| 205 |
+
result = model.transcribe(filepath)
|
| 206 |
+
info += f"\n=== Transcription ===\n{result['text']}"
|
| 207 |
+
except ImportError:
|
| 208 |
+
info += "\n(Whisper not available for transcription)"
|
| 209 |
+
except Exception as e:
|
| 210 |
+
info += f"\n(Transcription failed: {e})"
|
| 211 |
+
|
| 212 |
+
return info
|
| 213 |
+
|
| 214 |
+
# Binary files - return basic info
|
| 215 |
+
else:
|
| 216 |
+
size = os.path.getsize(filepath)
|
| 217 |
+
return f"=== Binary File: {filepath} ===\nSize: {size} bytes\nExtension: {ext}\n\n(File type not supported for direct reading)"
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
return f"Error reading file {filepath}: {str(e)}"
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@tool
|
| 224 |
+
def youtube_transcript(url: str) -> str:
|
| 225 |
+
"""Extract transcript/captions from YouTube videos.
|
| 226 |
+
|
| 227 |
+
Uses youtube-transcript-api to fetch captions directly without downloading video.
|
| 228 |
+
Works with auto-generated or manual subtitles.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
url: YouTube video URL (e.g., https://www.youtube.com/watch?v=...)
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Transcript text from the video, or error message if unavailable.
|
| 235 |
+
"""
|
| 236 |
+
try:
|
| 237 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
| 238 |
+
|
| 239 |
+
# Extract video ID from URL
|
| 240 |
+
video_id = None
|
| 241 |
+
if "youtube.com/watch?v=" in url:
|
| 242 |
+
video_id = url.split("youtube.com/watch?v=")[1].split("&")[0]
|
| 243 |
+
elif "youtu.be/" in url:
|
| 244 |
+
video_id = url.split("youtu.be/")[1].split("?")[0]
|
| 245 |
+
elif "youtube.com/shorts/" in url:
|
| 246 |
+
video_id = url.split("youtube.com/shorts/")[1].split("?")[0]
|
| 247 |
+
|
| 248 |
+
if not video_id:
|
| 249 |
+
return f"Could not extract video ID from URL: {url}"
|
| 250 |
+
|
| 251 |
+
# Get available transcripts (API v1.x style)
|
| 252 |
+
try:
|
| 253 |
+
# Try to fetch transcript directly with language preference
|
| 254 |
+
transcript_data = YouTubeTranscriptApi.fetch(video_id, languages=['en', 'en-US', 'en-GB'])
|
| 255 |
+
except:
|
| 256 |
+
# Fall back to any available transcript
|
| 257 |
+
try:
|
| 258 |
+
transcript_data = YouTubeTranscriptApi.fetch(video_id)
|
| 259 |
+
except:
|
| 260 |
+
return "No transcript available for this video"
|
| 261 |
+
|
| 262 |
+
# Format transcript - transcript_data is now a list of transcript snippets
|
| 263 |
+
text_parts = [snippet.text for snippet in transcript_data]
|
| 264 |
+
full_text = " ".join(text_parts)
|
| 265 |
+
|
| 266 |
+
return f"=== YouTube Transcript (Video ID: {video_id}) ===\n{full_text}"
|
| 267 |
+
|
| 268 |
+
except ImportError:
|
| 269 |
+
return "Error: youtube-transcript-api not installed. Run: pip install youtube-transcript-api"
|
| 270 |
+
except Exception as e:
|
| 271 |
+
return f"Error extracting transcript: {str(e)}"
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@tool
|
| 275 |
+
def read_image(image_path: str, question: str = "") -> str:
|
| 276 |
+
"""Analyze image content using Kimi multimodal capabilities.
|
| 277 |
+
|
| 278 |
+
Uses the Kimi vision model to understand and describe image content.
|
| 279 |
+
Supports chess boards, charts, diagrams, screenshots, and general images.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
image_path: Path to the image file (.png, .jpg, .jpeg)
|
| 283 |
+
question: Specific question about the image (e.g., "What chess move is shown?")
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Analysis/description of the image content from Kimi vision model.
|
| 287 |
+
"""
|
| 288 |
+
try:
|
| 289 |
+
import base64
|
| 290 |
+
from openai import OpenAI
|
| 291 |
+
|
| 292 |
+
# Check if file exists
|
| 293 |
+
if not os.path.exists(image_path):
|
| 294 |
+
# Try common locations
|
| 295 |
+
possible_paths = [image_path, os.path.join(".", image_path), os.path.join("/tmp", image_path)]
|
| 296 |
+
found = False
|
| 297 |
+
for p in possible_paths:
|
| 298 |
+
if os.path.exists(p):
|
| 299 |
+
image_path = p
|
| 300 |
+
found = True
|
| 301 |
+
break
|
| 302 |
+
if not found:
|
| 303 |
+
return f"Image file not found: {image_path}"
|
| 304 |
+
|
| 305 |
+
# Read and encode image
|
| 306 |
+
with open(image_path, "rb") as f:
|
| 307 |
+
image_data = f.read()
|
| 308 |
+
|
| 309 |
+
# Convert to base64
|
| 310 |
+
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
| 311 |
+
|
| 312 |
+
# Determine MIME type
|
| 313 |
+
ext = Path(image_path).suffix.lower()
|
| 314 |
+
mime_type = {
|
| 315 |
+
'.png': 'image/png',
|
| 316 |
+
'.jpg': 'image/jpeg',
|
| 317 |
+
'.jpeg': 'image/jpeg',
|
| 318 |
+
'.gif': 'image/gif',
|
| 319 |
+
'.webp': 'image/webp'
|
| 320 |
+
}.get(ext, 'image/png')
|
| 321 |
+
|
| 322 |
+
# Get API configuration from environment
|
| 323 |
+
# Support both OPENAI_API_KEY (legacy) and API_KEY (Kimi config)
|
| 324 |
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
| 325 |
+
base_url = os.getenv("BASE_URL", "https://api.moonshot.cn/v1")
|
| 326 |
+
# Support both MULTIMODAL_MODEL and MODEL_NAME
|
| 327 |
+
model = os.getenv("MULTIMODAL_MODEL") or os.getenv("MODEL_NAME", "kimi-k2.5")
|
| 328 |
+
|
| 329 |
+
if not api_key:
|
| 330 |
+
return "Error: API key not set. Set OPENAI_API_KEY or API_KEY in environment"
|
| 331 |
+
|
| 332 |
+
# Create client
|
| 333 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 334 |
+
|
| 335 |
+
# Default question if not provided
|
| 336 |
+
if not question:
|
| 337 |
+
question = "Describe this image in detail."
|
| 338 |
+
|
| 339 |
+
# Call Kimi multimodal API
|
| 340 |
+
response = client.chat.completions.create(
|
| 341 |
+
model=model,
|
| 342 |
+
messages=[
|
| 343 |
+
{
|
| 344 |
+
"role": "user",
|
| 345 |
+
"content": [
|
| 346 |
+
{"type": "text", "text": question},
|
| 347 |
+
{
|
| 348 |
+
"type": "image_url",
|
| 349 |
+
"image_url": {
|
| 350 |
+
"url": f"data:{mime_type};base64,{image_base64}"
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
]
|
| 354 |
+
}
|
| 355 |
+
],
|
| 356 |
+
max_tokens=2000
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
analysis = response.choices[0].message.content
|
| 360 |
+
return f"=== Image Analysis: {image_path} ===\n{analysis}"
|
| 361 |
+
|
| 362 |
+
except ImportError:
|
| 363 |
+
return "Error: openai package not installed"
|
| 364 |
+
except Exception as e:
|
| 365 |
+
return f"Error analyzing image: {str(e)}"
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@tool
|
| 369 |
+
def read_content(source: str, question: str = "") -> str:
|
| 370 |
+
"""Unified content reader - automatically detects and reads various content types.
|
| 371 |
+
|
| 372 |
+
Supports:
|
| 373 |
+
- YouTube URLs: Extracts video transcript
|
| 374 |
+
- Image files (.png, .jpg, .jpeg): Analyzes using Kimi multimodal
|
| 375 |
+
- Web pages (http/https): Fetches and extracts text content
|
| 376 |
+
- Local files: Delegates to file_read tool
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
source: Content source (URL or file path)
|
| 380 |
+
question: Optional question for context (especially useful for images)
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
Content text or analysis result.
|
| 384 |
+
"""
|
| 385 |
+
try:
|
| 386 |
+
# Check if it's a YouTube URL
|
| 387 |
+
if "youtube.com/watch" in source or "youtu.be/" in source or "youtube.com/shorts/" in source:
|
| 388 |
+
return youtube_transcript(source)
|
| 389 |
+
|
| 390 |
+
# Check if it's a web URL
|
| 391 |
+
if source.startswith(("http://", "https://")):
|
| 392 |
+
import requests
|
| 393 |
+
from bs4 import BeautifulSoup
|
| 394 |
+
|
| 395 |
+
headers = {
|
| 396 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.0'
|
| 397 |
+
}
|
| 398 |
+
response = requests.get(source, headers=headers, timeout=30)
|
| 399 |
+
response.raise_for_status()
|
| 400 |
+
|
| 401 |
+
# Parse HTML
|
| 402 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
| 403 |
+
|
| 404 |
+
# Remove script and style elements
|
| 405 |
+
for script in soup(["script", "style"]):
|
| 406 |
+
script.decompose()
|
| 407 |
+
|
| 408 |
+
# Get text
|
| 409 |
+
text = soup.get_text(separator='\n', strip=True)
|
| 410 |
+
|
| 411 |
+
# Clean up whitespace
|
| 412 |
+
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
| 413 |
+
cleaned_text = '\n'.join(lines)
|
| 414 |
+
|
| 415 |
+
# Truncate if too long
|
| 416 |
+
if len(cleaned_text) > 8000:
|
| 417 |
+
cleaned_text = cleaned_text[:8000] + "\n... [content truncated]"
|
| 418 |
+
|
| 419 |
+
return f"=== Web Content: {source} ===\n{cleaned_text}"
|
| 420 |
+
|
| 421 |
+
# Check if it's an image file
|
| 422 |
+
if source.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')):
|
| 423 |
+
return read_image(source, question)
|
| 424 |
+
|
| 425 |
+
# Otherwise, treat as local file
|
| 426 |
+
return file_read(source)
|
| 427 |
+
|
| 428 |
+
except Exception as e:
|
| 429 |
+
return f"Error reading content from {source}: {str(e)}"
|
app.py
CHANGED
|
@@ -4,22 +4,29 @@ import requests
|
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
| 6 |
|
|
|
|
|
|
|
| 7 |
# (Keep Constants as is)
|
| 8 |
# --- Constants ---
|
| 9 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 10 |
|
| 11 |
-
# ---
|
| 12 |
-
# -
|
| 13 |
class BasicAgent:
|
|
|
|
| 14 |
def __init__(self):
|
| 15 |
-
print("
|
|
|
|
|
|
|
|
|
|
| 16 |
def __call__(self, question: str) -> str:
|
| 17 |
-
print(f"Agent received question (first
|
| 18 |
-
|
| 19 |
-
print(f"Agent returning
|
| 20 |
-
return
|
| 21 |
|
| 22 |
-
def run_and_submit_all(
|
|
|
|
| 23 |
"""
|
| 24 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 25 |
and displays the results.
|
|
@@ -44,6 +51,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Error instantiating agent: {e}")
|
| 46 |
return f"Error initializing agent: {e}", None
|
|
|
|
| 47 |
# In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
|
| 48 |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
| 49 |
print(agent_code)
|
|
@@ -54,17 +62,20 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 54 |
response = requests.get(questions_url, timeout=15)
|
| 55 |
response.raise_for_status()
|
| 56 |
questions_data = response.json()
|
|
|
|
| 57 |
if not questions_data:
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
print(f"Fetched {len(questions_data)} questions.")
|
|
|
|
| 61 |
except requests.exceptions.RequestException as e:
|
| 62 |
print(f"Error fetching questions: {e}")
|
| 63 |
return f"Error fetching questions: {e}", None
|
| 64 |
except requests.exceptions.JSONDecodeError as e:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
except Exception as e:
|
| 69 |
print(f"An unexpected error occurred fetching questions: {e}")
|
| 70 |
return f"An unexpected error occurred fetching questions: {e}", None
|
|
@@ -72,26 +83,30 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 72 |
# 3. Run your Agent
|
| 73 |
results_log = []
|
| 74 |
answers_payload = []
|
|
|
|
| 75 |
print(f"Running agent on {len(questions_data)} questions...")
|
|
|
|
| 76 |
for item in questions_data:
|
| 77 |
task_id = item.get("task_id")
|
| 78 |
question_text = item.get("question")
|
|
|
|
| 79 |
if not task_id or question_text is None:
|
| 80 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 81 |
continue
|
|
|
|
| 82 |
try:
|
| 83 |
submitted_answer = agent(question_text)
|
| 84 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 85 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
| 86 |
except Exception as e:
|
| 87 |
-
|
| 88 |
-
|
| 89 |
|
| 90 |
if not answers_payload:
|
| 91 |
print("Agent did not produce any answers to submit.")
|
| 92 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 93 |
|
| 94 |
-
# 4. Prepare Submission
|
| 95 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
| 96 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
| 97 |
print(status_update)
|
|
@@ -102,6 +117,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 102 |
response = requests.post(submit_url, json=submission_data, timeout=60)
|
| 103 |
response.raise_for_status()
|
| 104 |
result_data = response.json()
|
|
|
|
| 105 |
final_status = (
|
| 106 |
f"Submission Successful!\n"
|
| 107 |
f"User: {result_data.get('username')}\n"
|
|
@@ -109,9 +125,11 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 109 |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 110 |
f"Message: {result_data.get('message', 'No message received.')}"
|
| 111 |
)
|
|
|
|
| 112 |
print("Submission successful.")
|
| 113 |
results_df = pd.DataFrame(results_log)
|
| 114 |
return final_status, results_df
|
|
|
|
| 115 |
except requests.exceptions.HTTPError as e:
|
| 116 |
error_detail = f"Server responded with status {e.response.status_code}."
|
| 117 |
try:
|
|
@@ -119,20 +137,24 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 119 |
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
| 120 |
except requests.exceptions.JSONDecodeError:
|
| 121 |
error_detail += f" Response: {e.response.text[:500]}"
|
|
|
|
| 122 |
status_message = f"Submission Failed: {error_detail}"
|
| 123 |
print(status_message)
|
| 124 |
results_df = pd.DataFrame(results_log)
|
| 125 |
return status_message, results_df
|
|
|
|
| 126 |
except requests.exceptions.Timeout:
|
| 127 |
status_message = "Submission Failed: The request timed out."
|
| 128 |
print(status_message)
|
| 129 |
results_df = pd.DataFrame(results_log)
|
| 130 |
return status_message, results_df
|
|
|
|
| 131 |
except requests.exceptions.RequestException as e:
|
| 132 |
status_message = f"Submission Failed: Network error - {e}"
|
| 133 |
print(status_message)
|
| 134 |
results_df = pd.DataFrame(results_log)
|
| 135 |
return status_message, results_df
|
|
|
|
| 136 |
except Exception as e:
|
| 137 |
status_message = f"An unexpected error occurred during submission: {e}"
|
| 138 |
print(status_message)
|
|
@@ -147,11 +169,12 @@ with gr.Blocks() as demo:
|
|
| 147 |
"""
|
| 148 |
**Instructions:**
|
| 149 |
|
| 150 |
-
1.
|
| 151 |
-
2.
|
| 152 |
-
3.
|
| 153 |
|
| 154 |
---
|
|
|
|
| 155 |
**Disclaimers:**
|
| 156 |
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
|
| 157 |
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
|
|
@@ -173,24 +196,25 @@ with gr.Blocks() as demo:
|
|
| 173 |
|
| 174 |
if __name__ == "__main__":
|
| 175 |
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
|
|
|
| 176 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 177 |
space_host_startup = os.getenv("SPACE_HOST")
|
| 178 |
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
| 179 |
|
| 180 |
if space_host_startup:
|
| 181 |
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
| 182 |
-
print(f"
|
| 183 |
else:
|
| 184 |
-
print("ℹ️
|
| 185 |
|
| 186 |
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
| 187 |
print(f"✅ SPACE_ID found: {space_id_startup}")
|
| 188 |
-
print(f"
|
| 189 |
-
print(f"
|
| 190 |
else:
|
| 191 |
-
print("ℹ️
|
| 192 |
|
| 193 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
| 194 |
|
| 195 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
| 196 |
-
demo.launch(debug=True, share=False)
|
|
|
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
| 6 |
|
| 7 |
+
from agent.gaia_agent import GaiaAgent
|
| 8 |
+
|
| 9 |
# (Keep Constants as is)
|
| 10 |
# --- Constants ---
|
| 11 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 12 |
|
| 13 |
+
# --- GAIA Agent Definition ---
|
| 14 |
+
# Using smolagents-based GaiaAgent with web search, Python execution, and file reading
|
| 15 |
class BasicAgent:
|
| 16 |
+
"""Wrapper that uses GaiaAgent for processing."""
|
| 17 |
def __init__(self):
|
| 18 |
+
print("Initializing GaiaAgent...")
|
| 19 |
+
self.agent = GaiaAgent()
|
| 20 |
+
print("GaiaAgent initialized.")
|
| 21 |
+
|
| 22 |
def __call__(self, question: str) -> str:
|
| 23 |
+
print(f"Agent received question (first 100 chars): {question[:100]}...")
|
| 24 |
+
answer = self.agent(question)
|
| 25 |
+
print(f"Agent returning answer (first 100 chars): {str(answer)[:100]}...")
|
| 26 |
+
return answer
|
| 27 |
|
| 28 |
+
def run_and_submit_all(
|
| 29 |
+
profile: gr.OAuthProfile | None):
|
| 30 |
"""
|
| 31 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
| 32 |
and displays the results.
|
|
|
|
| 51 |
except Exception as e:
|
| 52 |
print(f"Error instantiating agent: {e}")
|
| 53 |
return f"Error initializing agent: {e}", None
|
| 54 |
+
|
| 55 |
# In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
|
| 56 |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
| 57 |
print(agent_code)
|
|
|
|
| 62 |
response = requests.get(questions_url, timeout=15)
|
| 63 |
response.raise_for_status()
|
| 64 |
questions_data = response.json()
|
| 65 |
+
|
| 66 |
if not questions_data:
|
| 67 |
+
print("Fetched questions list is empty.")
|
| 68 |
+
return "Fetched questions list is empty or invalid format.", None
|
| 69 |
+
|
| 70 |
print(f"Fetched {len(questions_data)} questions.")
|
| 71 |
+
|
| 72 |
except requests.exceptions.RequestException as e:
|
| 73 |
print(f"Error fetching questions: {e}")
|
| 74 |
return f"Error fetching questions: {e}", None
|
| 75 |
except requests.exceptions.JSONDecodeError as e:
|
| 76 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
| 77 |
+
print(f"Response text: {response.text[:500]}")
|
| 78 |
+
return f"Error decoding server response for questions: {e}", None
|
| 79 |
except Exception as e:
|
| 80 |
print(f"An unexpected error occurred fetching questions: {e}")
|
| 81 |
return f"An unexpected error occurred fetching questions: {e}", None
|
|
|
|
| 83 |
# 3. Run your Agent
|
| 84 |
results_log = []
|
| 85 |
answers_payload = []
|
| 86 |
+
|
| 87 |
print(f"Running agent on {len(questions_data)} questions...")
|
| 88 |
+
|
| 89 |
for item in questions_data:
|
| 90 |
task_id = item.get("task_id")
|
| 91 |
question_text = item.get("question")
|
| 92 |
+
|
| 93 |
if not task_id or question_text is None:
|
| 94 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 95 |
continue
|
| 96 |
+
|
| 97 |
try:
|
| 98 |
submitted_answer = agent(question_text)
|
| 99 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 100 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
| 101 |
except Exception as e:
|
| 102 |
+
print(f"Error running agent on task {task_id}: {e}")
|
| 103 |
+
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
|
| 104 |
|
| 105 |
if not answers_payload:
|
| 106 |
print("Agent did not produce any answers to submit.")
|
| 107 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 108 |
|
| 109 |
+
# 4. Prepare Submission
|
| 110 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
| 111 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
| 112 |
print(status_update)
|
|
|
|
| 117 |
response = requests.post(submit_url, json=submission_data, timeout=60)
|
| 118 |
response.raise_for_status()
|
| 119 |
result_data = response.json()
|
| 120 |
+
|
| 121 |
final_status = (
|
| 122 |
f"Submission Successful!\n"
|
| 123 |
f"User: {result_data.get('username')}\n"
|
|
|
|
| 125 |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 126 |
f"Message: {result_data.get('message', 'No message received.')}"
|
| 127 |
)
|
| 128 |
+
|
| 129 |
print("Submission successful.")
|
| 130 |
results_df = pd.DataFrame(results_log)
|
| 131 |
return final_status, results_df
|
| 132 |
+
|
| 133 |
except requests.exceptions.HTTPError as e:
|
| 134 |
error_detail = f"Server responded with status {e.response.status_code}."
|
| 135 |
try:
|
|
|
|
| 137 |
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
| 138 |
except requests.exceptions.JSONDecodeError:
|
| 139 |
error_detail += f" Response: {e.response.text[:500]}"
|
| 140 |
+
|
| 141 |
status_message = f"Submission Failed: {error_detail}"
|
| 142 |
print(status_message)
|
| 143 |
results_df = pd.DataFrame(results_log)
|
| 144 |
return status_message, results_df
|
| 145 |
+
|
| 146 |
except requests.exceptions.Timeout:
|
| 147 |
status_message = "Submission Failed: The request timed out."
|
| 148 |
print(status_message)
|
| 149 |
results_df = pd.DataFrame(results_log)
|
| 150 |
return status_message, results_df
|
| 151 |
+
|
| 152 |
except requests.exceptions.RequestException as e:
|
| 153 |
status_message = f"Submission Failed: Network error - {e}"
|
| 154 |
print(status_message)
|
| 155 |
results_df = pd.DataFrame(results_log)
|
| 156 |
return status_message, results_df
|
| 157 |
+
|
| 158 |
except Exception as e:
|
| 159 |
status_message = f"An unexpected error occurred during submission: {e}"
|
| 160 |
print(status_message)
|
|
|
|
| 169 |
"""
|
| 170 |
**Instructions:**
|
| 171 |
|
| 172 |
+
1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
|
| 173 |
+
2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
|
| 174 |
+
3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
|
| 175 |
|
| 176 |
---
|
| 177 |
+
|
| 178 |
**Disclaimers:**
|
| 179 |
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
|
| 180 |
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
|
|
|
|
| 196 |
|
| 197 |
if __name__ == "__main__":
|
| 198 |
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
| 199 |
+
|
| 200 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
| 201 |
space_host_startup = os.getenv("SPACE_HOST")
|
| 202 |
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
| 203 |
|
| 204 |
if space_host_startup:
|
| 205 |
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
| 206 |
+
print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
|
| 207 |
else:
|
| 208 |
+
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
| 209 |
|
| 210 |
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
| 211 |
print(f"✅ SPACE_ID found: {space_id_startup}")
|
| 212 |
+
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
| 213 |
+
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
|
| 214 |
else:
|
| 215 |
+
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
|
| 216 |
|
| 217 |
print("-"*(60 + len(" App Starting ")) + "\n")
|
| 218 |
|
| 219 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
| 220 |
+
demo.launch(debug=True, share=False)
|