"""Session manager with automatic cleanup for RAG chains""" import time import threading from typing import Optional, Dict, Any import logging logger = logging.getLogger(__name__) class SessionManager: """Manages RAG chains with automatic cleanup and TTL""" def __init__(self, ttl_seconds: int = 86400): """ Initialize SessionManager Args: ttl_seconds: Time-to-live for sessions in seconds (default: 24 hours) """ self.sessions: Dict[str, dict] = {} self.ttl = ttl_seconds self.lock = threading.Lock() # Start cleanup thread self.cleanup_thread = threading.Thread( target=self._cleanup_loop, daemon=True ) self.cleanup_thread.start() logger.info(f"SessionManager initialized with TTL={ttl_seconds}s") def create_session(self, session_id: str, rag_chain: Any, metadata: dict = None) -> str: """ Create a new session with automatic expiry Args: session_id: Unique session identifier rag_chain: RAG chain object to store metadata: Optional metadata dictionary Returns: session_id for verification """ with self.lock: self.sessions[session_id] = { 'chain': rag_chain, 'created_at': time.time(), 'last_accessed': time.time(), 'metadata': metadata or {}, 'access_count': 0 } logger.info(f"Session created: {session_id}. Total: {len(self.sessions)}") return session_id def get_session(self, session_id: str) -> Optional[Any]: """ Retrieve a session and update access time Args: session_id: Session to retrieve Returns: RAG chain object or None if not found """ with self.lock: if session_id not in self.sessions: return None session_data = self.sessions[session_id] session_data['last_accessed'] = time.time() session_data['access_count'] += 1 return session_data['chain'] def delete_session(self, session_id: str) -> bool: """ Manually delete a session Args: session_id: Session to delete Returns: True if deleted, False if not found """ with self.lock: if session_id in self.sessions: del self.sessions[session_id] logger.info(f"Session deleted: {session_id}. Remaining: {len(self.sessions)}") return True return False def get_session_count(self) -> int: """Get number of active sessions""" with self.lock: return len(self.sessions) def get_session_info(self, session_id: str) -> Optional[dict]: """ Get detailed session information Args: session_id: Session to query Returns: Dictionary with session metadata or None """ with self.lock: if session_id not in self.sessions: return None session_data = self.sessions[session_id] current_time = time.time() return { 'session_id': session_id, 'created_at': session_data['created_at'], 'last_accessed': session_data['last_accessed'], 'access_count': session_data['access_count'], 'age_seconds': current_time - session_data['created_at'], 'idle_seconds': current_time - session_data['last_accessed'], 'ttl_seconds': self.ttl, 'expires_at': session_data['created_at'] + self.ttl, } def get_all_sessions_info(self) -> list: """Get information about all active sessions""" with self.lock: return [self.get_session_info(sid) for sid in self.sessions.keys()] def _cleanup_loop(self): """Background cleanup thread that runs periodically""" while True: try: time.sleep(300) # Check every 5 minutes self._cleanup_expired_sessions() except Exception as e: logger.error(f"Cleanup loop error: {e}") def _cleanup_expired_sessions(self): """Remove sessions that have exceeded their TTL""" current_time = time.time() expired = [] with self.lock: for session_id, session_data in list(self.sessions.items()): age = current_time - session_data['created_at'] if age > self.ttl: expired.append(session_id) # Remove expired sessions for session_id in expired: del self.sessions[session_id] logger.info(f"Expired session: {session_id}") if expired: logger.info(f"Cleaned {len(expired)} expired sessions. Active: {len(self.sessions)}") def cleanup_all(self): """Clear all sessions (typically called on shutdown)""" with self.lock: count = len(self.sessions) self.sessions.clear() logger.info(f"Cleaned up all {count} sessions") def get_memory_stats(self) -> dict: """Get memory statistics""" with self.lock: total_age = 0 total_idle = 0 total_accesses = 0 current_time = time.time() for session_data in self.sessions.values(): total_age += current_time - session_data['created_at'] total_idle += current_time - session_data['last_accessed'] total_accesses += session_data['access_count'] count = len(self.sessions) return { 'total_sessions': count, 'avg_age_seconds': total_age / count if count > 0 else 0, 'avg_idle_seconds': total_idle / count if count > 0 else 0, 'total_accesses': total_accesses, 'ttl_seconds': self.ttl, }