Spaces:
Runtime error
Runtime error
| """Chat history management service for the chat agent.""" | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from typing import List, Optional, Dict, Any | |
| import redis | |
| from sqlalchemy.exc import SQLAlchemyError | |
| from sqlalchemy import desc | |
| from ..models.message import Message | |
| from ..models.base import db | |
| logger = logging.getLogger(__name__) | |
| class ChatHistoryError(Exception): | |
| """Base exception for chat history errors.""" | |
| pass | |
| class ChatHistoryManager: | |
| """Manages chat history with dual storage (Redis cache + PostgreSQL persistence).""" | |
| def __init__(self, redis_client: redis.Redis, max_cache_messages: int = 20, | |
| context_window_size: int = 10): | |
| """ | |
| Initialize the chat history manager. | |
| Args: | |
| redis_client: Redis client instance for caching | |
| max_cache_messages: Maximum number of messages to cache per session | |
| context_window_size: Number of recent messages to use for LLM context | |
| """ | |
| self.redis_client = redis_client | |
| self.max_cache_messages = max_cache_messages | |
| self.context_window_size = context_window_size | |
| self.cache_prefix = "chat_history:" | |
| def store_message(self, session_id: str, role: str, content: str, | |
| language: str = 'python', message_metadata: Optional[Dict[str, Any]] = None) -> Message: | |
| """ | |
| Store a message in both cache and database. | |
| Args: | |
| session_id: Session identifier | |
| role: Message role ('user' or 'assistant') | |
| content: Message content | |
| language: Programming language context | |
| message_metadata: Additional message metadata | |
| Returns: | |
| Message: The stored message object | |
| Raises: | |
| ChatHistoryError: If message storage fails | |
| """ | |
| try: | |
| # Create message object | |
| if role == 'user': | |
| message = Message.create_user_message(session_id, content, language) | |
| elif role == 'assistant': | |
| message = Message.create_assistant_message( | |
| session_id, content, language, message_metadata | |
| ) | |
| else: | |
| raise ValueError(f"Invalid role: {role}. Must be 'user' or 'assistant'") | |
| # Store in database | |
| db.session.add(message) | |
| db.session.commit() | |
| # Store in cache | |
| self._cache_message(message) | |
| # Maintain cache size limit | |
| self._trim_cache(session_id) | |
| logger.debug(f"Stored {role} message for session {session_id}") | |
| return message | |
| except SQLAlchemyError as e: | |
| db.session.rollback() | |
| logger.error(f"Database error storing message: {e}") | |
| raise ChatHistoryError(f"Failed to store message: {e}") | |
| except redis.RedisError as e: | |
| logger.warning(f"Redis error caching message: {e}") | |
| # Message was stored in DB, continue without cache | |
| return message | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Unexpected error storing message: {e}") | |
| raise ChatHistoryError(f"Failed to store message: {e}") | |
| def get_recent_history(self, session_id: str, limit: Optional[int] = None) -> List[Message]: | |
| """ | |
| Get recent messages for LLM context, checking cache first then database. | |
| Args: | |
| session_id: Session identifier | |
| limit: Maximum number of messages to retrieve (defaults to context_window_size) | |
| Returns: | |
| List[Message]: List of recent messages ordered by timestamp | |
| Raises: | |
| ChatHistoryError: If history retrieval fails | |
| """ | |
| if limit is None: | |
| limit = self.context_window_size | |
| try: | |
| # Try to get from cache first | |
| cached_messages = self._get_cached_messages(session_id, limit) | |
| if cached_messages and len(cached_messages) >= limit: | |
| return cached_messages[:limit] | |
| # Get from database | |
| messages = db.session.query(Message).filter( | |
| Message.session_id == session_id | |
| ).order_by(desc(Message.timestamp)).limit(limit).all() | |
| # Reverse to get chronological order | |
| messages.reverse() | |
| # Update cache with retrieved messages | |
| if messages: | |
| self._cache_messages(messages) | |
| logger.debug(f"Retrieved {len(messages)} recent messages for session {session_id}") | |
| return messages | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error getting recent history: {e}") | |
| raise ChatHistoryError(f"Failed to get recent history: {e}") | |
| except Exception as e: | |
| logger.error(f"Unexpected error getting recent history: {e}") | |
| raise ChatHistoryError(f"Failed to get recent history: {e}") | |
| def get_full_history(self, session_id: str, page: int = 1, page_size: int = 50) -> List[Message]: | |
| """ | |
| Get complete conversation history from database with pagination. | |
| Args: | |
| session_id: Session identifier | |
| page: Page number (1-based) | |
| page_size: Number of messages per page | |
| Returns: | |
| List[Message]: List of messages ordered by timestamp | |
| Raises: | |
| ChatHistoryError: If history retrieval fails | |
| """ | |
| try: | |
| offset = (page - 1) * page_size | |
| messages = db.session.query(Message).filter( | |
| Message.session_id == session_id | |
| ).order_by(Message.timestamp).offset(offset).limit(page_size).all() | |
| logger.debug(f"Retrieved {len(messages)} messages (page {page}) for session {session_id}") | |
| return messages | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error getting full history: {e}") | |
| raise ChatHistoryError(f"Failed to get full history: {e}") | |
| def get_message_count(self, session_id: str) -> int: | |
| """ | |
| Get total message count for a session. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| int: Total number of messages in the session | |
| Raises: | |
| ChatHistoryError: If count retrieval fails | |
| """ | |
| try: | |
| count = db.session.query(Message).filter( | |
| Message.session_id == session_id | |
| ).count() | |
| return count | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error getting message count: {e}") | |
| raise ChatHistoryError(f"Failed to get message count: {e}") | |
| def clear_session_history(self, session_id: str) -> int: | |
| """ | |
| Clear all history for a session from both cache and database. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| int: Number of messages deleted | |
| Raises: | |
| ChatHistoryError: If history clearing fails | |
| """ | |
| try: | |
| # Get count before deletion | |
| count = self.get_message_count(session_id) | |
| # Delete from database | |
| db.session.query(Message).filter( | |
| Message.session_id == session_id | |
| ).delete() | |
| db.session.commit() | |
| # Clear from cache | |
| self._clear_cache(session_id) | |
| logger.info(f"Cleared {count} messages for session {session_id}") | |
| return count | |
| except SQLAlchemyError as e: | |
| db.session.rollback() | |
| logger.error(f"Database error clearing history: {e}") | |
| raise ChatHistoryError(f"Failed to clear history: {e}") | |
| except Exception as e: | |
| db.session.rollback() | |
| logger.error(f"Unexpected error clearing history: {e}") | |
| raise ChatHistoryError(f"Failed to clear history: {e}") | |
| def search_messages(self, session_id: str, query: str, limit: int = 20) -> List[Message]: | |
| """ | |
| Search messages by content within a session. | |
| Args: | |
| session_id: Session identifier | |
| query: Search query string | |
| limit: Maximum number of results | |
| Returns: | |
| List[Message]: List of matching messages ordered by timestamp | |
| Raises: | |
| ChatHistoryError: If search fails | |
| """ | |
| try: | |
| messages = db.session.query(Message).filter( | |
| Message.session_id == session_id, | |
| Message.content.ilike(f'%{query}%') | |
| ).order_by(desc(Message.timestamp)).limit(limit).all() | |
| logger.debug(f"Found {len(messages)} messages matching '{query}' in session {session_id}") | |
| return messages | |
| except SQLAlchemyError as e: | |
| logger.error(f"Database error searching messages: {e}") | |
| raise ChatHistoryError(f"Failed to search messages: {e}") | |
| def _cache_message(self, message: Message) -> None: | |
| """Cache a single message in Redis.""" | |
| if not self.redis_client: | |
| return # Skip caching if Redis is not available | |
| try: | |
| cache_key = f"{self.cache_prefix}{message.session_id}" | |
| message_data = { | |
| 'id': message.id, | |
| 'session_id': message.session_id, | |
| 'role': message.role, | |
| 'content': message.content, | |
| 'language': message.language, | |
| 'timestamp': message.timestamp.isoformat(), | |
| 'message_metadata': message.message_metadata | |
| } | |
| # Add to Redis list (most recent at the end) | |
| self.redis_client.rpush(cache_key, json.dumps(message_data)) | |
| # Set expiration (24 hours) | |
| self.redis_client.expire(cache_key, 86400) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to cache message {message.id}: {e}") | |
| def _cache_messages(self, messages: List[Message]) -> None: | |
| """Cache multiple messages in Redis.""" | |
| if not messages or not self.redis_client: | |
| return # Skip caching if Redis is not available | |
| session_id = messages[0].session_id | |
| cache_key = f"{self.cache_prefix}{session_id}" | |
| try: | |
| # Clear existing cache for this session | |
| self.redis_client.delete(cache_key) | |
| # Add all messages | |
| message_data_list = [] | |
| for message in messages: | |
| message_data = { | |
| 'id': message.id, | |
| 'session_id': message.session_id, | |
| 'role': message.role, | |
| 'content': message.content, | |
| 'language': message.language, | |
| 'timestamp': message.timestamp.isoformat(), | |
| 'message_metadata': message.message_metadata | |
| } | |
| message_data_list.append(json.dumps(message_data)) | |
| if message_data_list: | |
| self.redis_client.rpush(cache_key, *message_data_list) | |
| # Set expiration (24 hours) | |
| self.redis_client.expire(cache_key, 86400) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to cache messages for session {session_id}: {e}") | |
| def _get_cached_messages(self, session_id: str, limit: int) -> Optional[List[Message]]: | |
| """Get messages from Redis cache.""" | |
| if not self.redis_client: | |
| return None # Skip cache lookup if Redis is not available | |
| try: | |
| cache_key = f"{self.cache_prefix}{session_id}" | |
| # Get the most recent messages (from the end of the list) | |
| cached_data = self.redis_client.lrange(cache_key, -limit, -1) | |
| if not cached_data: | |
| return None | |
| messages = [] | |
| for data in cached_data: | |
| try: | |
| message_data = json.loads(data) | |
| # Create Message object from cached data | |
| message = Message( | |
| session_id=message_data['session_id'], | |
| role=message_data['role'], | |
| content=message_data['content'], | |
| language=message_data['language'], | |
| message_metadata=message_data['message_metadata'] | |
| ) | |
| message.id = message_data['id'] | |
| message.timestamp = datetime.fromisoformat(message_data['timestamp']) | |
| messages.append(message) | |
| except (json.JSONDecodeError, KeyError, ValueError) as e: | |
| logger.warning(f"Invalid cached message data: {e}") | |
| continue | |
| return messages | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to get cached messages for session {session_id}: {e}") | |
| return None | |
| def _trim_cache(self, session_id: str) -> None: | |
| """Trim cache to maintain size limit.""" | |
| if not self.redis_client: | |
| return # Skip cache trimming if Redis is not available | |
| try: | |
| cache_key = f"{self.cache_prefix}{session_id}" | |
| # Keep only the most recent messages | |
| self.redis_client.ltrim(cache_key, -self.max_cache_messages, -1) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to trim cache for session {session_id}: {e}") | |
| def _clear_cache(self, session_id: str) -> None: | |
| """Clear cache for a session.""" | |
| if not self.redis_client: | |
| return # Skip cache clearing if Redis is not available | |
| try: | |
| cache_key = f"{self.cache_prefix}{session_id}" | |
| self.redis_client.delete(cache_key) | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to clear cache for session {session_id}: {e}") | |
| def get_cache_stats(self, session_id: str) -> Dict[str, Any]: | |
| """ | |
| Get cache statistics for a session. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| Dict[str, Any]: Cache statistics | |
| """ | |
| if not self.redis_client: | |
| return { | |
| 'session_id': session_id, | |
| 'cached_messages': 0, | |
| 'cache_ttl': -1, | |
| 'max_cache_size': self.max_cache_messages, | |
| 'redis_status': 'disabled' | |
| } | |
| try: | |
| cache_key = f"{self.cache_prefix}{session_id}" | |
| cached_count = self.redis_client.llen(cache_key) | |
| ttl = self.redis_client.ttl(cache_key) | |
| return { | |
| 'session_id': session_id, | |
| 'cached_messages': cached_count, | |
| 'cache_ttl': ttl, | |
| 'max_cache_size': self.max_cache_messages | |
| } | |
| except redis.RedisError as e: | |
| logger.warning(f"Failed to get cache stats for session {session_id}: {e}") | |
| return { | |
| 'session_id': session_id, | |
| 'cached_messages': 0, | |
| 'cache_ttl': -1, | |
| 'max_cache_size': self.max_cache_messages, | |
| 'error': str(e) | |
| } | |
| def create_chat_history_manager(redis_client: redis.Redis, max_cache_messages: int = 20, | |
| context_window_size: int = 10) -> ChatHistoryManager: | |
| """ | |
| Factory function to create a ChatHistoryManager instance. | |
| Args: | |
| redis_client: Redis client instance | |
| max_cache_messages: Maximum number of messages to cache per session | |
| context_window_size: Number of recent messages to use for LLM context | |
| Returns: | |
| ChatHistoryManager: Configured chat history manager instance | |
| """ | |
| return ChatHistoryManager(redis_client, max_cache_messages, context_window_size) |