Spaces:
Sleeping
Sleeping
| from src.langgraph.config.constant import MongoCfg, RedisCfg | |
| from src.utils.logger import logger | |
| from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory | |
| from langchain_redis import RedisChatMessageHistory | |
| from motor.motor_asyncio import AsyncIOMotorClient | |
| from pydantic import BaseModel | |
| from typing import Type, Dict, List, Optional | |
| from bson import ObjectId | |
| from motor.motor_asyncio import AsyncIOMotorCollection | |
| from datetime import datetime, timezone, timedelta | |
| from src.apis.models.post_models import Comment, Reaction, Post | |
| from src.apis.models.destination_models import Destination | |
| from src.apis.models.schedule_models import Schedule | |
| from src.apis.models.hotel_models import BookHotel | |
| from src.apis.interfaces.api_interface import ChatHistoryManagement | |
| from src.apis.models.user_models import User | |
| from src.utils.logger import get_date_time | |
| client: AsyncIOMotorClient = AsyncIOMotorClient(MongoCfg.MONGODB_URL) | |
| database = client[MongoCfg.MONGO_INDEX] | |
| class MongoCRUD: | |
| def __init__( | |
| self, | |
| collection: AsyncIOMotorCollection, | |
| model: Type[BaseModel], | |
| ttl_seconds: Optional[int] = None, | |
| ): | |
| self.collection = collection | |
| self.model = model | |
| self.ttl_seconds = ttl_seconds | |
| self._index_created = False | |
| async def _ensure_ttl_index(self): | |
| """Ensure TTL index exists""" | |
| if self.ttl_seconds is not None and not self._index_created: | |
| await self.collection.create_index("expire_at", expireAfterSeconds=0) | |
| self._index_created = True | |
| def _order_fields(self, doc: Dict) -> Dict: | |
| """Order fields in the document to ensure created_at and updated_at are at the end.""" | |
| ordered_doc = { | |
| k: doc[k] for k in doc if k not in ["created_at", "updated_at", "expire_at"] | |
| } | |
| if "id" in doc: | |
| ordered_doc["_id"] = ObjectId(doc["id"]) | |
| if "created_at" in doc: | |
| ordered_doc["created_at"] = doc["created_at"] | |
| if "updated_at" in doc: | |
| ordered_doc["updated_at"] = doc["updated_at"] | |
| if "expire_at" in doc: | |
| ordered_doc["expire_at"] = doc["expire_at"] | |
| return ordered_doc | |
| async def create(self, data: Dict) -> str: | |
| """Create a new document in the collection asynchronously, optionally using a user-specified ID.""" | |
| await self._ensure_ttl_index() | |
| now = get_date_time().replace(tzinfo=None) | |
| data["created_at"] = now | |
| data["updated_at"] = now | |
| if self.ttl_seconds is not None: | |
| data["expire_at"] = now + timedelta(seconds=self.ttl_seconds) | |
| document = self.model(**data).model_dump(exclude_unset=True) | |
| ordered_document = self._order_fields(document) | |
| result = await self.collection.insert_one(ordered_document) | |
| return str(result.inserted_id) | |
| async def read(self, query: Dict) -> List[Dict]: | |
| """Read documents from the collection based on a query asynchronously.""" | |
| cursor = self.collection.find(query) | |
| docs = [] | |
| async for doc in cursor: | |
| docs.append( | |
| { | |
| "_id": str(doc["_id"]), | |
| **self._order_fields(self.model(**doc).model_dump(exclude={"id"})), | |
| } | |
| ) | |
| return docs | |
| async def read_one(self, query: Dict) -> Optional[Dict]: | |
| """Read a single document from the collection based on a query asynchronously.""" | |
| doc = await self.collection.find_one(query) | |
| if doc: | |
| doc["_id"] = str(doc["_id"]) | |
| return { | |
| "_id": doc["_id"], | |
| **self._order_fields(self.model(**doc).model_dump(exclude={"id"})), | |
| } | |
| return None | |
| async def update(self, query: Dict, data: Dict) -> int: | |
| """Update documents in the collection based on a query asynchronously.""" | |
| await self._ensure_ttl_index() | |
| # Check if update operators like $inc, $set, etc., are used | |
| if any(key.startswith("$") for key in data.keys()): | |
| update_data = data | |
| else: | |
| # If no MongoDB operators are used, treat it as a normal update | |
| data["updated_at"] = get_date_time().replace(tzinfo=None) | |
| if self.ttl_seconds is not None: | |
| data["expire_at"] = data["updated_at"] + timedelta( | |
| seconds=self.ttl_seconds | |
| ) | |
| update_data = { | |
| "$set": self._order_fields( | |
| self.model(**data).model_dump(exclude_unset=True) | |
| ) | |
| } | |
| result = await self.collection.update_many(query, update_data) | |
| return result.modified_count | |
| async def delete(self, query: Dict) -> int: | |
| """Delete documents from the collection based on a query asynchronously.""" | |
| result = await self.collection.delete_many(query) | |
| return result.deleted_count | |
| async def delete_one(self, query: Dict) -> int: | |
| """Delete a single document from the collection based on a query asynchronously.""" | |
| result = await self.collection.delete_one(query) | |
| return result.deleted_count | |
| async def find_by_id(self, id: str) -> Optional[Dict]: | |
| """Find a document by its ID asynchronously.""" | |
| return await self.read_one({"_id": ObjectId(id)}) | |
| async def find_all(self) -> List[Dict]: | |
| """Find all documents in the collection asynchronously.""" | |
| return await self.read({}) | |
| async def find_many( | |
| self, filter: Dict, skip: int = 0, limit: int = 0, sort: List[tuple] = None | |
| ) -> List[Dict]: | |
| """ | |
| Find documents based on filter with pagination support. | |
| Args: | |
| filter: MongoDB query filter | |
| skip: Number of documents to skip | |
| limit: Maximum number of documents to return (0 means no limit) | |
| sort: Optional sorting parameters [(field_name, direction)] | |
| where direction is 1 for ascending, -1 for descending | |
| Returns: | |
| List of documents matching the filter | |
| """ | |
| cursor = self.collection.find(filter) | |
| # Apply pagination | |
| if skip > 0: | |
| cursor = cursor.skip(skip) | |
| if limit > 0: | |
| cursor = cursor.limit(limit) | |
| # Apply sorting if provided | |
| if sort: | |
| cursor = cursor.sort(sort) | |
| docs = [] | |
| async for doc in cursor: | |
| # Convert _id to string and prepare document | |
| doc_id = str(doc["_id"]) | |
| doc_copy = {**doc} | |
| doc_copy["_id"] = doc_id | |
| # Process through model validation | |
| try: | |
| validated_doc = self.model(**doc_copy).model_dump(exclude={"id"}) | |
| docs.append({"_id": doc_id, **self._order_fields(validated_doc)}) | |
| except Exception as e: | |
| logger.error(f"Error validating document {doc_id}: {str(e)}") | |
| # Include document even if validation fails, but with original data | |
| docs.append( | |
| {"_id": doc_id, **{k: v for k, v in doc.items() if k != "_id"}} | |
| ) | |
| return docs | |
| async def count(self, filter: Dict) -> int: | |
| """ | |
| Count documents matching the filter. | |
| Args: | |
| filter: MongoDB query filter | |
| Returns: | |
| Number of documents matching the filter | |
| """ | |
| try: | |
| return await self.collection.count_documents(filter) | |
| except Exception as e: | |
| logger.error(f"Error counting documents: {str(e)}") | |
| return 0 | |
| from motor.motor_asyncio import AsyncIOMotorCollection | |
| from bson.son import SON | |
| from datetime import datetime, timedelta | |
| from typing import List, Dict | |
| class PostMongoCRUD(MongoCRUD): | |
| async def find_many_with_score( | |
| self, | |
| filter: Dict, | |
| top_destinations: List[str], | |
| limit: int = 10, | |
| skip: int = 0, | |
| ) -> List[Dict]: | |
| now = datetime.now(timezone.utc) | |
| destination_score_branches = [ | |
| {"case": {"$eq": ["$destination_id", did]}, "then": (5 - i) * 2} | |
| for i, did in enumerate(top_destinations) | |
| ] | |
| pipeline = [ | |
| {"$match": filter}, | |
| { | |
| "$addFields": { | |
| "BaseScore": 1, | |
| "DestinationScore": ( | |
| { | |
| "$switch": { | |
| "branches": destination_score_branches, | |
| "default": 0, | |
| } | |
| } | |
| if top_destinations | |
| else 0 | |
| ), | |
| "EngagementScore": { | |
| "$min": [ | |
| { | |
| "$divide": [ | |
| { | |
| "$add": [ | |
| "$reaction_count", | |
| {"$multiply": ["$comment_count", 2]}, | |
| ] | |
| }, | |
| 100, | |
| ] | |
| }, | |
| 5, | |
| ] | |
| }, | |
| "FreshnessScore": { | |
| "$switch": { | |
| "branches": [ | |
| { | |
| "case": { | |
| "$gte": ["$created_at", now - timedelta(days=1)] | |
| }, | |
| "then": 5, | |
| }, | |
| { | |
| "case": { | |
| "$gte": ["$created_at", now - timedelta(days=3)] | |
| }, | |
| "then": 3, | |
| }, | |
| { | |
| "case": { | |
| "$gte": ["$created_at", now - timedelta(days=7)] | |
| }, | |
| "then": 2, | |
| }, | |
| ], | |
| "default": 1, | |
| } | |
| }, | |
| } | |
| }, | |
| { | |
| "$addFields": { | |
| "PriorityScore": { | |
| "$add": [ | |
| "$BaseScore", | |
| "$DestinationScore", | |
| "$EngagementScore", | |
| "$FreshnessScore", | |
| ] | |
| } | |
| } | |
| }, | |
| {"$sort": SON([("PriorityScore", -1), ("created_at", -1)])}, | |
| {"$skip": skip}, | |
| {"$limit": limit}, | |
| ] | |
| cursor = self.collection.aggregate(pipeline) | |
| results = [] | |
| async for doc in cursor: | |
| doc["_id"] = str(doc["_id"]) | |
| results.append(doc) | |
| return results | |
| PostCRUD = PostMongoCRUD(database[MongoCfg.POST], Post) | |
| def chat_messages_history( | |
| session_id: str, number_of_messages: int = MongoCfg.MAX_HISTORY_SIZE, db="mongo" | |
| ): | |
| if not session_id: | |
| session_id = "12345678910" | |
| logger.warning("Session ID not provided, using default session ID") | |
| if db == "redis": | |
| return RedisChatMessageHistory( | |
| session_id=session_id, | |
| redis_url=RedisCfg.REDIS_URL, | |
| # ttl=605000, | |
| ttl=40000, | |
| ) | |
| return MongoDBChatMessageHistory( | |
| session_id=session_id, | |
| connection_string=MongoCfg.MONGODB_URL, | |
| database_name=MongoCfg.MONGO_INDEX, | |
| collection_name=MongoCfg.CHAT_HISTORY, | |
| history_size=number_of_messages, | |
| ) | |
| BookHotelCRUD = MongoCRUD(database[MongoCfg.BOOK_HOTEL], BookHotel) | |
| ScheduleCRUD = MongoCRUD(database[MongoCfg.ACTIVITY], Schedule) | |
| UserCRUD = MongoCRUD(database[MongoCfg.USER], User) | |
| ReactionCRUD = MongoCRUD(database[MongoCfg.REACTION], Reaction) | |
| CommentCRUD = MongoCRUD(database[MongoCfg.COMMENT], Comment) | |
| DestinationCRUD = MongoCRUD(database[MongoCfg.DESTINATION], Destination) | |
| chat_history_management_crud = MongoCRUD( | |
| database["chat_history_management"], ChatHistoryManagement, 3600 | |
| ) | |