scratch_chat / chat_agent /services /chat_history.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
"""Chat history management service for the chat agent."""
import json
import logging
from datetime import datetime
from typing import List, Optional, Dict, Any
import redis
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import desc
from ..models.message import Message
from ..models.base import db
logger = logging.getLogger(__name__)
class ChatHistoryError(Exception):
"""Base exception for chat history errors."""
pass
class ChatHistoryManager:
"""Manages chat history with dual storage (Redis cache + PostgreSQL persistence)."""
def __init__(self, redis_client: redis.Redis, max_cache_messages: int = 20,
context_window_size: int = 10):
"""
Initialize the chat history manager.
Args:
redis_client: Redis client instance for caching
max_cache_messages: Maximum number of messages to cache per session
context_window_size: Number of recent messages to use for LLM context
"""
self.redis_client = redis_client
self.max_cache_messages = max_cache_messages
self.context_window_size = context_window_size
self.cache_prefix = "chat_history:"
def store_message(self, session_id: str, role: str, content: str,
language: str = 'python', message_metadata: Optional[Dict[str, Any]] = None) -> Message:
"""
Store a message in both cache and database.
Args:
session_id: Session identifier
role: Message role ('user' or 'assistant')
content: Message content
language: Programming language context
message_metadata: Additional message metadata
Returns:
Message: The stored message object
Raises:
ChatHistoryError: If message storage fails
"""
try:
# Create message object
if role == 'user':
message = Message.create_user_message(session_id, content, language)
elif role == 'assistant':
message = Message.create_assistant_message(
session_id, content, language, message_metadata
)
else:
raise ValueError(f"Invalid role: {role}. Must be 'user' or 'assistant'")
# Store in database
db.session.add(message)
db.session.commit()
# Store in cache
self._cache_message(message)
# Maintain cache size limit
self._trim_cache(session_id)
logger.debug(f"Stored {role} message for session {session_id}")
return message
except SQLAlchemyError as e:
db.session.rollback()
logger.error(f"Database error storing message: {e}")
raise ChatHistoryError(f"Failed to store message: {e}")
except redis.RedisError as e:
logger.warning(f"Redis error caching message: {e}")
# Message was stored in DB, continue without cache
return message
except Exception as e:
db.session.rollback()
logger.error(f"Unexpected error storing message: {e}")
raise ChatHistoryError(f"Failed to store message: {e}")
def get_recent_history(self, session_id: str, limit: Optional[int] = None) -> List[Message]:
"""
Get recent messages for LLM context, checking cache first then database.
Args:
session_id: Session identifier
limit: Maximum number of messages to retrieve (defaults to context_window_size)
Returns:
List[Message]: List of recent messages ordered by timestamp
Raises:
ChatHistoryError: If history retrieval fails
"""
if limit is None:
limit = self.context_window_size
try:
# Try to get from cache first
cached_messages = self._get_cached_messages(session_id, limit)
if cached_messages and len(cached_messages) >= limit:
return cached_messages[:limit]
# Get from database
messages = db.session.query(Message).filter(
Message.session_id == session_id
).order_by(desc(Message.timestamp)).limit(limit).all()
# Reverse to get chronological order
messages.reverse()
# Update cache with retrieved messages
if messages:
self._cache_messages(messages)
logger.debug(f"Retrieved {len(messages)} recent messages for session {session_id}")
return messages
except SQLAlchemyError as e:
logger.error(f"Database error getting recent history: {e}")
raise ChatHistoryError(f"Failed to get recent history: {e}")
except Exception as e:
logger.error(f"Unexpected error getting recent history: {e}")
raise ChatHistoryError(f"Failed to get recent history: {e}")
def get_full_history(self, session_id: str, page: int = 1, page_size: int = 50) -> List[Message]:
"""
Get complete conversation history from database with pagination.
Args:
session_id: Session identifier
page: Page number (1-based)
page_size: Number of messages per page
Returns:
List[Message]: List of messages ordered by timestamp
Raises:
ChatHistoryError: If history retrieval fails
"""
try:
offset = (page - 1) * page_size
messages = db.session.query(Message).filter(
Message.session_id == session_id
).order_by(Message.timestamp).offset(offset).limit(page_size).all()
logger.debug(f"Retrieved {len(messages)} messages (page {page}) for session {session_id}")
return messages
except SQLAlchemyError as e:
logger.error(f"Database error getting full history: {e}")
raise ChatHistoryError(f"Failed to get full history: {e}")
def get_message_count(self, session_id: str) -> int:
"""
Get total message count for a session.
Args:
session_id: Session identifier
Returns:
int: Total number of messages in the session
Raises:
ChatHistoryError: If count retrieval fails
"""
try:
count = db.session.query(Message).filter(
Message.session_id == session_id
).count()
return count
except SQLAlchemyError as e:
logger.error(f"Database error getting message count: {e}")
raise ChatHistoryError(f"Failed to get message count: {e}")
def clear_session_history(self, session_id: str) -> int:
"""
Clear all history for a session from both cache and database.
Args:
session_id: Session identifier
Returns:
int: Number of messages deleted
Raises:
ChatHistoryError: If history clearing fails
"""
try:
# Get count before deletion
count = self.get_message_count(session_id)
# Delete from database
db.session.query(Message).filter(
Message.session_id == session_id
).delete()
db.session.commit()
# Clear from cache
self._clear_cache(session_id)
logger.info(f"Cleared {count} messages for session {session_id}")
return count
except SQLAlchemyError as e:
db.session.rollback()
logger.error(f"Database error clearing history: {e}")
raise ChatHistoryError(f"Failed to clear history: {e}")
except Exception as e:
db.session.rollback()
logger.error(f"Unexpected error clearing history: {e}")
raise ChatHistoryError(f"Failed to clear history: {e}")
def search_messages(self, session_id: str, query: str, limit: int = 20) -> List[Message]:
"""
Search messages by content within a session.
Args:
session_id: Session identifier
query: Search query string
limit: Maximum number of results
Returns:
List[Message]: List of matching messages ordered by timestamp
Raises:
ChatHistoryError: If search fails
"""
try:
messages = db.session.query(Message).filter(
Message.session_id == session_id,
Message.content.ilike(f'%{query}%')
).order_by(desc(Message.timestamp)).limit(limit).all()
logger.debug(f"Found {len(messages)} messages matching '{query}' in session {session_id}")
return messages
except SQLAlchemyError as e:
logger.error(f"Database error searching messages: {e}")
raise ChatHistoryError(f"Failed to search messages: {e}")
def _cache_message(self, message: Message) -> None:
"""Cache a single message in Redis."""
if not self.redis_client:
return # Skip caching if Redis is not available
try:
cache_key = f"{self.cache_prefix}{message.session_id}"
message_data = {
'id': message.id,
'session_id': message.session_id,
'role': message.role,
'content': message.content,
'language': message.language,
'timestamp': message.timestamp.isoformat(),
'message_metadata': message.message_metadata
}
# Add to Redis list (most recent at the end)
self.redis_client.rpush(cache_key, json.dumps(message_data))
# Set expiration (24 hours)
self.redis_client.expire(cache_key, 86400)
except redis.RedisError as e:
logger.warning(f"Failed to cache message {message.id}: {e}")
def _cache_messages(self, messages: List[Message]) -> None:
"""Cache multiple messages in Redis."""
if not messages or not self.redis_client:
return # Skip caching if Redis is not available
session_id = messages[0].session_id
cache_key = f"{self.cache_prefix}{session_id}"
try:
# Clear existing cache for this session
self.redis_client.delete(cache_key)
# Add all messages
message_data_list = []
for message in messages:
message_data = {
'id': message.id,
'session_id': message.session_id,
'role': message.role,
'content': message.content,
'language': message.language,
'timestamp': message.timestamp.isoformat(),
'message_metadata': message.message_metadata
}
message_data_list.append(json.dumps(message_data))
if message_data_list:
self.redis_client.rpush(cache_key, *message_data_list)
# Set expiration (24 hours)
self.redis_client.expire(cache_key, 86400)
except redis.RedisError as e:
logger.warning(f"Failed to cache messages for session {session_id}: {e}")
def _get_cached_messages(self, session_id: str, limit: int) -> Optional[List[Message]]:
"""Get messages from Redis cache."""
if not self.redis_client:
return None # Skip cache lookup if Redis is not available
try:
cache_key = f"{self.cache_prefix}{session_id}"
# Get the most recent messages (from the end of the list)
cached_data = self.redis_client.lrange(cache_key, -limit, -1)
if not cached_data:
return None
messages = []
for data in cached_data:
try:
message_data = json.loads(data)
# Create Message object from cached data
message = Message(
session_id=message_data['session_id'],
role=message_data['role'],
content=message_data['content'],
language=message_data['language'],
message_metadata=message_data['message_metadata']
)
message.id = message_data['id']
message.timestamp = datetime.fromisoformat(message_data['timestamp'])
messages.append(message)
except (json.JSONDecodeError, KeyError, ValueError) as e:
logger.warning(f"Invalid cached message data: {e}")
continue
return messages
except redis.RedisError as e:
logger.warning(f"Failed to get cached messages for session {session_id}: {e}")
return None
def _trim_cache(self, session_id: str) -> None:
"""Trim cache to maintain size limit."""
if not self.redis_client:
return # Skip cache trimming if Redis is not available
try:
cache_key = f"{self.cache_prefix}{session_id}"
# Keep only the most recent messages
self.redis_client.ltrim(cache_key, -self.max_cache_messages, -1)
except redis.RedisError as e:
logger.warning(f"Failed to trim cache for session {session_id}: {e}")
def _clear_cache(self, session_id: str) -> None:
"""Clear cache for a session."""
if not self.redis_client:
return # Skip cache clearing if Redis is not available
try:
cache_key = f"{self.cache_prefix}{session_id}"
self.redis_client.delete(cache_key)
except redis.RedisError as e:
logger.warning(f"Failed to clear cache for session {session_id}: {e}")
def get_cache_stats(self, session_id: str) -> Dict[str, Any]:
"""
Get cache statistics for a session.
Args:
session_id: Session identifier
Returns:
Dict[str, Any]: Cache statistics
"""
if not self.redis_client:
return {
'session_id': session_id,
'cached_messages': 0,
'cache_ttl': -1,
'max_cache_size': self.max_cache_messages,
'redis_status': 'disabled'
}
try:
cache_key = f"{self.cache_prefix}{session_id}"
cached_count = self.redis_client.llen(cache_key)
ttl = self.redis_client.ttl(cache_key)
return {
'session_id': session_id,
'cached_messages': cached_count,
'cache_ttl': ttl,
'max_cache_size': self.max_cache_messages
}
except redis.RedisError as e:
logger.warning(f"Failed to get cache stats for session {session_id}: {e}")
return {
'session_id': session_id,
'cached_messages': 0,
'cache_ttl': -1,
'max_cache_size': self.max_cache_messages,
'error': str(e)
}
def create_chat_history_manager(redis_client: redis.Redis, max_cache_messages: int = 20,
context_window_size: int = 10) -> ChatHistoryManager:
"""
Factory function to create a ChatHistoryManager instance.
Args:
redis_client: Redis client instance
max_cache_messages: Maximum number of messages to cache per session
context_window_size: Number of recent messages to use for LLM context
Returns:
ChatHistoryManager: Configured chat history manager instance
"""
return ChatHistoryManager(redis_client, max_cache_messages, context_window_size)