Test_Magus / agent11.py
SergeyO7's picture
Rename agent.py to agent11.py
dad18de verified
from smolagents import CodeAgent, LiteLLMModel, Tool, DuckDuckGoSearchTool, WikipediaSearchTool
from token_bucket import Limiter, MemoryStorage
from tenacity import retry, stop_after_attempt, wait_exponential
from langchain_community.document_loaders import ArxivLoader
from sentence_transformers import SentenceTransformer
from bs4 import BeautifulSoup
from datetime import datetime
import pandas as pd
import numpy as np
import requests
import asyncio
import whisper
import yaml
import os
import re
import json
from typing import Optional
# --------------------------
# Core Tools from Previous Implementation
# --------------------------
class VisitWebpageTool(Tool):
name = "visit_webpage"
description = "Visits a webpage and returns its content as markdown"
inputs = {'url': {'type': 'string', 'description': 'The URL to visit'}}
output_type = "string"
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def forward(self, url: str) -> str:
try:
response = requests.get(url, timeout=30)
response.raise_for_status()
return markdownify(response.text).strip()
except Exception as e:
return f"Error fetching webpage: {str(e)}"
class DownloadTaskAttachmentTool(Tool):
name = "download_file"
description = "Downloads files from the task API"
inputs = {'task_id': {'type': 'string', 'description': 'The task ID to download'}}
output_type = "string"
def forward(self, task_id: str) -> str:
api_url = os.getenv("TASK_API_URL", "https://agents-course-unit4-scoring.hf.space")
file_url = f"{api_url}/files/{task_id}"
try:
response = requests.get(file_url, stream=True, timeout=30)
response.raise_for_status()
# File type detection
content_type = response.headers.get('Content-Type', '')
extension = self._get_extension(content_type)
os.makedirs("downloads", exist_ok=True)
file_path = f"downloads/{task_id}{extension}"
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return file_path
except Exception as e:
raise RuntimeError(f"Download failed: {str(e)}")
def _get_extension(self, content_type: str) -> str:
type_map = {
'image/png': '.png',
'image/jpeg': '.jpg',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'audio/mpeg': '.mp3',
'application/pdf': '.pdf',
'text/x-python': '.py'
}
return type_map.get(content_type.split(';')[0], '.bin')
class ArxivSearchTool(Tool):
name = "arxiv_search"
description = "Searches academic papers on Arxiv"
inputs = {'query': {'type': 'string', 'description': 'Search query'}}
output_type = "string"
def forward(self, query: str) -> str:
try:
loader = ArxivLoader(query=query, load_max_docs=3)
docs = loader.load()
return "\n\n".join([
f"Title: {doc.metadata['Title']}\n"
f"Authors: {doc.metadata['Authors']}\n"
f"Summary: {doc.page_content[:500]}..."
for doc in docs
])
except Exception as e:
return f"Arxiv search failed: {str(e)}"
class SpeechToTextTool(Tool):
name = "speech_to_text"
description = "Converts audio files to text"
inputs = {'audio_path': {'type': 'string', 'description': 'Path to audio file'}}
output_type = "string"
def __init__(self):
self.model = whisper.load_model("base")
def forward(self, audio_path: str) -> str:
if not os.path.exists(audio_path):
return f"File not found: {audio_path}"
return self.model.transcribe(audio_path).get("text", "")
# --------------------------
# Enhanced Tools with Validation
# --------------------------
class ValidatedExcelReader(Tool):
name = "excel_reader"
description = "Reads and validates Excel files"
inputs = {
'file_path': {'type': 'string', 'description': 'Path to Excel file'},
'schema': {'type': 'object', 'description': 'Validation schema', 'nullable': True}
}
output_type = "string"
def forward(self, file_path: str, schema: dict = None) -> str:
df = pd.read_excel(file_path)
if schema:
validation = ValidationPipeline().validate(df, schema)
if not validation['valid']:
raise ValueError(f"Data validation failed: {validation['errors']}")
return df.to_markdown()
# --------------------------
# Integrated Universal Loader
# --------------------------
class UniversalLoader(Tool):
name = "universal_loader"
description = "Loads various file types and web content using appropriate sub-tools."
inputs = {
'source': {
'type': 'string',
'description': 'Type of source to load (web/excel/audio/arxiv)'
},
'task_id': {
'type': 'string',
'description': 'Task ID for attachments',
'nullable': True
}
}
output_type = "string"
def __init__(self):
self.loaders = {
'excel': ValidatedExcelReader(),
'audio': SpeechToTextTool(),
'arxiv': ArxivSearchTool(),
'web': VisitWebpageTool()
}
def forward(self, source: str, task_id: str = None) -> str:
try:
if source == "attachment":
file_path = DownloadTaskAttachmentTool()(task_id)
return self._load_by_type(file_path)
return self.loaders[source].forward(task_id)
except Exception as e:
return self._fallback(source, task_id)
def _load_by_type(self, file_path: str) -> str:
ext = file_path.split('.')[-1].lower()
loader_map = {
'xlsx': 'excel',
'mp3': 'audio',
'pdf': 'arxiv'
}
return self.loaders[loader_map.get(ext, 'web')].forward(file_path)
def _fallback(self, source: str, context: str) -> str:
return CrossVerifiedSearch()(f"{source} {context}")
# --------------------------
# Validation Pipeline
# --------------------------
class ValidationPipeline:
VALIDATORS = {
'numeric': {
'check': lambda x: pd.api.types.is_numeric_dtype(x),
'error': "Non-numeric value found in numeric field"
},
'temporal': {
'check': lambda x: pd.api.types.is_datetime64_any_dtype(x),
'error': "Invalid date format detected"
},
'categorical': {
'check': lambda x: x.isin(x.dropna().unique()),
'error': "Invalid category value detected"
}
}
def validate(self, data, schema: dict):
errors = []
for field, config in schema.items():
validator = self.VALIDATORS.get(config['type'])
if not validator['check'](data[field]):
errors.append(f"{field}: {validator['error']}")
return {
'valid': len(errors) == 0,
'errors': errors,
'confidence': 1.0 - (len(errors) / len(schema))
}
# --------------------------
# Tool Router
# --------------------------
class ToolRouter:
def __init__(self):
self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
self.domain_embeddings = {
'music': self.encoder.encode("music album release artist track"),
'sports': self.encoder.encode("athlete team score tournament"),
'science': self.encoder.encode("chemistry biology physics research")
}
self.ddg = DuckDuckGoSearchTool()
self.wiki = WikipediaSearchTool()
self.arxiv = ArxivSearchTool()
def forward(self, query: str, domain: str = None) -> str:
"""Smart search with domain prioritization"""
if domain == "academic":
return self.arxiv(query)
elif domain == "general":
return self.ddg(query)
elif domain == "encyclopedic":
return self.wiki(query)
# Fallback: Search all sources
results = {
"web": self.ddg(query),
"wikipedia": self.wiki(query),
"arxiv": self.arxiv(query)
}
return json.dumps(results)
def route(self, question: str):
query_embed = self.encoder.encode(question)
scores = {
domain: np.dot(query_embed, domain_embed)
for domain, domain_embed in self.domain_embeddings.items()
}
return max(scores, key=scores.get)
# --------------------------
# Temporal Search
# --------------------------
class HistoricalSearch:
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def get_historical_content(self, url: str, target_date: str):
return requests.get(
f"http://archive.org/wayback/available?url={url}&timestamp={target_date}"
).json()
# --------------------------
# Enhanced Excel Reader
# --------------------------
class EnhancedExcelReader(Tool):
def forward(self, path: str):
df = pd.read_excel(path)
validation = ValidationPipeline().validate(df, self._detect_schema(df))
if not validation['valid']:
raise ValueError(f"Data validation failed: {validation['errors']}")
return df.to_markdown()
def _detect_schema(self, df: pd.DataFrame):
schema = {}
for col in df.columns:
dtype = 'categorical'
if pd.api.types.is_numeric_dtype(df[col]):
dtype = 'numeric'
elif pd.api.types.is_datetime64_any_dtype(df[col]):
dtype = 'temporal'
schema[col] = {'type': dtype}
return schema
# --------------------------
# Cross-Verified Search
# --------------------------
class CrossVerifiedSearch(Tool):
name = "cross_verified_search"
description = "Searches multiple sources and returns consensus results."
inputs = {'query': {'type': 'string', 'description': 'Search query'}}
output_type = "string"
SOURCES = [
DuckDuckGoSearchTool(),
WikipediaSearchTool(),
ArxivSearchTool()
]
def __call__(self, query: str):
results = []
for source in self.SOURCES:
try:
results.append(source(query))
except Exception as e:
continue
return self._consensus(results)
def forward(self, query: str) -> str:
results = []
for source in self.SOURCES:
try:
results.append(source(query))
except Exception as e:
continue
return self._consensus(results)
def _consensus(self, results):
# Simple majority voting implementation
counts = {}
for result in results:
key = str(result)[:100] # Simple hash for demo
counts[key] = counts.get(key, 0) + 1
return max(counts, key=counts.get)
# --------------------------
# Main Agent Class (Integrated)
# --------------------------
class MagAgent:
def __init__(self, rate_limiter: Optional[Limiter] = None):
"""Initialize the MagAgent with rate limiter, model, tools, and prompt templates."""
self.rate_limiter = rate_limiter
self.model = LiteLLMModel(
model_id="gemini/gemini-1.5-flash",
api_key=os.environ.get("GEMINI_KEY"),
max_tokens=8192
)
# Load prompt templates
self.prompt_templates = self._load_prompt_templates()
# Validate prompt templates
self._validate_prompt_templates()
# Initialize tools and agent
self.tools = [
UniversalLoader(), CrossVerifiedSearch(), ValidatedExcelReader(),
VisitWebpageTool(), DownloadTaskAttachmentTool(), SpeechToTextTool()
]
self.agent = CodeAgent(
model=self.model,
tools=self.tools,
verbosity_level=2,
prompt_templates=self.prompt_templates,
max_steps=20,
add_base_tools=False
)
def _load_prompt_templates(self) -> dict:
"""Load default and custom prompt templates from prompts.yaml."""
defaults = {
"system_prompt": {"template": "You are Magus...", "variables": []},
"managed_agent": {"template": "Decomposing problem...", "variables": ["task_id", "task", "timestamp", "question_analysis"]},
"planning": {
"template": "Step-by-Step Plan...",
"variables": ["step1", "step2", "step3", "validation_step"],
"initial_facts": "...", "initial_plan": "...",
"update_facts_pre_messages": "...", "update_facts_post_messages": "...",
"update_plan_pre_messages": "...", "update_plan_post_messages": "..."
},
"final_answer": {"template": "Final Verified Answer...", "variables": ["sources", "answer"]}
}
try:
with open("prompts.yaml") as f:
user_prompts = yaml.safe_load(f)
if isinstance(user_prompts, dict):
for key in defaults:
if key in user_prompts:
defaults[key].update({
k: v for k, v in user_prompts[key].items()
if k in ['template', 'variables']
})
for extra_key in set(user_prompts[key]) - {'template', 'variables'}:
defaults[key][extra_key] = user_prompts[key][extra_key]
else:
print(f"Invalid prompts.yaml structure. Using defaults. Got type: {type(user_prompts)}")
except Exception as e:
print(f"Error loading prompts.yaml: {str(e)}. Using defaults")
return defaults
def _validate_prompt_templates(self) -> None:
"""Validate the loaded prompt templates."""
required_keys = {'system_prompt', 'managed_agent', 'planning', 'final_answer'}
if missing := required_keys - set(self.prompt_templates.keys()):
raise ValueError(f"Missing required prompt templates: {missing}")
required_managed_agent_vars = {'task_id', 'task', 'timestamp', 'question_analysis', 'subtasks', 'validation_rules', 'report'}
if missing_vars := required_managed_agent_vars - set(self.prompt_templates["managed_agent"]["variables"]):
raise ValueError(f"Missing required variables in managed_agent template: {missing_vars}")
required_planning_templates = {'initial_plan', 'initial_facts'}
if missing := required_planning_templates - set(self.prompt_templates['planning'].keys()):
raise ValueError(f"Missing required planning templates: {missing}")
print("Loaded prompt templates:")
for name, template in self.prompt_templates.items():
print(f"{name}: Variables={template.get('variables', [])}")
if 'sub_templates' in template:
print(f"Sub-templates: {list(template['sub_templates'].keys())}")
print("---")
async def __call__(self, question: str, task_id: str) -> str:
"""Execute the agent with a question and task ID."""
try:
context = self._create_context(question, task_id)
result = await self._execute_agent(question, task_id)
return self._validate_and_format(result, context)
except Exception as e:
return self._handle_error(e, context)
def _create_context(self, question: str, task_id: str) -> dict:
"""Create context dictionary for the task."""
return {
"task": question,
"task_id": task_id,
"timestamp": datetime.now().isoformat(),
"validation_checks": [],
"report": None
}
def _build_task_prompt(self, question: str, task_id: str) -> str:
"""Build the task prompt using the managed_agent template."""
template = self.prompt_templates["managed_agent"]["template"]
variables = {
"task_id": task_id,
"question": question,
"timestamp": datetime.now().isoformat(),
"question_analysis": self._generate_analysis(question),
"subtasks": self._generate_subtasks(question),
"validation_rules": self._get_validation_rules(question),
"report": None
}
return template.format(**variables)
def _generate_analysis(self, question: str) -> str:
"""Generate analysis for the question."""
return "\n".join([
f"- Question Type: {self._detect_question_type(question)}",
f"- Key Entities: {self._extract_entities(question)}",
f"- Temporal Constraints: {self._find_temporal_limits(question)}"
])
def _generate_subtasks(self, question: str) -> str:
"""Generate subtasks (placeholder implementation)."""
return "1. Analyze\n2. Verify\n3. Respond"
def _get_validation_rules(self, question: str) -> str:
"""Generate validation rules."""
return "\n".join([
"* Multi-source verification",
"* Temporal consistency check",
"* Numerical validation"
])
async def _execute_agent(self, question: str, task_id: str) -> str:
"""Run the agent with the built task prompt."""
return await asyncio.to_thread(
self.agent.run,
task=self._build_task_prompt(question, task_id)
)
def _validate_and_format(self, result: str, context: dict) -> str:
"""Validate and format the agent's result."""
if not result or not isinstance(result, str):
raise ValueError("Invalid agent response")
if len(result) > 4096:
result = result[:4090] + "..."
context["validation_checks"].append({"type": "success", "timestamp": datetime.now().isoformat()})
return result
def _handle_error(self, error: Exception, context: dict) -> str:
"""Handle errors with context-aware formatting."""
error_type = error.__class__.__name__
error_msg = str(error)
context["validation_checks"].append({
"type": "error", "error_type": error_type,
"message": error_msg, "timestamp": datetime.now().isoformat()
})
return f"AGENT ERROR: {error_type} - {error_msg}"
def _detect_question_type(self, question: str) -> str:
"""Detect the type of question (placeholder)."""
return "Unknown"
def _extract_entities(self, question: str) -> str:
"""Extract entities from the question."""
return ", ".join(re.findall(r'\b[A-Z][a-z]+\b', question)) or "None"
def _find_temporal_limits(self, question: str) -> str:
"""Find temporal constraints in the question."""
dates = re.findall(r'\b\d{4}\b', question)
return f"{min(dates)}-{max(dates)}" if dates else "None"
def _detect_domain(self, question: str) -> str:
"""Detect domain of the question (placeholder)."""
return "general"