File size: 6,877 Bytes
172064c
 
 
 
 
 
 
 
 
 
 
 
b4c9cb7
172064c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4c9cb7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from src.utils.logger import logger
from motor.motor_asyncio import AsyncIOMotorClient
from pydantic import BaseModel
from typing import Type, Dict, List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorCollection
from datetime import datetime, timezone, timedelta
from src.utils.logger import get_date_time
import os

client: AsyncIOMotorClient = AsyncIOMotorClient(os.getenv("MONGO_CONNECTION_STR"))
# database = client["custom_gpt"]
database = client["prompt_editor"]


class MongoCRUD:
    def __init__(
        self,
        collection: AsyncIOMotorCollection,
        model: Type[BaseModel],
        ttl_seconds: Optional[int] = None,
    ):
        self.collection = collection
        self.model = model
        self.ttl_seconds = ttl_seconds
        self._index_created = False

    async def _ensure_ttl_index(self):
        """Ensure TTL index exists"""
        if self.ttl_seconds is not None and not self._index_created:
            await self.collection.create_index("expire_at", expireAfterSeconds=0)
            self._index_created = True

    def _order_fields(self, doc: Dict) -> Dict:
        """Order fields in the document to ensure created_at and updated_at are at the end."""
        ordered_doc = {
            k: doc[k] for k in doc if k not in ["created_at", "updated_at", "expire_at"]
        }
        if "id" in doc:
            ordered_doc["_id"] = ObjectId(doc["id"])
        if "created_at" in doc:
            ordered_doc["created_at"] = doc["created_at"]
        if "updated_at" in doc:
            ordered_doc["updated_at"] = doc["updated_at"]
        if "expire_at" in doc:
            ordered_doc["expire_at"] = doc["expire_at"]
        return ordered_doc

    async def create(self, data: Dict) -> str:
        """Create a new document in the collection asynchronously, optionally using a user-specified ID."""
        await self._ensure_ttl_index()
        now = get_date_time().replace(tzinfo=None)
        data["created_at"] = now
        data["updated_at"] = now
        if self.ttl_seconds is not None:
            data["expire_at"] = now + timedelta(seconds=self.ttl_seconds)
        document = self.model(**data).model_dump(exclude_unset=True)
        ordered_document = self._order_fields(document)
        result = await self.collection.insert_one(ordered_document)
        return str(result.inserted_id)

    async def read(self, query: Dict, sort: List[tuple] = None) -> List[Dict]:
        """Read documents from the collection based on a query asynchronously."""
        cursor = self.collection.find(query)
        
        # Apply sorting if provided
        if sort:
            cursor = cursor.sort(sort)
            
        docs = []
        async for doc in cursor:
            docs.append(
                {
                    "_id": str(doc["_id"]),
                    **self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
                }
            )
        return docs

    async def read_one(self, query: Dict) -> Optional[Dict]:
        """Read a single document from the collection based on a query asynchronously."""
        doc = await self.collection.find_one(query)
        if doc:
            doc["_id"] = str(doc["_id"])
            return {
                "_id": doc["_id"],
                **self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
            }
        return None

    async def update(self, query: Dict, data: Dict, upsert: bool = False) -> int:
        await self._ensure_ttl_index()

        if any(key.startswith("$") for key in data.keys()):
            update_data = data
        else:
            data["updated_at"] = get_date_time().replace(tzinfo=None)
            if self.ttl_seconds is not None:
                data["expire_at"] = data["updated_at"] + timedelta(
                    seconds=self.ttl_seconds
                )
            update_data = {
                "$set": self._order_fields(
                    self.model(**data).model_dump(exclude_unset=True)
                )
            }

        result = await self.collection.update_many(query, update_data, upsert=upsert)
        return result.modified_count

    async def delete(self, query: Dict) -> int:
        """Delete documents from the collection based on a query asynchronously."""
        result = await self.collection.delete_many(query)
        return result.deleted_count

    async def delete_one(self, query: Dict) -> int:
        """Delete a single document from the collection based on a query asynchronously."""
        result = await self.collection.delete_one(query)
        return result.deleted_count

    async def find_by_id(self, id: str) -> Optional[Dict]:
        """Find a document by its ID asynchronously."""
        return await self.read_one({"_id": ObjectId(id)})

    async def find_all(self) -> List[Dict]:
        """Find all documents in the collection asynchronously."""
        return await self.read({})

    async def find_many(
        self, filter: Dict, skip: int = 0, limit: int = 0, sort: List[tuple] = None
    ) -> List[Dict]:
        """
        Find documents based on filter with pagination support.

        Args:
            filter: MongoDB query filter
            skip: Number of documents to skip
            limit: Maximum number of documents to return (0 means no limit)
            sort: Optional sorting parameters [(field_name, direction)]
                  where direction is 1 for ascending, -1 for descending

        Returns:
            List of documents matching the filter
        """
        cursor = self.collection.find(filter)

        # Apply pagination
        if skip > 0:
            cursor = cursor.skip(skip)
        if limit > 0:
            cursor = cursor.limit(limit)

        # Apply sorting if provided
        if sort:
            cursor = cursor.sort(sort)

        docs = []
        async for doc in cursor:
            # Convert _id to string and prepare document
            doc_id = str(doc["_id"])
            doc_copy = {**doc}
            doc_copy["_id"] = doc_id

            # Process through model validation
            try:
                validated_doc = self.model(**doc_copy).model_dump(exclude={"id"})
                docs.append({"_id": doc_id, **self._order_fields(validated_doc)})
            except Exception as e:
                logger.error(f"Error validating document {doc_id}: {str(e)}")
                # Include document even if validation fails, but with original data
                docs.append(
                    {"_id": doc_id, **{k: v for k, v in doc.items() if k != "_id"}}
                )
        
        return docs


from src.apis.models.user_models import User
from src.apis.models.prompt_models import Prompt

UserCRUD = MongoCRUD(database["users"], User)
PromptCRUD = MongoCRUD(database["prompt_templates"], Prompt)