sema-chat / app /services /session_manager.py
kamau1's picture
Initial Commit
639f3bb
"""
Session Manager - Handles conversation sessions and message history
Supports both in-memory and Redis-based storage
"""
import asyncio
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import json
import uuid
from ..core.config import settings
from ..core.logging import LoggerMixin
from ..models.schemas import ChatMessage, ConversationHistory, SessionInfo
class SessionManager(LoggerMixin):
"""
Manages chat sessions and conversation history
Supports both in-memory and Redis storage backends
"""
def __init__(self):
self.sessions: Dict[str, ConversationHistory] = {}
self.redis_client = None
self.use_redis = bool(settings.redis_url)
self.session_timeout = settings.session_timeout * 60 # Convert to seconds
self.max_sessions_per_user = settings.max_sessions_per_user
self.max_messages_per_session = settings.max_messages_per_session
# Cleanup task
self._cleanup_task = None
async def initialize(self) -> bool:
"""Initialize the session manager"""
try:
if self.use_redis:
await self._initialize_redis()
# Start cleanup task
self._cleanup_task = asyncio.create_task(self._cleanup_expired_sessions())
self.log_info("Session manager initialized",
storage_type="redis" if self.use_redis else "memory",
session_timeout=self.session_timeout)
return True
except Exception as e:
self.log_error("Failed to initialize session manager", error=str(e))
return False
async def _initialize_redis(self):
"""Initialize Redis connection"""
try:
import redis.asyncio as redis
self.redis_client = redis.from_url(settings.redis_url)
# Test connection
await self.redis_client.ping()
self.log_info("Redis connection established", url=settings.redis_url)
except ImportError:
self.log_warning("Redis not available, falling back to memory storage")
self.use_redis = False
except Exception as e:
self.log_error("Redis connection failed", error=str(e))
self.use_redis = False
async def shutdown(self):
"""Shutdown the session manager"""
try:
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
if self.redis_client:
await self.redis_client.close()
self.log_info("Session manager shutdown complete")
except Exception as e:
self.log_error("Session manager shutdown failed", error=str(e))
async def create_session(self, session_id: str, user_id: Optional[str] = None) -> bool:
"""
Create a new chat session
Args:
session_id: Unique session identifier
user_id: Optional user identifier for session limits
Returns:
bool: True if session created successfully
"""
try:
# Check if session already exists
if await self.session_exists(session_id):
self.log_info("Session already exists", session_id=session_id)
return True
# Check user session limits if user_id provided
if user_id and self.max_sessions_per_user > 0:
user_sessions = await self.get_user_sessions(user_id)
if len(user_sessions) >= self.max_sessions_per_user:
self.log_warning("User session limit exceeded",
user_id=user_id,
limit=self.max_sessions_per_user)
return False
# Create new session
session = ConversationHistory(
session_id=session_id,
messages=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
message_count=0
)
await self._store_session(session)
self.log_info("Session created", session_id=session_id, user_id=user_id)
return True
except Exception as e:
self.log_error("Failed to create session", error=str(e), session_id=session_id)
return False
async def add_message(self, session_id: str, message: ChatMessage) -> bool:
"""
Add a message to a session
Args:
session_id: Session identifier
message: Message to add
Returns:
bool: True if message added successfully
"""
try:
# Get or create session
session = await self.get_session(session_id)
if not session:
await self.create_session(session_id)
session = await self.get_session(session_id)
if not session:
self.log_error("Failed to create session", session_id=session_id)
return False
# Check message limit
if (self.max_messages_per_session > 0 and
len(session.messages) >= self.max_messages_per_session):
# Remove oldest messages to make room
messages_to_remove = len(session.messages) - self.max_messages_per_session + 1
session.messages = session.messages[messages_to_remove:]
self.log_info("Trimmed old messages",
session_id=session_id,
removed_count=messages_to_remove)
# Add message
session.messages.append(message)
session.message_count = len(session.messages)
session.updated_at = datetime.utcnow()
# Store updated session
await self._store_session(session)
self.log_debug("Message added to session",
session_id=session_id,
message_role=message.role,
total_messages=session.message_count)
return True
except Exception as e:
self.log_error("Failed to add message", error=str(e), session_id=session_id)
return False
async def get_session(self, session_id: str) -> Optional[ConversationHistory]:
"""
Get a session by ID
Args:
session_id: Session identifier
Returns:
ConversationHistory or None if not found
"""
try:
if self.use_redis:
return await self._get_session_from_redis(session_id)
else:
return self.sessions.get(session_id)
except Exception as e:
self.log_error("Failed to get session", error=str(e), session_id=session_id)
return None
async def session_exists(self, session_id: str) -> bool:
"""Check if a session exists"""
session = await self.get_session(session_id)
return session is not None
async def delete_session(self, session_id: str) -> bool:
"""
Delete a session
Args:
session_id: Session identifier
Returns:
bool: True if session deleted successfully
"""
try:
if self.use_redis:
await self.redis_client.delete(f"session:{session_id}")
else:
self.sessions.pop(session_id, None)
self.log_info("Session deleted", session_id=session_id)
return True
except Exception as e:
self.log_error("Failed to delete session", error=str(e), session_id=session_id)
return False
async def get_session_messages(self, session_id: str, limit: Optional[int] = None) -> List[ChatMessage]:
"""
Get messages from a session
Args:
session_id: Session identifier
limit: Optional limit on number of messages to return
Returns:
List of ChatMessage objects
"""
session = await self.get_session(session_id)
if not session:
return []
messages = session.messages
if limit and limit > 0:
messages = messages[-limit:] # Get last N messages
return messages
async def get_active_sessions(self) -> List[SessionInfo]:
"""Get information about all active sessions"""
try:
sessions = []
if self.use_redis:
# Get all session keys from Redis
keys = await self.redis_client.keys("session:*")
for key in keys:
session_id = key.decode().replace("session:", "")
session = await self.get_session(session_id)
if session:
sessions.append(self._session_to_info(session))
else:
# Get from memory
for session in self.sessions.values():
sessions.append(self._session_to_info(session))
return sessions
except Exception as e:
self.log_error("Failed to get active sessions", error=str(e))
return []
async def get_user_sessions(self, user_id: str) -> List[SessionInfo]:
"""Get sessions for a specific user (requires user_id in session metadata)"""
# This is a simplified implementation
# In a real system, you'd store user_id -> session_id mappings
all_sessions = await self.get_active_sessions()
return [s for s in all_sessions if s.session_id.startswith(f"{user_id}-")]
def _session_to_info(self, session: ConversationHistory) -> SessionInfo:
"""Convert ConversationHistory to SessionInfo"""
return SessionInfo(
session_id=session.session_id,
created_at=session.created_at,
updated_at=session.updated_at,
message_count=session.message_count,
model_name=settings.model_name, # Current model
is_active=True
)
async def _store_session(self, session: ConversationHistory):
"""Store session in the appropriate backend"""
if self.use_redis:
await self._store_session_in_redis(session)
else:
self.sessions[session.session_id] = session
async def _store_session_in_redis(self, session: ConversationHistory):
"""Store session in Redis"""
key = f"session:{session.session_id}"
data = {
"session_id": session.session_id,
"messages": [
{
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.isoformat(),
"metadata": msg.metadata or {}
}
for msg in session.messages
],
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"message_count": session.message_count
}
await self.redis_client.setex(
key,
self.session_timeout,
json.dumps(data, default=str)
)
async def _get_session_from_redis(self, session_id: str) -> Optional[ConversationHistory]:
"""Get session from Redis"""
key = f"session:{session_id}"
data = await self.redis_client.get(key)
if not data:
return None
try:
session_data = json.loads(data)
messages = [
ChatMessage(
role=msg["role"],
content=msg["content"],
timestamp=datetime.fromisoformat(msg["timestamp"]),
metadata=msg.get("metadata")
)
for msg in session_data["messages"]
]
return ConversationHistory(
session_id=session_data["session_id"],
messages=messages,
created_at=datetime.fromisoformat(session_data["created_at"]),
updated_at=datetime.fromisoformat(session_data["updated_at"]),
message_count=session_data["message_count"]
)
except Exception as e:
self.log_error("Failed to parse session from Redis", error=str(e), session_id=session_id)
return None
async def _cleanup_expired_sessions(self):
"""Background task to cleanup expired sessions"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
if not self.use_redis: # Redis handles expiration automatically
current_time = datetime.utcnow()
expired_sessions = []
for session_id, session in self.sessions.items():
if (current_time - session.updated_at).total_seconds() > self.session_timeout:
expired_sessions.append(session_id)
for session_id in expired_sessions:
del self.sessions[session_id]
self.log_debug("Expired session cleaned up", session_id=session_id)
if expired_sessions:
self.log_info("Cleaned up expired sessions", count=len(expired_sessions))
except asyncio.CancelledError:
break
except Exception as e:
self.log_error("Session cleanup failed", error=str(e))
# Global session manager instance
session_manager = SessionManager()
async def get_session_manager() -> SessionManager:
"""Get the global session manager instance"""
return session_manager
async def initialize_session_manager() -> bool:
"""Initialize the global session manager"""
return await session_manager.initialize()
async def shutdown_session_manager():
"""Shutdown the global session manager"""
await session_manager.shutdown()