File size: 6,160 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
"""
Database connection and models for EvoAgentX.
"""
# import asyncio
import logging
from datetime import datetime
from enum import Enum
from typing import Optional, List, Dict, Any # , Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import ASCENDING, TEXT
from pydantic_core import core_schema
from bson import ObjectId
from pydantic import GetCoreSchemaHandler
from pydantic import Field, BaseModel
from evoagentx.app.config import settings
# Setup logger
logger = logging.getLogger(__name__)
# Custom PyObjectId for MongoDB ObjectId compatibility with Pydantic
class PyObjectId(ObjectId):
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
return core_schema.no_info_after_validator_function(cls.validate, core_schema.str_schema())
@classmethod
def validate(cls, v):
if not ObjectId.is_valid(v):
raise ValueError("Invalid ObjectId")
return ObjectId(v)
# Base model with ObjectId handling
class MongoBaseModel(BaseModel):
id: Optional[PyObjectId] = Field(alias="_id", default=None)
model_config = {
"protected_namespaces": (),
"populate_by_name": True, # Replace `allow_population_by_field_name`
"arbitrary_types_allowed": True, # Keep custom types like ObjectId
"json_encoders": {
ObjectId: str # Ensure ObjectId is serialized as a string
}
}
# Status Enums
class AgentStatus(str, Enum):
CREATED = "created"
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
class WorkflowStatus(str, Enum):
CREATED = "created"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class ExecutionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
CANCELLED = "cancelled"
# Database Models
class Agent(MongoBaseModel):
id: str = Field(..., alias="_id")
name: str
description: Optional[str] = None
config: Dict[str, Any]
state: Dict[str, Any] = Field(default_factory=dict)
runtime_params: Dict[str, Any] = Field(default_factory=dict)
status: AgentStatus = AgentStatus.CREATED
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
created_by: Optional[str] = None
tags: List[str] = Field(default_factory=list)
class Workflow(MongoBaseModel):
id: str = Field(..., alias="_id")
name: str
description: Optional[str] = None
definition: Dict[str, Any]
agent_ids: List[str] = Field(default_factory=list)
status: WorkflowStatus = WorkflowStatus.CREATED
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
created_by: Optional[str] = None
tags: List[str] = Field(default_factory=list)
version: int = 1
class ExecutionLog(MongoBaseModel):
workflow_id: str
execution_id: str
step_id: Optional[str] = None
agent_id: Optional[str] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)
level: str = "INFO"
message: str
details: Dict[str, Any] = Field(default_factory=dict)
class WorkflowExecution(MongoBaseModel):
workflow_id: str
status: ExecutionStatus = ExecutionStatus.PENDING
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
input_params: Dict[str, Any] = Field(default_factory=dict)
results: Dict[str, Any] = Field(default_factory=dict)
created_by: Optional[str] = None
step_results: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
current_step: Optional[str] = None
error_message: Optional[str] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
# Database client
class Database:
client: AsyncIOMotorClient = None
db = None
# Collections
agents = None
workflows = None
executions = None
logs = None
@classmethod
async def connect(cls):
"""Connect to MongoDB"""
logger.info(f"Connecting to MongoDB at {settings.MONGODB_URL}...")
cls.client = AsyncIOMotorClient(settings.MONGODB_URL)
cls.db = cls.client[settings.MONGODB_DB_NAME]
# Set up collections
cls.agents = cls.db.agents
cls.workflows = cls.db.workflows
cls.executions = cls.db.workflow_executions
cls.logs = cls.db.execution_logs
# Create indexes
await cls._create_indexes()
logger.info("Connected to MongoDB successfully")
@classmethod
async def disconnect(cls):
"""Disconnect from MongoDB"""
if cls.client:
cls.client.close()
logger.info("Disconnected from MongoDB")
@classmethod
async def _create_indexes(cls):
"""Create indexes for collections"""
# Agent indexes
await cls.agents.create_index([("name", ASCENDING)], unique=True)
await cls.agents.create_index([("name", TEXT), ("description", TEXT)])
await cls.agents.create_index([("created_at", ASCENDING)])
await cls.agents.create_index([("tags", ASCENDING)])
# Workflow indexes
await cls.workflows.create_index([("name", ASCENDING)])
await cls.workflows.create_index([("name", TEXT), ("description", TEXT)])
await cls.workflows.create_index([("created_at", ASCENDING)])
await cls.workflows.create_index([("agent_ids", ASCENDING)])
await cls.workflows.create_index([("tags", ASCENDING)])
# Execution indexes
await cls.executions.create_index([("workflow_id", ASCENDING)])
await cls.executions.create_index([("created_at", ASCENDING)])
await cls.executions.create_index([("status", ASCENDING)])
# Log indexes
await cls.logs.create_index([("execution_id", ASCENDING)])
await cls.logs.create_index([("timestamp", ASCENDING)])
await cls.logs.create_index([("workflow_id", ASCENDING), ("execution_id", ASCENDING)]) |