Spaces:
Runtime error
Runtime error
| import redis | |
| import os | |
| import json | |
| from datetime import datetime | |
| from dotenv import load_dotenv | |
| from fastapi.responses import JSONResponse | |
| from typing import Optional, List, Dict | |
| from llama_index.storage.chat_store.redis import RedisChatStore | |
| from pymongo.mongo_client import MongoClient | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| from service.dto import ChatMessage | |
| load_dotenv() | |
| class ChatStore: | |
| def __init__(self): | |
| self.redis_client = redis.Redis( | |
| # host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com", | |
| host = os.getenv("REDIS_HOST"), | |
| port=os.getenv("REDIS_PORT"), | |
| username = os.getenv("REDIS_USERNAME"), | |
| password=os.getenv("REDIS_PASSWORD"), | |
| ) | |
| uri = os.getenv("MONGO_URI") | |
| self.client = MongoClient(uri) | |
| def initialize_memory_bot(self, session_id): | |
| # Decode Redis keys to work with strings | |
| redis_keys = [key.decode('utf-8') for key in self.redis_client.keys()] | |
| chat_store = RedisChatStore( | |
| redis_client=self.redis_client, ttl=86400 # Time-to-live set for 1 hour | |
| ) | |
| db = self.client["bot_database"] | |
| # Check if the session exists in Redis or MongoDB | |
| if session_id in redis_keys: | |
| # If the session already exists in Redis, create the memory buffer using Redis | |
| memory = ChatMemoryBuffer.from_defaults( | |
| token_limit=3000, chat_store=chat_store, chat_store_key=session_id | |
| ) | |
| elif session_id in db.list_collection_names(): | |
| # If the session exists in MongoDB but not Redis, fetch messages from MongoDB | |
| self.add_chat_history_to_redis(session_id) # Add chat history to Redis | |
| # Then create the memory buffer using Redis | |
| memory = ChatMemoryBuffer.from_defaults( | |
| token_limit=3000, chat_store=chat_store, chat_store_key=session_id | |
| ) | |
| else: | |
| # If the session doesn't exist in either Redis or MongoDB, create an empty memory buffer | |
| memory = ChatMemoryBuffer.from_defaults( | |
| token_limit=3000, chat_store=chat_store, chat_store_key=session_id | |
| ) | |
| return memory | |
| def get_messages(self, session_id: str) -> List[dict]: | |
| """Get messages for a session_id.""" | |
| items = self.redis_client.lrange(session_id, 0, -1) | |
| if len(items) == 0: | |
| return [] | |
| # Decode and parse each item into a dictionary | |
| return [json.loads(m.decode("utf-8")) for m in items] | |
| def get_last_message(self, session_id: str) -> Optional[Dict]: | |
| """Get the last message for a session_id.""" | |
| last_message = self.redis_client.lindex(session_id, -1) | |
| if last_message is None: | |
| return None # Return None if there are no messages | |
| # Decode and parse the last message into a dictionary | |
| return json.loads(last_message.decode("utf-8")) | |
| def get_last_message_mongodb(self, session_id: str): | |
| db = self.client["bot_database"] | |
| collection = db[session_id] | |
| # Get the last document by sorting by _id in descending order | |
| last_document = collection.find().sort("_id", -1).limit(1) | |
| # Iterasi last_document dan kembalikan isi content jika ada | |
| for doc in last_document: | |
| return str(doc.get('content', "")) # kembalikan content atau string kosong jika tidak ada | |
| # Jika tidak ada dokumen, kembalikan string kosong | |
| return "" | |
| def delete_last_message(self, session_id: str) -> Optional[ChatMessage]: | |
| """Delete last message for a session_id.""" | |
| return self.redis_client.rpop(session_id) | |
| def delete_messages(self, session_id: str) -> Optional[List[ChatMessage]]: | |
| """Delete messages for a session_id.""" | |
| self.redis_client.delete(session_id) | |
| db = self.client["bot_database"] | |
| db.session_id.drop() | |
| return None | |
| def clean_message(self, session_id: str) -> Optional[ChatMessage]: | |
| """Delete specific message for a session_id.""" | |
| current_list = self.redis_client.lrange(session_id, 0, -1) | |
| indices_to_delete = [] | |
| for index, item in enumerate(current_list): | |
| data = json.loads(item) # Parse JSON string to dict | |
| # Logic to determine if item should be removed | |
| if (data.get("role") == "assistant" and data.get("content") is None) or ( | |
| data.get("role") == "tool" | |
| ): | |
| indices_to_delete.append(index) | |
| # Remove elements by their indices in reverse order | |
| for index in reversed(indices_to_delete): | |
| self.redis_client.lrem( | |
| session_id, 1, current_list[index] | |
| ) # Remove the element from the list in Redis | |
| def get_keys(self) -> List[str]: | |
| """Get all keys.""" | |
| try: | |
| return [key.decode("utf-8") for key in self.redis_client.keys("*")] | |
| except Exception as e: | |
| return JSONResponse(status_code=400, content="the error when get keys") | |
| def add_message(self, session_id: str, message: Optional[ChatMessage]) -> None: | |
| """Add a message for a session_id.""" | |
| item = json.dumps(self._message_to_dict(message)) | |
| self.redis_client.rpush(session_id, item) | |
| def _message_to_dict(self, message: Optional[ChatMessage]) -> dict: | |
| # Convert the ChatMessage instance into a dictionary with necessary adjustments | |
| message_dict = message.model_dump() | |
| # Convert any datetime fields to ISO format, if needed | |
| if isinstance(message_dict.get('timestamp'), datetime): | |
| message_dict['timestamp'] = message_dict['timestamp'].isoformat() | |
| return message_dict | |
| def add_chat_history_to_redis(self, session_id: str) -> None: | |
| """Fetch chat history from MongoDB and add it to Redis.""" | |
| db = self.client["bot_database"] | |
| collection = db[session_id] | |
| try: | |
| chat_history = collection.find() | |
| chat_history_list = [ | |
| { | |
| key: message[key] | |
| for key in message | |
| if key not in ["_id", "timestamp"] and message[key] is not None | |
| } | |
| for message in chat_history | |
| if message is not None | |
| ] | |
| for message in chat_history_list: | |
| # Convert MongoDB document to the format you need | |
| item = json.dumps( | |
| self._message_to_dict(ChatMessage(**message)) | |
| ) # Convert message to dict | |
| # Push to Redis | |
| self.redis_client.rpush(session_id, item) | |
| self.redis_client.expire(session_id, time=86400) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content="Add Database Error") | |
| def get_all_messages_mongodb(self, session_id): | |
| """Get all messages for a session_id from MongoDB.""" | |
| try: | |
| db = self.client["bot_database"] | |
| collection = db[session_id] | |
| # Retrieve all documents from the collection | |
| documents = collection.find() | |
| # Convert the cursor to a list and exclude the _id field | |
| documents_list = [ | |
| {key: doc[key] for key in doc if key !="_id" and doc[key] is not None} | |
| for doc in documents | |
| ] | |
| return documents_list | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content=f"An error occurred while retrieving messages: {e}") |