Spaces:
Sleeping
Sleeping
File size: 12,863 Bytes
3973360 5426b49 3973360 7691ec7 3973360 7691ec7 3973360 02d6bde 8c4ead2 02d6bde 8c4ead2 02d6bde 8c4ead2 3973360 02d6bde 3973360 1815cc8 3973360 e390496 3973360 5426b49 3973360 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
from src.langgraph.config.constant import MongoCfg, RedisCfg
from src.utils.logger import logger
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
from langchain_redis import RedisChatMessageHistory
from motor.motor_asyncio import AsyncIOMotorClient
from pydantic import BaseModel
from typing import Type, Dict, List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorCollection
from datetime import datetime, timezone, timedelta
from src.apis.models.post_models import Comment, Reaction, Post
from src.apis.models.destination_models import Destination
from src.apis.models.schedule_models import Schedule
from src.apis.models.hotel_models import BookHotel
from src.apis.interfaces.api_interface import ChatHistoryManagement
from src.apis.models.user_models import User
from src.utils.logger import get_date_time
client: AsyncIOMotorClient = AsyncIOMotorClient(MongoCfg.MONGODB_URL)
database = client[MongoCfg.MONGO_INDEX]
class MongoCRUD:
def __init__(
self,
collection: AsyncIOMotorCollection,
model: Type[BaseModel],
ttl_seconds: Optional[int] = None,
):
self.collection = collection
self.model = model
self.ttl_seconds = ttl_seconds
self._index_created = False
async def _ensure_ttl_index(self):
"""Ensure TTL index exists"""
if self.ttl_seconds is not None and not self._index_created:
await self.collection.create_index("expire_at", expireAfterSeconds=0)
self._index_created = True
def _order_fields(self, doc: Dict) -> Dict:
"""Order fields in the document to ensure created_at and updated_at are at the end."""
ordered_doc = {
k: doc[k] for k in doc if k not in ["created_at", "updated_at", "expire_at"]
}
if "id" in doc:
ordered_doc["_id"] = ObjectId(doc["id"])
if "created_at" in doc:
ordered_doc["created_at"] = doc["created_at"]
if "updated_at" in doc:
ordered_doc["updated_at"] = doc["updated_at"]
if "expire_at" in doc:
ordered_doc["expire_at"] = doc["expire_at"]
return ordered_doc
async def create(self, data: Dict) -> str:
"""Create a new document in the collection asynchronously, optionally using a user-specified ID."""
await self._ensure_ttl_index()
now = get_date_time().replace(tzinfo=None)
data["created_at"] = now
data["updated_at"] = now
if self.ttl_seconds is not None:
data["expire_at"] = now + timedelta(seconds=self.ttl_seconds)
document = self.model(**data).model_dump(exclude_unset=True)
ordered_document = self._order_fields(document)
result = await self.collection.insert_one(ordered_document)
return str(result.inserted_id)
async def read(self, query: Dict) -> List[Dict]:
"""Read documents from the collection based on a query asynchronously."""
cursor = self.collection.find(query)
docs = []
async for doc in cursor:
docs.append(
{
"_id": str(doc["_id"]),
**self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
}
)
return docs
async def read_one(self, query: Dict) -> Optional[Dict]:
"""Read a single document from the collection based on a query asynchronously."""
doc = await self.collection.find_one(query)
if doc:
doc["_id"] = str(doc["_id"])
return {
"_id": doc["_id"],
**self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
}
return None
async def update(self, query: Dict, data: Dict) -> int:
"""Update documents in the collection based on a query asynchronously."""
await self._ensure_ttl_index()
# Check if update operators like $inc, $set, etc., are used
if any(key.startswith("$") for key in data.keys()):
update_data = data
else:
# If no MongoDB operators are used, treat it as a normal update
data["updated_at"] = get_date_time().replace(tzinfo=None)
if self.ttl_seconds is not None:
data["expire_at"] = data["updated_at"] + timedelta(
seconds=self.ttl_seconds
)
update_data = {
"$set": self._order_fields(
self.model(**data).model_dump(exclude_unset=True)
)
}
result = await self.collection.update_many(query, update_data)
return result.modified_count
async def delete(self, query: Dict) -> int:
"""Delete documents from the collection based on a query asynchronously."""
result = await self.collection.delete_many(query)
return result.deleted_count
async def delete_one(self, query: Dict) -> int:
"""Delete a single document from the collection based on a query asynchronously."""
result = await self.collection.delete_one(query)
return result.deleted_count
async def find_by_id(self, id: str) -> Optional[Dict]:
"""Find a document by its ID asynchronously."""
return await self.read_one({"_id": ObjectId(id)})
async def find_all(self) -> List[Dict]:
"""Find all documents in the collection asynchronously."""
return await self.read({})
async def find_many(
self, filter: Dict, skip: int = 0, limit: int = 0, sort: List[tuple] = None
) -> List[Dict]:
"""
Find documents based on filter with pagination support.
Args:
filter: MongoDB query filter
skip: Number of documents to skip
limit: Maximum number of documents to return (0 means no limit)
sort: Optional sorting parameters [(field_name, direction)]
where direction is 1 for ascending, -1 for descending
Returns:
List of documents matching the filter
"""
cursor = self.collection.find(filter)
# Apply pagination
if skip > 0:
cursor = cursor.skip(skip)
if limit > 0:
cursor = cursor.limit(limit)
# Apply sorting if provided
if sort:
cursor = cursor.sort(sort)
docs = []
async for doc in cursor:
# Convert _id to string and prepare document
doc_id = str(doc["_id"])
doc_copy = {**doc}
doc_copy["_id"] = doc_id
# Process through model validation
try:
validated_doc = self.model(**doc_copy).model_dump(exclude={"id"})
docs.append({"_id": doc_id, **self._order_fields(validated_doc)})
except Exception as e:
logger.error(f"Error validating document {doc_id}: {str(e)}")
# Include document even if validation fails, but with original data
docs.append(
{"_id": doc_id, **{k: v for k, v in doc.items() if k != "_id"}}
)
return docs
async def count(self, filter: Dict) -> int:
"""
Count documents matching the filter.
Args:
filter: MongoDB query filter
Returns:
Number of documents matching the filter
"""
try:
return await self.collection.count_documents(filter)
except Exception as e:
logger.error(f"Error counting documents: {str(e)}")
return 0
from motor.motor_asyncio import AsyncIOMotorCollection
from bson.son import SON
from datetime import datetime, timedelta
from typing import List, Dict
class PostMongoCRUD(MongoCRUD):
async def find_many_with_score(
self,
filter: Dict,
top_destinations: List[str],
limit: int = 10,
skip: int = 0,
) -> List[Dict]:
now = datetime.now(timezone.utc)
destination_score_branches = [
{"case": {"$eq": ["$destination_id", did]}, "then": (5 - i) * 2}
for i, did in enumerate(top_destinations)
]
pipeline = [
{"$match": filter},
{
"$addFields": {
"BaseScore": 1,
"DestinationScore": (
{
"$switch": {
"branches": destination_score_branches,
"default": 0,
}
}
if top_destinations
else 0
),
"EngagementScore": {
"$min": [
{
"$divide": [
{
"$add": [
"$reaction_count",
{"$multiply": ["$comment_count", 2]},
]
},
100,
]
},
5,
]
},
"FreshnessScore": {
"$switch": {
"branches": [
{
"case": {
"$gte": ["$created_at", now - timedelta(days=1)]
},
"then": 5,
},
{
"case": {
"$gte": ["$created_at", now - timedelta(days=3)]
},
"then": 3,
},
{
"case": {
"$gte": ["$created_at", now - timedelta(days=7)]
},
"then": 2,
},
],
"default": 1,
}
},
}
},
{
"$addFields": {
"PriorityScore": {
"$add": [
"$BaseScore",
"$DestinationScore",
"$EngagementScore",
"$FreshnessScore",
]
}
}
},
{"$sort": SON([("PriorityScore", -1), ("created_at", -1)])},
{"$skip": skip},
{"$limit": limit},
]
cursor = self.collection.aggregate(pipeline)
results = []
async for doc in cursor:
doc["_id"] = str(doc["_id"])
results.append(doc)
return results
PostCRUD = PostMongoCRUD(database[MongoCfg.POST], Post)
def chat_messages_history(
session_id: str, number_of_messages: int = MongoCfg.MAX_HISTORY_SIZE, db="mongo"
):
if not session_id:
session_id = "12345678910"
logger.warning("Session ID not provided, using default session ID")
if db == "redis":
return RedisChatMessageHistory(
session_id=session_id,
redis_url=RedisCfg.REDIS_URL,
# ttl=605000,
ttl=40000,
)
return MongoDBChatMessageHistory(
session_id=session_id,
connection_string=MongoCfg.MONGODB_URL,
database_name=MongoCfg.MONGO_INDEX,
collection_name=MongoCfg.CHAT_HISTORY,
history_size=number_of_messages,
)
BookHotelCRUD = MongoCRUD(database[MongoCfg.BOOK_HOTEL], BookHotel)
ScheduleCRUD = MongoCRUD(database[MongoCfg.ACTIVITY], Schedule)
UserCRUD = MongoCRUD(database[MongoCfg.USER], User)
ReactionCRUD = MongoCRUD(database[MongoCfg.REACTION], Reaction)
CommentCRUD = MongoCRUD(database[MongoCfg.COMMENT], Comment)
DestinationCRUD = MongoCRUD(database[MongoCfg.DESTINATION], Destination)
chat_history_management_crud = MongoCRUD(
database["chat_history_management"], ChatHistoryManagement, 3600
)
|