Spaces:
Runtime error
Runtime error
| import redis | |
| import os | |
| import json | |
| from fastapi import HTTPException | |
| from uuid import uuid4 | |
| from typing import Optional, List | |
| from llama_index.storage.chat_store.redis import RedisChatStore | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| from service.dto import ChatMessage | |
| class ChatStore: | |
| def __init__(self): | |
| self.redis_client = redis.Redis( | |
| host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com", | |
| port=10365, | |
| password=os.environ.get("REDIS_PASSWORD"), | |
| ) | |
| def generate_uuid(use_hex=False): | |
| if use_hex: | |
| return str(uuid4().hex) | |
| else: | |
| return str(uuid4()) | |
| def initialize_memory_bot(self, session_id=None): | |
| if session_id is None: | |
| session_id = self.generate_uuid() | |
| # chat_store = SimpleChatStore() | |
| chat_store = RedisChatStore( | |
| redis_client=self.redis_client | |
| ) # Need to be configured | |
| memory = ChatMemoryBuffer.from_defaults( | |
| token_limit=3000, chat_store=chat_store, chat_store_key=session_id | |
| ) | |
| return memory | |
| def get_messages(self, session_id: str) -> List[dict]: | |
| """Get messages for a session_id.""" | |
| items = self.redis_client.lrange(session_id, 0, -1) | |
| if len(items) == 0: | |
| return [] | |
| # Decode and parse each item into a dictionary | |
| return [json.loads(m.decode("utf-8")) for m in items] | |
| def delete_last_message(self, session_id: str) -> Optional[ChatMessage]: | |
| """Delete last message for a session_id.""" | |
| return self.redis_client.rpop(session_id) | |
| def delete_messages(self, key: str) -> Optional[List[ChatMessage]]: | |
| """Delete messages for a key.""" | |
| self.redis_client.delete(key) | |
| return None | |
| def clean_message(self, session_id: str) -> Optional[ChatMessage]: | |
| """Delete specific message for a session_id.""" | |
| current_list = self.redis_client.lrange(session_id, 0, -1) | |
| indices_to_delete = [] | |
| for index, item in enumerate(current_list): | |
| data = json.loads(item) # Parse JSON string to dict | |
| # Logic to determine if item should be removed | |
| if (data.get("role") == "assistant" and data.get("content") is None) or (data.get("role") == "tool"): | |
| indices_to_delete.append(index) | |
| # Remove elements by their indices in reverse order | |
| for index in reversed(indices_to_delete): | |
| self.redis_client.lrem(session_id, 1, current_list[index]) # Remove the element from the list in Redis | |
| def get_keys(self) -> List[str]: | |
| """Get all keys.""" | |
| try : | |
| print(self.redis_client.keys("*")) | |
| return [key.decode("utf-8") for key in self.redis_client.keys("*")] | |
| except Exception as e: | |
| # Log the error and raise HTTPException for FastAPI | |
| print(f"An error occurred in update data.: {e}") | |
| raise HTTPException( | |
| status_code=400, detail="the error when get keys" | |
| ) | |
| def add_message(self, session_id: str, message: ChatMessage) -> None: | |
| """Add a message for a session_id.""" | |
| item = json.dumps(self._message_to_dict(message)) | |
| self.redis_client.rpush(session_id, item) | |
| def _message_to_dict(self, message: ChatMessage) -> dict: | |
| return message.model_dump() |