|
|
from typing import Dict, List, Any, Optional, Callable |
|
|
from ...core.base import LatticeComponent, LatticeError |
|
|
from pydantic import BaseModel |
|
|
from enum import Enum |
|
|
from datetime import datetime |
|
|
import logging |
|
|
import asyncio |
|
|
import uuid |
|
|
|
|
|
class AgentRole(Enum): |
|
|
"""Predefined agent roles""" |
|
|
RESEARCHER = "researcher" |
|
|
ANALYST = "analyst" |
|
|
VALIDATOR = "validator" |
|
|
COORDINATOR = "coordinator" |
|
|
EXECUTOR = "executor" |
|
|
CUSTOM = "custom" |
|
|
|
|
|
class AgentTool(BaseModel): |
|
|
"""Tool available to agents""" |
|
|
name: str |
|
|
description: str |
|
|
function: Callable |
|
|
required_permissions: List[str] |
|
|
metadata: Optional[Dict[str, Any]] = None |
|
|
|
|
|
class AgentConfig(BaseModel): |
|
|
"""Agent configuration""" |
|
|
role: AgentRole |
|
|
name: str |
|
|
description: str |
|
|
system_message: str |
|
|
tools: List[AgentTool] |
|
|
model: str = "claude-3-opus" |
|
|
temperature: float = 0.7 |
|
|
max_tokens: int = 1000 |
|
|
metadata: Optional[Dict[str, Any]] = None |
|
|
|
|
|
class TaskInput(BaseModel): |
|
|
"""Task input for agents""" |
|
|
task_id: str |
|
|
description: str |
|
|
inputs: Dict[str, Any] |
|
|
context: Optional[Dict[str, Any]] = None |
|
|
tools: Optional[List[str]] = None |
|
|
|
|
|
class TaskResult(BaseModel): |
|
|
"""Task execution result""" |
|
|
task_id: str |
|
|
agent_id: str |
|
|
status: str |
|
|
result: Dict[str, Any] |
|
|
error: Optional[str] = None |
|
|
tools_used: List[str] |
|
|
start_time: datetime |
|
|
end_time: datetime |
|
|
metadata: Dict[str, Any] |
|
|
|
|
|
class BaseAgent(LatticeComponent): |
|
|
"""Base agent class""" |
|
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None): |
|
|
super().__init__(config) |
|
|
self.agent_config = AgentConfig(**(config or {})) |
|
|
self.tools = {tool.name: tool for tool in self.agent_config.tools} |
|
|
self.agent_id = str(uuid.uuid4()) |
|
|
self.logger = logging.getLogger(f"lattice.agent.{self.agent_config.name}") |
|
|
|
|
|
async def initialize(self) -> None: |
|
|
"""Initialize agent""" |
|
|
try: |
|
|
|
|
|
for tool in self.agent_config.tools: |
|
|
if not callable(tool.function): |
|
|
raise LatticeError(f"Tool {tool.name} function is not callable") |
|
|
|
|
|
self._initialized = True |
|
|
|
|
|
except Exception as e: |
|
|
raise LatticeError(f"Failed to initialize agent: {str(e)}") |
|
|
|
|
|
async def validate_config(self) -> bool: |
|
|
"""Validate agent configuration""" |
|
|
try: |
|
|
AgentConfig(**(self.config or {})) |
|
|
return True |
|
|
except Exception as e: |
|
|
self.logger.error(f"Invalid configuration: {str(e)}") |
|
|
return False |
|
|
|
|
|
async def execute_task(self, task: TaskInput) -> TaskResult: |
|
|
"""Execute a task""" |
|
|
self.ensure_initialized() |
|
|
start_time = datetime.now() |
|
|
tools_used = [] |
|
|
|
|
|
try: |
|
|
|
|
|
if task.tools: |
|
|
for tool_name in task.tools: |
|
|
if tool_name not in self.tools: |
|
|
raise LatticeError(f"Tool {tool_name} not available") |
|
|
|
|
|
|
|
|
result = await self._execute_implementation(task) |
|
|
|
|
|
|
|
|
tools_used = list(set(tools_used)) |
|
|
|
|
|
return TaskResult( |
|
|
task_id=task.task_id, |
|
|
agent_id=self.agent_id, |
|
|
status="completed", |
|
|
result=result, |
|
|
tools_used=tools_used, |
|
|
start_time=start_time, |
|
|
end_time=datetime.now(), |
|
|
metadata={ |
|
|
"agent_role": self.agent_config.role.value, |
|
|
"model": self.agent_config.model |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Task execution failed: {str(e)}") |
|
|
return TaskResult( |
|
|
task_id=task.task_id, |
|
|
agent_id=self.agent_id, |
|
|
status="failed", |
|
|
result={}, |
|
|
error=str(e), |
|
|
tools_used=tools_used, |
|
|
start_time=start_time, |
|
|
end_time=datetime.now(), |
|
|
metadata={ |
|
|
"agent_role": self.agent_config.role.value, |
|
|
"model": self.agent_config.model |
|
|
} |
|
|
) |
|
|
|
|
|
async def use_tool(self, tool_name: str, **kwargs) -> Any: |
|
|
"""Use a tool""" |
|
|
if tool_name not in self.tools: |
|
|
raise LatticeError(f"Tool {tool_name} not available") |
|
|
|
|
|
tool = self.tools[tool_name] |
|
|
|
|
|
try: |
|
|
result = await tool.function(**kwargs) |
|
|
return result |
|
|
except Exception as e: |
|
|
self.logger.error(f"Tool {tool_name} execution failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
@abstractmethod |
|
|
async def _execute_implementation(self, task: TaskInput) -> Dict[str, Any]: |
|
|
"""Implementation specific task execution""" |
|
|
pass |
|
|
|
|
|
class AgentRegistry: |
|
|
"""Registry for agent management""" |
|
|
|
|
|
def __init__(self): |
|
|
self.agents: Dict[str, BaseAgent] = {} |
|
|
self.logger = logging.getLogger("lattice.agent.registry") |
|
|
|
|
|
async def register_agent(self, agent: BaseAgent) -> str: |
|
|
"""Register a new agent""" |
|
|
try: |
|
|
await agent.initialize() |
|
|
self.agents[agent.agent_id] = agent |
|
|
return agent.agent_id |
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to register agent: {str(e)}") |
|
|
raise |
|
|
|
|
|
def get_agent(self, agent_id: str) -> Optional[BaseAgent]: |
|
|
"""Get agent by ID""" |
|
|
return self.agents.get(agent_id) |
|
|
|
|
|
def list_agents(self) -> List[Dict[str, Any]]: |
|
|
"""List registered agents""" |
|
|
return [ |
|
|
{ |
|
|
"agent_id": agent.agent_id, |
|
|
"name": agent.agent_config.name, |
|
|
"role": agent.agent_config.role.value, |
|
|
"tools": [t.name for t in agent.agent_config.tools] |
|
|
} |
|
|
for agent in self.agents.values() |
|
|
] |
|
|
|
|
|
async def execute_task(self, agent_id: str, task: TaskInput) -> TaskResult: |
|
|
"""Execute task using specified agent""" |
|
|
agent = self.get_agent(agent_id) |
|
|
if not agent: |
|
|
raise LatticeError(f"Agent {agent_id} not found") |
|
|
|
|
|
return await agent.execute_task(task) |