File size: 6,160 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Database connection and models for EvoAgentX.
"""
# import asyncio
import logging
from datetime import datetime
from enum import Enum
from typing import Optional, List, Dict, Any # , Union
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import ASCENDING, TEXT
from pydantic_core import core_schema
from bson import ObjectId
from pydantic import GetCoreSchemaHandler
from pydantic import Field, BaseModel
from evoagentx.app.config import settings

# Setup logger
logger = logging.getLogger(__name__)

# Custom PyObjectId for MongoDB ObjectId compatibility with Pydantic
class PyObjectId(ObjectId):
    @classmethod
    def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
        return core_schema.no_info_after_validator_function(cls.validate, core_schema.str_schema())

    @classmethod
    def validate(cls, v):
        if not ObjectId.is_valid(v):
            raise ValueError("Invalid ObjectId")
        return ObjectId(v)

# Base model with ObjectId handling
class MongoBaseModel(BaseModel):
    id: Optional[PyObjectId] = Field(alias="_id", default=None)

    model_config = {
        "protected_namespaces": (),
        "populate_by_name": True,  # Replace `allow_population_by_field_name`
        "arbitrary_types_allowed": True,  # Keep custom types like ObjectId
        "json_encoders": {
            ObjectId: str  # Ensure ObjectId is serialized as a string
        }
    }

# Status Enums
class AgentStatus(str, Enum):
    CREATED = "created"
    ACTIVE = "active"
    INACTIVE = "inactive"
    ERROR = "error"

class WorkflowStatus(str, Enum):
    CREATED = "created"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"

class ExecutionStatus(str, Enum):
    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    TIMEOUT = "timeout"
    CANCELLED = "cancelled"

# Database Models
class Agent(MongoBaseModel):
    id: str = Field(..., alias="_id")
    name: str
    description: Optional[str] = None
    config: Dict[str, Any]
    state: Dict[str, Any] = Field(default_factory=dict)
    runtime_params: Dict[str, Any] = Field(default_factory=dict)
    status: AgentStatus = AgentStatus.CREATED
    created_at: datetime = Field(default_factory=datetime.utcnow)
    updated_at: datetime = Field(default_factory=datetime.utcnow)
    created_by: Optional[str] = None
    tags: List[str] = Field(default_factory=list)

class Workflow(MongoBaseModel):
    id: str = Field(..., alias="_id")
    name: str
    description: Optional[str] = None
    definition: Dict[str, Any]
    agent_ids: List[str] = Field(default_factory=list)
    status: WorkflowStatus = WorkflowStatus.CREATED
    created_at: datetime = Field(default_factory=datetime.utcnow)
    updated_at: datetime = Field(default_factory=datetime.utcnow)
    created_by: Optional[str] = None
    tags: List[str] = Field(default_factory=list)
    version: int = 1

class ExecutionLog(MongoBaseModel):
    workflow_id: str
    execution_id: str
    step_id: Optional[str] = None
    agent_id: Optional[str] = None
    timestamp: datetime = Field(default_factory=datetime.utcnow)
    level: str = "INFO"
    message: str
    details: Dict[str, Any] = Field(default_factory=dict)

class WorkflowExecution(MongoBaseModel):
    workflow_id: str
    status: ExecutionStatus = ExecutionStatus.PENDING
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    input_params: Dict[str, Any] = Field(default_factory=dict)
    results: Dict[str, Any] = Field(default_factory=dict)
    created_by: Optional[str] = None
    step_results: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    current_step: Optional[str] = None
    error_message: Optional[str] = None
    created_at: datetime = Field(default_factory=datetime.utcnow)

# Database client
class Database:
    client: AsyncIOMotorClient = None
    db = None
    
    # Collections
    agents = None
    workflows = None
    executions = None
    logs = None
    
    @classmethod
    async def connect(cls):
        """Connect to MongoDB"""
        logger.info(f"Connecting to MongoDB at {settings.MONGODB_URL}...")
        cls.client = AsyncIOMotorClient(settings.MONGODB_URL)
        cls.db = cls.client[settings.MONGODB_DB_NAME]
        
        # Set up collections
        cls.agents = cls.db.agents
        cls.workflows = cls.db.workflows
        cls.executions = cls.db.workflow_executions
        cls.logs = cls.db.execution_logs
        
        # Create indexes
        await cls._create_indexes()
        
        logger.info("Connected to MongoDB successfully")
    
    @classmethod
    async def disconnect(cls):
        """Disconnect from MongoDB"""
        if cls.client:
            cls.client.close()
            logger.info("Disconnected from MongoDB")
    
    @classmethod
    async def _create_indexes(cls):
        """Create indexes for collections"""
        # Agent indexes
        await cls.agents.create_index([("name", ASCENDING)], unique=True)
        await cls.agents.create_index([("name", TEXT), ("description", TEXT)])
        await cls.agents.create_index([("created_at", ASCENDING)])
        await cls.agents.create_index([("tags", ASCENDING)])
        
        # Workflow indexes
        await cls.workflows.create_index([("name", ASCENDING)])
        await cls.workflows.create_index([("name", TEXT), ("description", TEXT)])
        await cls.workflows.create_index([("created_at", ASCENDING)])
        await cls.workflows.create_index([("agent_ids", ASCENDING)])
        await cls.workflows.create_index([("tags", ASCENDING)])
        
        # Execution indexes
        await cls.executions.create_index([("workflow_id", ASCENDING)])
        await cls.executions.create_index([("created_at", ASCENDING)])
        await cls.executions.create_index([("status", ASCENDING)])
        
        # Log indexes
        await cls.logs.create_index([("execution_id", ASCENDING)])
        await cls.logs.create_index([("timestamp", ASCENDING)])
        await cls.logs.create_index([("workflow_id", ASCENDING), ("execution_id", ASCENDING)])