ABAO77's picture
Upload 164 files
e390496 verified
from src.langgraph.config.constant import MongoCfg, RedisCfg
from src.utils.logger import logger
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
from langchain_redis import RedisChatMessageHistory
from motor.motor_asyncio import AsyncIOMotorClient
from pydantic import BaseModel
from typing import Type, Dict, List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorCollection
from datetime import datetime, timezone, timedelta
from src.apis.models.post_models import Comment, Reaction, Post
from src.apis.models.destination_models import Destination
from src.apis.models.schedule_models import Schedule
from src.apis.models.hotel_models import BookHotel
from src.apis.interfaces.api_interface import ChatHistoryManagement
from src.apis.models.user_models import User
from src.utils.logger import get_date_time
client: AsyncIOMotorClient = AsyncIOMotorClient(MongoCfg.MONGODB_URL)
database = client[MongoCfg.MONGO_INDEX]
class MongoCRUD:
def __init__(
self,
collection: AsyncIOMotorCollection,
model: Type[BaseModel],
ttl_seconds: Optional[int] = None,
):
self.collection = collection
self.model = model
self.ttl_seconds = ttl_seconds
self._index_created = False
async def _ensure_ttl_index(self):
"""Ensure TTL index exists"""
if self.ttl_seconds is not None and not self._index_created:
await self.collection.create_index("expire_at", expireAfterSeconds=0)
self._index_created = True
def _order_fields(self, doc: Dict) -> Dict:
"""Order fields in the document to ensure created_at and updated_at are at the end."""
ordered_doc = {
k: doc[k] for k in doc if k not in ["created_at", "updated_at", "expire_at"]
}
if "id" in doc:
ordered_doc["_id"] = ObjectId(doc["id"])
if "created_at" in doc:
ordered_doc["created_at"] = doc["created_at"]
if "updated_at" in doc:
ordered_doc["updated_at"] = doc["updated_at"]
if "expire_at" in doc:
ordered_doc["expire_at"] = doc["expire_at"]
return ordered_doc
async def create(self, data: Dict) -> str:
"""Create a new document in the collection asynchronously, optionally using a user-specified ID."""
await self._ensure_ttl_index()
now = get_date_time().replace(tzinfo=None)
data["created_at"] = now
data["updated_at"] = now
if self.ttl_seconds is not None:
data["expire_at"] = now + timedelta(seconds=self.ttl_seconds)
document = self.model(**data).model_dump(exclude_unset=True)
ordered_document = self._order_fields(document)
result = await self.collection.insert_one(ordered_document)
return str(result.inserted_id)
async def read(self, query: Dict) -> List[Dict]:
"""Read documents from the collection based on a query asynchronously."""
cursor = self.collection.find(query)
docs = []
async for doc in cursor:
docs.append(
{
"_id": str(doc["_id"]),
**self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
}
)
return docs
async def read_one(self, query: Dict) -> Optional[Dict]:
"""Read a single document from the collection based on a query asynchronously."""
doc = await self.collection.find_one(query)
if doc:
doc["_id"] = str(doc["_id"])
return {
"_id": doc["_id"],
**self._order_fields(self.model(**doc).model_dump(exclude={"id"})),
}
return None
async def update(self, query: Dict, data: Dict) -> int:
"""Update documents in the collection based on a query asynchronously."""
await self._ensure_ttl_index()
# Check if update operators like $inc, $set, etc., are used
if any(key.startswith("$") for key in data.keys()):
update_data = data
else:
# If no MongoDB operators are used, treat it as a normal update
data["updated_at"] = get_date_time().replace(tzinfo=None)
if self.ttl_seconds is not None:
data["expire_at"] = data["updated_at"] + timedelta(
seconds=self.ttl_seconds
)
update_data = {
"$set": self._order_fields(
self.model(**data).model_dump(exclude_unset=True)
)
}
result = await self.collection.update_many(query, update_data)
return result.modified_count
async def delete(self, query: Dict) -> int:
"""Delete documents from the collection based on a query asynchronously."""
result = await self.collection.delete_many(query)
return result.deleted_count
async def delete_one(self, query: Dict) -> int:
"""Delete a single document from the collection based on a query asynchronously."""
result = await self.collection.delete_one(query)
return result.deleted_count
async def find_by_id(self, id: str) -> Optional[Dict]:
"""Find a document by its ID asynchronously."""
return await self.read_one({"_id": ObjectId(id)})
async def find_all(self) -> List[Dict]:
"""Find all documents in the collection asynchronously."""
return await self.read({})
async def find_many(
self, filter: Dict, skip: int = 0, limit: int = 0, sort: List[tuple] = None
) -> List[Dict]:
"""
Find documents based on filter with pagination support.
Args:
filter: MongoDB query filter
skip: Number of documents to skip
limit: Maximum number of documents to return (0 means no limit)
sort: Optional sorting parameters [(field_name, direction)]
where direction is 1 for ascending, -1 for descending
Returns:
List of documents matching the filter
"""
cursor = self.collection.find(filter)
# Apply pagination
if skip > 0:
cursor = cursor.skip(skip)
if limit > 0:
cursor = cursor.limit(limit)
# Apply sorting if provided
if sort:
cursor = cursor.sort(sort)
docs = []
async for doc in cursor:
# Convert _id to string and prepare document
doc_id = str(doc["_id"])
doc_copy = {**doc}
doc_copy["_id"] = doc_id
# Process through model validation
try:
validated_doc = self.model(**doc_copy).model_dump(exclude={"id"})
docs.append({"_id": doc_id, **self._order_fields(validated_doc)})
except Exception as e:
logger.error(f"Error validating document {doc_id}: {str(e)}")
# Include document even if validation fails, but with original data
docs.append(
{"_id": doc_id, **{k: v for k, v in doc.items() if k != "_id"}}
)
return docs
async def count(self, filter: Dict) -> int:
"""
Count documents matching the filter.
Args:
filter: MongoDB query filter
Returns:
Number of documents matching the filter
"""
try:
return await self.collection.count_documents(filter)
except Exception as e:
logger.error(f"Error counting documents: {str(e)}")
return 0
from motor.motor_asyncio import AsyncIOMotorCollection
from bson.son import SON
from datetime import datetime, timedelta
from typing import List, Dict
class PostMongoCRUD(MongoCRUD):
async def find_many_with_score(
self,
filter: Dict,
top_destinations: List[str],
limit: int = 10,
skip: int = 0,
) -> List[Dict]:
now = datetime.now(timezone.utc)
destination_score_branches = [
{"case": {"$eq": ["$destination_id", did]}, "then": (5 - i) * 2}
for i, did in enumerate(top_destinations)
]
pipeline = [
{"$match": filter},
{
"$addFields": {
"BaseScore": 1,
"DestinationScore": (
{
"$switch": {
"branches": destination_score_branches,
"default": 0,
}
}
if top_destinations
else 0
),
"EngagementScore": {
"$min": [
{
"$divide": [
{
"$add": [
"$reaction_count",
{"$multiply": ["$comment_count", 2]},
]
},
100,
]
},
5,
]
},
"FreshnessScore": {
"$switch": {
"branches": [
{
"case": {
"$gte": ["$created_at", now - timedelta(days=1)]
},
"then": 5,
},
{
"case": {
"$gte": ["$created_at", now - timedelta(days=3)]
},
"then": 3,
},
{
"case": {
"$gte": ["$created_at", now - timedelta(days=7)]
},
"then": 2,
},
],
"default": 1,
}
},
}
},
{
"$addFields": {
"PriorityScore": {
"$add": [
"$BaseScore",
"$DestinationScore",
"$EngagementScore",
"$FreshnessScore",
]
}
}
},
{"$sort": SON([("PriorityScore", -1), ("created_at", -1)])},
{"$skip": skip},
{"$limit": limit},
]
cursor = self.collection.aggregate(pipeline)
results = []
async for doc in cursor:
doc["_id"] = str(doc["_id"])
results.append(doc)
return results
PostCRUD = PostMongoCRUD(database[MongoCfg.POST], Post)
def chat_messages_history(
session_id: str, number_of_messages: int = MongoCfg.MAX_HISTORY_SIZE, db="mongo"
):
if not session_id:
session_id = "12345678910"
logger.warning("Session ID not provided, using default session ID")
if db == "redis":
return RedisChatMessageHistory(
session_id=session_id,
redis_url=RedisCfg.REDIS_URL,
# ttl=605000,
ttl=40000,
)
return MongoDBChatMessageHistory(
session_id=session_id,
connection_string=MongoCfg.MONGODB_URL,
database_name=MongoCfg.MONGO_INDEX,
collection_name=MongoCfg.CHAT_HISTORY,
history_size=number_of_messages,
)
BookHotelCRUD = MongoCRUD(database[MongoCfg.BOOK_HOTEL], BookHotel)
ScheduleCRUD = MongoCRUD(database[MongoCfg.ACTIVITY], Schedule)
UserCRUD = MongoCRUD(database[MongoCfg.USER], User)
ReactionCRUD = MongoCRUD(database[MongoCfg.REACTION], Reaction)
CommentCRUD = MongoCRUD(database[MongoCfg.COMMENT], Comment)
DestinationCRUD = MongoCRUD(database[MongoCfg.DESTINATION], Destination)
chat_history_management_crud = MongoCRUD(
database["chat_history_management"], ChatHistoryManagement, 3600
)