Spaces:
Sleeping
Sleeping
File size: 6,877 Bytes
172064c b4c9cb7 172064c b4c9cb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from src.utils.logger import logger
from motor.motor_asyncio import AsyncIOMotorClient
from pydantic import BaseModel
from typing import Type, Dict, List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorCollection
from datetime import datetime, timezone, timedelta
from src.utils.logger import get_date_time
import os
client: AsyncIOMotorClient = AsyncIOMotorClient(os.getenv("MONGO_CONNECTION_STR"))
# database = client["custom_gpt"]
database = client["prompt_editor"]
class MongoCRUD:
def __init__(
self,
collection: AsyncIOMotorCollection,
model: Type[BaseModel],
ttl_seconds: Optional[int] = None,
):
self.collection = collection
self.model = model
self.ttl_seconds = ttl_seconds
self._index_created = False
async def _ensure_ttl_index(self):
"""Ensure TTL index exists"""
if self.ttl_seconds is not None and not self._index_created:
await self.collection.create_index("expire_at", expireAfterSeconds=0)
self._index_created = True
def _order_fields(self, doc: Dict) -> Dict:
"""Order fields in the document to ensure created_at and updated_at are at the end."""
ordered_doc = {
k: doc[k] for k in doc if k not in ["created_at", "updated_at", "expire_at"]
}
if "id" in doc:
ordered_doc["_id"] = ObjectId(doc["id"])
if "created_at" in doc:
ordered_doc["created_at"] = doc["created_at"]
if "updated_at" in doc:
ordered_doc["updated_at"] = doc["updated_at"]
if "expire_at" in doc:
ordered_doc["expire_at"] = doc["expire_at"]
return ordered_doc
async def create(self, data: Dict) -> str:
"""Create a new document in the collection asynchronously, optionally using a user-specified ID."""
await self._ensure_ttl_index()
now = get_date_time().replace(tzinfo=None)
data["created_at"] = now
data["updated_at"] = now
if self.ttl_seconds is not None:
data["expire_at"] = now + timedelta(seconds=self.ttl_seconds)
document = self.model(**data).model_dump(exclude_unset=True)
ordered_document = self._order_fields(document)
result = await self.collection.insert_one(ordered_document)
return str(result.inserted_id)
async def read(self, query: Dict, sort: List[tuple] = None) -> List[Dict]:
"""Read documents from the collection based on a query asynchronously."""
cursor = self.collection.find(query)
# Apply sorting if provided
if sort:
cursor = cursor.sort(sort)
docs = []
async for doc in cursor:
docs.append(
{
"_id": str(doc["_id"]),
**self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
}
)
return docs
async def read_one(self, query: Dict) -> Optional[Dict]:
"""Read a single document from the collection based on a query asynchronously."""
doc = await self.collection.find_one(query)
if doc:
doc["_id"] = str(doc["_id"])
return {
"_id": doc["_id"],
**self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
}
return None
async def update(self, query: Dict, data: Dict, upsert: bool = False) -> int:
await self._ensure_ttl_index()
if any(key.startswith("$") for key in data.keys()):
update_data = data
else:
data["updated_at"] = get_date_time().replace(tzinfo=None)
if self.ttl_seconds is not None:
data["expire_at"] = data["updated_at"] + timedelta(
seconds=self.ttl_seconds
)
update_data = {
"$set": self._order_fields(
self.model(**data).model_dump(exclude_unset=True)
)
}
result = await self.collection.update_many(query, update_data, upsert=upsert)
return result.modified_count
async def delete(self, query: Dict) -> int:
"""Delete documents from the collection based on a query asynchronously."""
result = await self.collection.delete_many(query)
return result.deleted_count
async def delete_one(self, query: Dict) -> int:
"""Delete a single document from the collection based on a query asynchronously."""
result = await self.collection.delete_one(query)
return result.deleted_count
async def find_by_id(self, id: str) -> Optional[Dict]:
"""Find a document by its ID asynchronously."""
return await self.read_one({"_id": ObjectId(id)})
async def find_all(self) -> List[Dict]:
"""Find all documents in the collection asynchronously."""
return await self.read({})
async def find_many(
self, filter: Dict, skip: int = 0, limit: int = 0, sort: List[tuple] = None
) -> List[Dict]:
"""
Find documents based on filter with pagination support.
Args:
filter: MongoDB query filter
skip: Number of documents to skip
limit: Maximum number of documents to return (0 means no limit)
sort: Optional sorting parameters [(field_name, direction)]
where direction is 1 for ascending, -1 for descending
Returns:
List of documents matching the filter
"""
cursor = self.collection.find(filter)
# Apply pagination
if skip > 0:
cursor = cursor.skip(skip)
if limit > 0:
cursor = cursor.limit(limit)
# Apply sorting if provided
if sort:
cursor = cursor.sort(sort)
docs = []
async for doc in cursor:
# Convert _id to string and prepare document
doc_id = str(doc["_id"])
doc_copy = {**doc}
doc_copy["_id"] = doc_id
# Process through model validation
try:
validated_doc = self.model(**doc_copy).model_dump(exclude={"id"})
docs.append({"_id": doc_id, **self._order_fields(validated_doc)})
except Exception as e:
logger.error(f"Error validating document {doc_id}: {str(e)}")
# Include document even if validation fails, but with original data
docs.append(
{"_id": doc_id, **{k: v for k, v in doc.items() if k != "_id"}}
)
return docs
from src.apis.models.user_models import User
from src.apis.models.prompt_models import Prompt
UserCRUD = MongoCRUD(database["users"], User)
PromptCRUD = MongoCRUD(database["prompt_templates"], Prompt) |