| import os |
| import re |
| import time |
| import requests |
| import tempfile |
| from pathlib import Path |
| from typing import Optional |
|
|
| from dotenv import load_dotenv |
| |
| from smolagents import CodeAgent, tool, LiteLLMModel |
|
|
| |
| from tools import ( |
| EnhancedSearchTool, |
| EnhancedWikipediaTool, |
| excel_to_markdown, |
| image_file_info, |
| audio_file_info, |
| code_file_read, |
| extract_youtube_info |
| ) |
|
|
| |
| load_dotenv() |
|
|
| |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
| FILE_PATH = f"{DEFAULT_API_URL}/files/" |
|
|
| |
| class RateLimitedModel(LiteLLMModel): |
| """ |
| Wraps the standard smolagents model to enforce a strict 4-second delay |
| between LLM calls to prevent 429 Too Many Requests errors on free APIs. |
| """ |
| def __call__(self, messages, stop_sequences=None, grammar=None, **kwargs): |
| print("\n⏳ [Rate Limit Protection] Pausing for 4 seconds before next LLM call...") |
| time.sleep(4.0) |
| return super().__call__( |
| messages, |
| stop_sequences=stop_sequences, |
| grammar=grammar, |
| **kwargs |
| ) |
|
|
| |
| @tool |
| def enhanced_web_search(query: str) -> str: |
| """Enhanced web search with intelligent query processing. Use for recent/broad web info. |
| Args: |
| query: The specific search query string to look up on the web. |
| """ |
| return EnhancedSearchTool().run(query) |
|
|
| @tool |
| def enhanced_wikipedia(query: str) -> str: |
| """Enhanced Wikipedia search. Use this strictly for factual or encyclopedic knowledge. |
| Args: |
| query: The entity or subject to search for on Wikipedia. |
| """ |
| return EnhancedWikipediaTool().run(query) |
|
|
| @tool |
| def process_excel(excel_path: str, sheet_name: Optional[str] = None) -> str: |
| """Enhanced Excel analysis. Use for spreadsheet-related files (.xlsx, .csv). |
| Args: |
| excel_path: The absolute local file path to the excel or csv file. |
| sheet_name: Optional specific sheet name to analyze. |
| """ |
| return excel_to_markdown(excel_path=excel_path, sheet_name=sheet_name) |
|
|
| @tool |
| def analyze_image(image_path: str, question: str) -> str: |
| """Enhanced image file analysis. Use for images (.png, .jpg, etc.). |
| Args: |
| image_path: The absolute local file path to the image. |
| question: What you want to know about the image. |
| """ |
| return image_file_info(image_path=image_path, question=question) |
|
|
| @tool |
| def process_audio(audio_path: str) -> str: |
| """Enhanced audio processing. Use for sound files (.mp3, .wav, etc.) or transcription. |
| Args: |
| audio_path: The absolute local file path to the audio file. |
| """ |
| return audio_file_info(audio_path=audio_path) |
|
|
| @tool |
| def analyze_code(file_path: str) -> str: |
| """Enhanced code file analysis. Use when files like .py, .js, .html are mentioned. |
| Args: |
| file_path: The absolute local file path to the code file. |
| """ |
| return code_file_read(file_path=file_path) |
|
|
| @tool |
| def extract_youtube(url: str) -> str: |
| """Extracts transcription from a YouTube video link. |
| Args: |
| url: The full YouTube URL to extract text from. |
| """ |
| return extract_youtube_info(url) |
|
|
| |
| def detect_file_type(file_path: str) -> Optional[str]: |
| ext = Path(file_path).suffix.lower() |
| file_type_mapping = { |
| '.xlsx': 'excel', '.xls': 'excel', '.csv': 'excel', |
| '.png': 'image', '.jpg': 'image', '.jpeg': 'image', |
| '.bmp': 'image', '.gif': 'image', '.tiff': 'image', '.webp': 'image', |
| '.mp3': 'audio', '.wav': 'audio', '.ogg': 'audio', |
| '.flac': 'audio', '.m4a': 'audio', '.aac': 'audio', |
| '.py': 'code', '.ipynb': 'code', '.js': 'code', '.html': 'code', |
| '.css': 'code', '.java': 'code', '.cpp': 'code', '.c': 'code', |
| '.sql': 'code', '.r': 'code', '.json': 'code', '.xml': 'code', |
| '.txt': 'text', '.md': 'text', '.pdf': 'document', |
| '.doc': 'document', '.docx': 'document' |
| } |
| return file_type_mapping.get(ext) |
|
|
| def process_file(task_id: str, question_text: str) -> str: |
| file_url = f"{FILE_PATH}{task_id}" |
| try: |
| print(f"[{task_id}] Attempting download: {file_url}") |
| response = requests.get(file_url, timeout=30) |
| response.raise_for_status() |
| except requests.exceptions.RequestException as exc: |
| print(f"[{task_id}] No file downloaded: {str(exc)}") |
| return question_text |
|
|
| content_disposition = response.headers.get("content-disposition", "") |
| filename = task_id |
| filename_match = re.search(r'filename[*]?=(?:"([^"]+)"|([^;]+))', content_disposition) |
| if filename_match: |
| filename = (filename_match.group(1) or filename_match.group(2)).strip() |
|
|
| temp_storage_dir = Path(tempfile.gettempdir()) / "gaia_enhanced_files" / task_id |
| temp_storage_dir.mkdir(parents=True, exist_ok=True) |
| |
| file_path = temp_storage_dir / filename |
| file_path.write_bytes(response.content) |
| |
| file_size = len(response.content) |
| file_type = detect_file_type(filename) |
| |
| enhanced_question = ( |
| f"{question_text}\n\n" |
| f"==================================================\n" |
| f"FILE INFORMATION:\n" |
| f"A file was downloaded for this task and saved locally at:\n" |
| f"{str(file_path)}\n" |
| f"File details:\n" |
| f"- Name: {filename}\n" |
| f"- Size: {file_size:,} bytes\n" |
| f"- Type: {file_type or 'unknown'}\n" |
| f"==================================================\n" |
| ) |
| return enhanced_question |
|
|
| |
| class GaiaAgent: |
| """GAIA Agent powered by smolagents CodeAgent""" |
| |
| def __init__(self): |
| self.model = RateLimitedModel( |
| model_id=os.getenv("GEMINI_MODEL", "gemini/gemini-2.5-flash"), |
| api_key=os.getenv("GEMINI_API_KEY") |
| ) |
| |
| |
| self.agent = CodeAgent( |
| tools=[ |
| enhanced_web_search, |
| enhanced_wikipedia, |
| process_excel, |
| analyze_image, |
| process_audio, |
| analyze_code, |
| extract_youtube |
| ], |
| model=self.model, |
| |
| |
| max_steps=12, |
| |
| additional_authorized_imports=["pandas", "numpy", "re", "math", "json", "collections"] |
| ) |
| |
| print("✓ smolagents CodeAgent Architecture initialized") |
| print("✓ 4-Second Rate Limit Protection Active") |
|
|
| def __call__(self, task_id: str, question: str) -> str: |
| print(f"\n{'='*60}") |
| print(f"[{task_id}] PROCESSING: {question}") |
| |
| |
| processed_question = process_file(task_id, question) |
| |
| try: |
| |
| result = self.agent.run(processed_question) |
| |
| print(f"[{task_id}] FINAL ANSWER: {result}") |
| print(f"{'='*60}") |
| return str(result) |
| |
| except Exception as e: |
| error_msg = f"Critical error in execution: {str(e)}" |
| print(f"[{task_id}] {error_msg}") |
| |
| |
| try: |
| print("Attempting fallback direct response...") |
| return self.model(messages=[{"role": "user", "content": question}]).content |
| except: |
| return error_msg |
|
|
| |
| if __name__ == "__main__": |
| agent = GaiaAgent() |
| |
| sample_questions = [ |
| "What is the current population of Tokyo?", |
| "Tell me about the history of machine learning.", |
| ] |
| |
| print("\n" + "="*80) |
| print("SMOLAGENTS GAIA DEMONSTRATION") |
| print("="*80) |
| |
| for i, question in enumerate(sample_questions): |
| print(f"\nExample {i+1}: {question}") |
| result = agent(f"demo_{i}", question) |