| | """Session manager with automatic cleanup for RAG chains""" |
| |
|
| | import time |
| | import threading |
| | from typing import Optional, Dict, Any |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class SessionManager: |
| | """Manages RAG chains with automatic cleanup and TTL""" |
| | |
| | def __init__(self, ttl_seconds: int = 86400): |
| | """ |
| | Initialize SessionManager |
| | |
| | Args: |
| | ttl_seconds: Time-to-live for sessions in seconds (default: 24 hours) |
| | """ |
| | self.sessions: Dict[str, dict] = {} |
| | self.ttl = ttl_seconds |
| | self.lock = threading.Lock() |
| | |
| | |
| | self.cleanup_thread = threading.Thread( |
| | target=self._cleanup_loop, |
| | daemon=True |
| | ) |
| | self.cleanup_thread.start() |
| | logger.info(f"SessionManager initialized with TTL={ttl_seconds}s") |
| | |
| | def create_session(self, session_id: str, rag_chain: Any, metadata: dict = None) -> str: |
| | """ |
| | Create a new session with automatic expiry |
| | |
| | Args: |
| | session_id: Unique session identifier |
| | rag_chain: RAG chain object to store |
| | metadata: Optional metadata dictionary |
| | |
| | Returns: |
| | session_id for verification |
| | """ |
| | with self.lock: |
| | self.sessions[session_id] = { |
| | 'chain': rag_chain, |
| | 'created_at': time.time(), |
| | 'last_accessed': time.time(), |
| | 'metadata': metadata or {}, |
| | 'access_count': 0 |
| | } |
| | logger.info(f"Session created: {session_id}. Total: {len(self.sessions)}") |
| | return session_id |
| | |
| | def get_session(self, session_id: str) -> Optional[Any]: |
| | """ |
| | Retrieve a session and update access time |
| | |
| | Args: |
| | session_id: Session to retrieve |
| | |
| | Returns: |
| | RAG chain object or None if not found |
| | """ |
| | with self.lock: |
| | if session_id not in self.sessions: |
| | return None |
| | |
| | session_data = self.sessions[session_id] |
| | session_data['last_accessed'] = time.time() |
| | session_data['access_count'] += 1 |
| | |
| | return session_data['chain'] |
| | |
| | def delete_session(self, session_id: str) -> bool: |
| | """ |
| | Manually delete a session |
| | |
| | Args: |
| | session_id: Session to delete |
| | |
| | Returns: |
| | True if deleted, False if not found |
| | """ |
| | with self.lock: |
| | if session_id in self.sessions: |
| | del self.sessions[session_id] |
| | logger.info(f"Session deleted: {session_id}. Remaining: {len(self.sessions)}") |
| | return True |
| | return False |
| | |
| | def get_session_count(self) -> int: |
| | """Get number of active sessions""" |
| | with self.lock: |
| | return len(self.sessions) |
| | |
| | def get_session_info(self, session_id: str) -> Optional[dict]: |
| | """ |
| | Get detailed session information |
| | |
| | Args: |
| | session_id: Session to query |
| | |
| | Returns: |
| | Dictionary with session metadata or None |
| | """ |
| | with self.lock: |
| | if session_id not in self.sessions: |
| | return None |
| | |
| | session_data = self.sessions[session_id] |
| | current_time = time.time() |
| | |
| | return { |
| | 'session_id': session_id, |
| | 'created_at': session_data['created_at'], |
| | 'last_accessed': session_data['last_accessed'], |
| | 'access_count': session_data['access_count'], |
| | 'age_seconds': current_time - session_data['created_at'], |
| | 'idle_seconds': current_time - session_data['last_accessed'], |
| | 'ttl_seconds': self.ttl, |
| | 'expires_at': session_data['created_at'] + self.ttl, |
| | } |
| | |
| | def get_all_sessions_info(self) -> list: |
| | """Get information about all active sessions""" |
| | with self.lock: |
| | return [self.get_session_info(sid) for sid in self.sessions.keys()] |
| | |
| | def _cleanup_loop(self): |
| | """Background cleanup thread that runs periodically""" |
| | while True: |
| | try: |
| | time.sleep(300) |
| | self._cleanup_expired_sessions() |
| | except Exception as e: |
| | logger.error(f"Cleanup loop error: {e}") |
| | |
| | def _cleanup_expired_sessions(self): |
| | """Remove sessions that have exceeded their TTL""" |
| | current_time = time.time() |
| | expired = [] |
| | |
| | with self.lock: |
| | for session_id, session_data in list(self.sessions.items()): |
| | age = current_time - session_data['created_at'] |
| | if age > self.ttl: |
| | expired.append(session_id) |
| | |
| | |
| | for session_id in expired: |
| | del self.sessions[session_id] |
| | logger.info(f"Expired session: {session_id}") |
| | |
| | if expired: |
| | logger.info(f"Cleaned {len(expired)} expired sessions. Active: {len(self.sessions)}") |
| | |
| | def cleanup_all(self): |
| | """Clear all sessions (typically called on shutdown)""" |
| | with self.lock: |
| | count = len(self.sessions) |
| | self.sessions.clear() |
| | logger.info(f"Cleaned up all {count} sessions") |
| | |
| | def get_memory_stats(self) -> dict: |
| | """Get memory statistics""" |
| | with self.lock: |
| | total_age = 0 |
| | total_idle = 0 |
| | total_accesses = 0 |
| | current_time = time.time() |
| | |
| | for session_data in self.sessions.values(): |
| | total_age += current_time - session_data['created_at'] |
| | total_idle += current_time - session_data['last_accessed'] |
| | total_accesses += session_data['access_count'] |
| | |
| | count = len(self.sessions) |
| | |
| | return { |
| | 'total_sessions': count, |
| | 'avg_age_seconds': total_age / count if count > 0 else 0, |
| | 'avg_idle_seconds': total_idle / count if count > 0 else 0, |
| | 'total_accesses': total_accesses, |
| | 'ttl_seconds': self.ttl, |
| | } |
| |
|