Spaces:
Sleeping
Sleeping
| from smolagents import CodeAgent, LiteLLMModel, Tool, DuckDuckGoSearchTool, WikipediaSearchTool | |
| from token_bucket import Limiter, MemoryStorage | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| from langchain_community.document_loaders import ArxivLoader | |
| from sentence_transformers import SentenceTransformer | |
| from bs4 import BeautifulSoup | |
| from datetime import datetime | |
| import pandas as pd | |
| import numpy as np | |
| import requests | |
| import asyncio | |
| import whisper | |
| import yaml | |
| import os | |
| import re | |
| import json | |
| from typing import Optional | |
| # -------------------------- | |
| # Core Tools from Previous Implementation | |
| # -------------------------- | |
| class VisitWebpageTool(Tool): | |
| name = "visit_webpage" | |
| description = "Visits a webpage and returns its content as markdown" | |
| inputs = {'url': {'type': 'string', 'description': 'The URL to visit'}} | |
| output_type = "string" | |
| def forward(self, url: str) -> str: | |
| try: | |
| response = requests.get(url, timeout=30) | |
| response.raise_for_status() | |
| return markdownify(response.text).strip() | |
| except Exception as e: | |
| return f"Error fetching webpage: {str(e)}" | |
| class DownloadTaskAttachmentTool(Tool): | |
| name = "download_file" | |
| description = "Downloads files from the task API" | |
| inputs = {'task_id': {'type': 'string', 'description': 'The task ID to download'}} | |
| output_type = "string" | |
| def forward(self, task_id: str) -> str: | |
| api_url = os.getenv("TASK_API_URL", "https://agents-course-unit4-scoring.hf.space") | |
| file_url = f"{api_url}/files/{task_id}" | |
| try: | |
| response = requests.get(file_url, stream=True, timeout=30) | |
| response.raise_for_status() | |
| # File type detection | |
| content_type = response.headers.get('Content-Type', '') | |
| extension = self._get_extension(content_type) | |
| os.makedirs("downloads", exist_ok=True) | |
| file_path = f"downloads/{task_id}{extension}" | |
| with open(file_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return file_path | |
| except Exception as e: | |
| raise RuntimeError(f"Download failed: {str(e)}") | |
| def _get_extension(self, content_type: str) -> str: | |
| type_map = { | |
| 'image/png': '.png', | |
| 'image/jpeg': '.jpg', | |
| 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx', | |
| 'audio/mpeg': '.mp3', | |
| 'application/pdf': '.pdf', | |
| 'text/x-python': '.py' | |
| } | |
| return type_map.get(content_type.split(';')[0], '.bin') | |
| class ArxivSearchTool(Tool): | |
| name = "arxiv_search" | |
| description = "Searches academic papers on Arxiv" | |
| inputs = {'query': {'type': 'string', 'description': 'Search query'}} | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| try: | |
| loader = ArxivLoader(query=query, load_max_docs=3) | |
| docs = loader.load() | |
| return "\n\n".join([ | |
| f"Title: {doc.metadata['Title']}\n" | |
| f"Authors: {doc.metadata['Authors']}\n" | |
| f"Summary: {doc.page_content[:500]}..." | |
| for doc in docs | |
| ]) | |
| except Exception as e: | |
| return f"Arxiv search failed: {str(e)}" | |
| class SpeechToTextTool(Tool): | |
| name = "speech_to_text" | |
| description = "Converts audio files to text" | |
| inputs = {'audio_path': {'type': 'string', 'description': 'Path to audio file'}} | |
| output_type = "string" | |
| def __init__(self): | |
| self.model = whisper.load_model("base") | |
| def forward(self, audio_path: str) -> str: | |
| if not os.path.exists(audio_path): | |
| return f"File not found: {audio_path}" | |
| return self.model.transcribe(audio_path).get("text", "") | |
| # -------------------------- | |
| # Enhanced Tools with Validation | |
| # -------------------------- | |
| class ValidatedExcelReader(Tool): | |
| name = "excel_reader" | |
| description = "Reads and validates Excel files" | |
| inputs = { | |
| 'file_path': {'type': 'string', 'description': 'Path to Excel file'}, | |
| 'schema': {'type': 'object', 'description': 'Validation schema', 'nullable': True} | |
| } | |
| output_type = "string" | |
| def forward(self, file_path: str, schema: dict = None) -> str: | |
| df = pd.read_excel(file_path) | |
| if schema: | |
| validation = ValidationPipeline().validate(df, schema) | |
| if not validation['valid']: | |
| raise ValueError(f"Data validation failed: {validation['errors']}") | |
| return df.to_markdown() | |
| # -------------------------- | |
| # Integrated Universal Loader | |
| # -------------------------- | |
| class UniversalLoader(Tool): | |
| name = "universal_loader" | |
| description = "Loads various file types and web content using appropriate sub-tools." | |
| inputs = { | |
| 'source': { | |
| 'type': 'string', | |
| 'description': 'Type of source to load (web/excel/audio/arxiv)' | |
| }, | |
| 'task_id': { | |
| 'type': 'string', | |
| 'description': 'Task ID for attachments', | |
| 'nullable': True | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self): | |
| self.loaders = { | |
| 'excel': ValidatedExcelReader(), | |
| 'audio': SpeechToTextTool(), | |
| 'arxiv': ArxivSearchTool(), | |
| 'web': VisitWebpageTool() | |
| } | |
| def forward(self, source: str, task_id: str = None) -> str: | |
| try: | |
| if source == "attachment": | |
| file_path = DownloadTaskAttachmentTool()(task_id) | |
| return self._load_by_type(file_path) | |
| return self.loaders[source].forward(task_id) | |
| except Exception as e: | |
| return self._fallback(source, task_id) | |
| def _load_by_type(self, file_path: str) -> str: | |
| ext = file_path.split('.')[-1].lower() | |
| loader_map = { | |
| 'xlsx': 'excel', | |
| 'mp3': 'audio', | |
| 'pdf': 'arxiv' | |
| } | |
| return self.loaders[loader_map.get(ext, 'web')].forward(file_path) | |
| def _fallback(self, source: str, context: str) -> str: | |
| return CrossVerifiedSearch()(f"{source} {context}") | |
| # -------------------------- | |
| # Validation Pipeline | |
| # -------------------------- | |
| class ValidationPipeline: | |
| VALIDATORS = { | |
| 'numeric': { | |
| 'check': lambda x: pd.api.types.is_numeric_dtype(x), | |
| 'error': "Non-numeric value found in numeric field" | |
| }, | |
| 'temporal': { | |
| 'check': lambda x: pd.api.types.is_datetime64_any_dtype(x), | |
| 'error': "Invalid date format detected" | |
| }, | |
| 'categorical': { | |
| 'check': lambda x: x.isin(x.dropna().unique()), | |
| 'error': "Invalid category value detected" | |
| } | |
| } | |
| def validate(self, data, schema: dict): | |
| errors = [] | |
| for field, config in schema.items(): | |
| validator = self.VALIDATORS.get(config['type']) | |
| if not validator['check'](data[field]): | |
| errors.append(f"{field}: {validator['error']}") | |
| return { | |
| 'valid': len(errors) == 0, | |
| 'errors': errors, | |
| 'confidence': 1.0 - (len(errors) / len(schema)) | |
| } | |
| # -------------------------- | |
| # Tool Router | |
| # -------------------------- | |
| class ToolRouter: | |
| def __init__(self): | |
| self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.domain_embeddings = { | |
| 'music': self.encoder.encode("music album release artist track"), | |
| 'sports': self.encoder.encode("athlete team score tournament"), | |
| 'science': self.encoder.encode("chemistry biology physics research") | |
| } | |
| self.ddg = DuckDuckGoSearchTool() | |
| self.wiki = WikipediaSearchTool() | |
| self.arxiv = ArxivSearchTool() | |
| def forward(self, query: str, domain: str = None) -> str: | |
| """Smart search with domain prioritization""" | |
| if domain == "academic": | |
| return self.arxiv(query) | |
| elif domain == "general": | |
| return self.ddg(query) | |
| elif domain == "encyclopedic": | |
| return self.wiki(query) | |
| # Fallback: Search all sources | |
| results = { | |
| "web": self.ddg(query), | |
| "wikipedia": self.wiki(query), | |
| "arxiv": self.arxiv(query) | |
| } | |
| return json.dumps(results) | |
| def route(self, question: str): | |
| query_embed = self.encoder.encode(question) | |
| scores = { | |
| domain: np.dot(query_embed, domain_embed) | |
| for domain, domain_embed in self.domain_embeddings.items() | |
| } | |
| return max(scores, key=scores.get) | |
| # -------------------------- | |
| # Temporal Search | |
| # -------------------------- | |
| class HistoricalSearch: | |
| def get_historical_content(self, url: str, target_date: str): | |
| return requests.get( | |
| f"http://archive.org/wayback/available?url={url}×tamp={target_date}" | |
| ).json() | |
| # -------------------------- | |
| # Enhanced Excel Reader | |
| # -------------------------- | |
| class EnhancedExcelReader(Tool): | |
| def forward(self, path: str): | |
| df = pd.read_excel(path) | |
| validation = ValidationPipeline().validate(df, self._detect_schema(df)) | |
| if not validation['valid']: | |
| raise ValueError(f"Data validation failed: {validation['errors']}") | |
| return df.to_markdown() | |
| def _detect_schema(self, df: pd.DataFrame): | |
| schema = {} | |
| for col in df.columns: | |
| dtype = 'categorical' | |
| if pd.api.types.is_numeric_dtype(df[col]): | |
| dtype = 'numeric' | |
| elif pd.api.types.is_datetime64_any_dtype(df[col]): | |
| dtype = 'temporal' | |
| schema[col] = {'type': dtype} | |
| return schema | |
| # -------------------------- | |
| # Cross-Verified Search | |
| # -------------------------- | |
| class CrossVerifiedSearch(Tool): | |
| name = "cross_verified_search" | |
| description = "Searches multiple sources and returns consensus results." | |
| inputs = {'query': {'type': 'string', 'description': 'Search query'}} | |
| output_type = "string" | |
| SOURCES = [ | |
| DuckDuckGoSearchTool(), | |
| WikipediaSearchTool(), | |
| ArxivSearchTool() | |
| ] | |
| def __call__(self, query: str): | |
| results = [] | |
| for source in self.SOURCES: | |
| try: | |
| results.append(source(query)) | |
| except Exception as e: | |
| continue | |
| return self._consensus(results) | |
| def forward(self, query: str) -> str: | |
| results = [] | |
| for source in self.SOURCES: | |
| try: | |
| results.append(source(query)) | |
| except Exception as e: | |
| continue | |
| return self._consensus(results) | |
| def _consensus(self, results): | |
| # Simple majority voting implementation | |
| counts = {} | |
| for result in results: | |
| key = str(result)[:100] # Simple hash for demo | |
| counts[key] = counts.get(key, 0) + 1 | |
| return max(counts, key=counts.get) | |
| # -------------------------- | |
| # Main Agent Class (Integrated) | |
| # -------------------------- | |
| class MagAgent: | |
| def __init__(self, rate_limiter: Optional[Limiter] = None): | |
| """Initialize the MagAgent with rate limiter, model, tools, and prompt templates.""" | |
| self.rate_limiter = rate_limiter | |
| self.model = LiteLLMModel( | |
| model_id="gemini/gemini-1.5-flash", | |
| api_key=os.environ.get("GEMINI_KEY"), | |
| max_tokens=8192 | |
| ) | |
| # Load prompt templates | |
| self.prompt_templates = self._load_prompt_templates() | |
| # Validate prompt templates | |
| self._validate_prompt_templates() | |
| # Initialize tools and agent | |
| self.tools = [ | |
| UniversalLoader(), CrossVerifiedSearch(), ValidatedExcelReader(), | |
| VisitWebpageTool(), DownloadTaskAttachmentTool(), SpeechToTextTool() | |
| ] | |
| self.agent = CodeAgent( | |
| model=self.model, | |
| tools=self.tools, | |
| verbosity_level=2, | |
| prompt_templates=self.prompt_templates, | |
| max_steps=20, | |
| add_base_tools=False | |
| ) | |
| def _load_prompt_templates(self) -> dict: | |
| """Load default and custom prompt templates from prompts.yaml.""" | |
| defaults = { | |
| "system_prompt": {"template": "You are Magus...", "variables": []}, | |
| "managed_agent": {"template": "Decomposing problem...", "variables": ["task_id", "task", "timestamp", "question_analysis"]}, | |
| "planning": { | |
| "template": "Step-by-Step Plan...", | |
| "variables": ["step1", "step2", "step3", "validation_step"], | |
| "initial_facts": "...", "initial_plan": "...", | |
| "update_facts_pre_messages": "...", "update_facts_post_messages": "...", | |
| "update_plan_pre_messages": "...", "update_plan_post_messages": "..." | |
| }, | |
| "final_answer": {"template": "Final Verified Answer...", "variables": ["sources", "answer"]} | |
| } | |
| try: | |
| with open("prompts.yaml") as f: | |
| user_prompts = yaml.safe_load(f) | |
| if isinstance(user_prompts, dict): | |
| for key in defaults: | |
| if key in user_prompts: | |
| defaults[key].update({ | |
| k: v for k, v in user_prompts[key].items() | |
| if k in ['template', 'variables'] | |
| }) | |
| for extra_key in set(user_prompts[key]) - {'template', 'variables'}: | |
| defaults[key][extra_key] = user_prompts[key][extra_key] | |
| else: | |
| print(f"Invalid prompts.yaml structure. Using defaults. Got type: {type(user_prompts)}") | |
| except Exception as e: | |
| print(f"Error loading prompts.yaml: {str(e)}. Using defaults") | |
| return defaults | |
| def _validate_prompt_templates(self) -> None: | |
| """Validate the loaded prompt templates.""" | |
| required_keys = {'system_prompt', 'managed_agent', 'planning', 'final_answer'} | |
| if missing := required_keys - set(self.prompt_templates.keys()): | |
| raise ValueError(f"Missing required prompt templates: {missing}") | |
| required_managed_agent_vars = {'task_id', 'task', 'timestamp', 'question_analysis', 'subtasks', 'validation_rules', 'report'} | |
| if missing_vars := required_managed_agent_vars - set(self.prompt_templates["managed_agent"]["variables"]): | |
| raise ValueError(f"Missing required variables in managed_agent template: {missing_vars}") | |
| required_planning_templates = {'initial_plan', 'initial_facts'} | |
| if missing := required_planning_templates - set(self.prompt_templates['planning'].keys()): | |
| raise ValueError(f"Missing required planning templates: {missing}") | |
| print("Loaded prompt templates:") | |
| for name, template in self.prompt_templates.items(): | |
| print(f"{name}: Variables={template.get('variables', [])}") | |
| if 'sub_templates' in template: | |
| print(f"Sub-templates: {list(template['sub_templates'].keys())}") | |
| print("---") | |
| async def __call__(self, question: str, task_id: str) -> str: | |
| """Execute the agent with a question and task ID.""" | |
| try: | |
| context = self._create_context(question, task_id) | |
| result = await self._execute_agent(question, task_id) | |
| return self._validate_and_format(result, context) | |
| except Exception as e: | |
| return self._handle_error(e, context) | |
| def _create_context(self, question: str, task_id: str) -> dict: | |
| """Create context dictionary for the task.""" | |
| return { | |
| "task": question, | |
| "task_id": task_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "validation_checks": [], | |
| "report": None | |
| } | |
| def _build_task_prompt(self, question: str, task_id: str) -> str: | |
| """Build the task prompt using the managed_agent template.""" | |
| template = self.prompt_templates["managed_agent"]["template"] | |
| variables = { | |
| "task_id": task_id, | |
| "question": question, | |
| "timestamp": datetime.now().isoformat(), | |
| "question_analysis": self._generate_analysis(question), | |
| "subtasks": self._generate_subtasks(question), | |
| "validation_rules": self._get_validation_rules(question), | |
| "report": None | |
| } | |
| return template.format(**variables) | |
| def _generate_analysis(self, question: str) -> str: | |
| """Generate analysis for the question.""" | |
| return "\n".join([ | |
| f"- Question Type: {self._detect_question_type(question)}", | |
| f"- Key Entities: {self._extract_entities(question)}", | |
| f"- Temporal Constraints: {self._find_temporal_limits(question)}" | |
| ]) | |
| def _generate_subtasks(self, question: str) -> str: | |
| """Generate subtasks (placeholder implementation).""" | |
| return "1. Analyze\n2. Verify\n3. Respond" | |
| def _get_validation_rules(self, question: str) -> str: | |
| """Generate validation rules.""" | |
| return "\n".join([ | |
| "* Multi-source verification", | |
| "* Temporal consistency check", | |
| "* Numerical validation" | |
| ]) | |
| async def _execute_agent(self, question: str, task_id: str) -> str: | |
| """Run the agent with the built task prompt.""" | |
| return await asyncio.to_thread( | |
| self.agent.run, | |
| task=self._build_task_prompt(question, task_id) | |
| ) | |
| def _validate_and_format(self, result: str, context: dict) -> str: | |
| """Validate and format the agent's result.""" | |
| if not result or not isinstance(result, str): | |
| raise ValueError("Invalid agent response") | |
| if len(result) > 4096: | |
| result = result[:4090] + "..." | |
| context["validation_checks"].append({"type": "success", "timestamp": datetime.now().isoformat()}) | |
| return result | |
| def _handle_error(self, error: Exception, context: dict) -> str: | |
| """Handle errors with context-aware formatting.""" | |
| error_type = error.__class__.__name__ | |
| error_msg = str(error) | |
| context["validation_checks"].append({ | |
| "type": "error", "error_type": error_type, | |
| "message": error_msg, "timestamp": datetime.now().isoformat() | |
| }) | |
| return f"AGENT ERROR: {error_type} - {error_msg}" | |
| def _detect_question_type(self, question: str) -> str: | |
| """Detect the type of question (placeholder).""" | |
| return "Unknown" | |
| def _extract_entities(self, question: str) -> str: | |
| """Extract entities from the question.""" | |
| return ", ".join(re.findall(r'\b[A-Z][a-z]+\b', question)) or "None" | |
| def _find_temporal_limits(self, question: str) -> str: | |
| """Find temporal constraints in the question.""" | |
| dates = re.findall(r'\b\d{4}\b', question) | |
| return f"{min(dates)}-{max(dates)}" if dates else "None" | |
| def _detect_domain(self, question: str) -> str: | |
| """Detect domain of the question (placeholder).""" | |
| return "general" |