cognichat / utils /session_manager.py
HYPERXD
feat: add comprehensive memory leak fixes and upgrade documentation
ae279de
"""Session manager with automatic cleanup for RAG chains"""
import time
import threading
from typing import Optional, Dict, Any
import logging
logger = logging.getLogger(__name__)
class SessionManager:
"""Manages RAG chains with automatic cleanup and TTL"""
def __init__(self, ttl_seconds: int = 86400):
"""
Initialize SessionManager
Args:
ttl_seconds: Time-to-live for sessions in seconds (default: 24 hours)
"""
self.sessions: Dict[str, dict] = {}
self.ttl = ttl_seconds
self.lock = threading.Lock()
# Start cleanup thread
self.cleanup_thread = threading.Thread(
target=self._cleanup_loop,
daemon=True
)
self.cleanup_thread.start()
logger.info(f"SessionManager initialized with TTL={ttl_seconds}s")
def create_session(self, session_id: str, rag_chain: Any, metadata: dict = None) -> str:
"""
Create a new session with automatic expiry
Args:
session_id: Unique session identifier
rag_chain: RAG chain object to store
metadata: Optional metadata dictionary
Returns:
session_id for verification
"""
with self.lock:
self.sessions[session_id] = {
'chain': rag_chain,
'created_at': time.time(),
'last_accessed': time.time(),
'metadata': metadata or {},
'access_count': 0
}
logger.info(f"Session created: {session_id}. Total: {len(self.sessions)}")
return session_id
def get_session(self, session_id: str) -> Optional[Any]:
"""
Retrieve a session and update access time
Args:
session_id: Session to retrieve
Returns:
RAG chain object or None if not found
"""
with self.lock:
if session_id not in self.sessions:
return None
session_data = self.sessions[session_id]
session_data['last_accessed'] = time.time()
session_data['access_count'] += 1
return session_data['chain']
def delete_session(self, session_id: str) -> bool:
"""
Manually delete a session
Args:
session_id: Session to delete
Returns:
True if deleted, False if not found
"""
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
logger.info(f"Session deleted: {session_id}. Remaining: {len(self.sessions)}")
return True
return False
def get_session_count(self) -> int:
"""Get number of active sessions"""
with self.lock:
return len(self.sessions)
def get_session_info(self, session_id: str) -> Optional[dict]:
"""
Get detailed session information
Args:
session_id: Session to query
Returns:
Dictionary with session metadata or None
"""
with self.lock:
if session_id not in self.sessions:
return None
session_data = self.sessions[session_id]
current_time = time.time()
return {
'session_id': session_id,
'created_at': session_data['created_at'],
'last_accessed': session_data['last_accessed'],
'access_count': session_data['access_count'],
'age_seconds': current_time - session_data['created_at'],
'idle_seconds': current_time - session_data['last_accessed'],
'ttl_seconds': self.ttl,
'expires_at': session_data['created_at'] + self.ttl,
}
def get_all_sessions_info(self) -> list:
"""Get information about all active sessions"""
with self.lock:
return [self.get_session_info(sid) for sid in self.sessions.keys()]
def _cleanup_loop(self):
"""Background cleanup thread that runs periodically"""
while True:
try:
time.sleep(300) # Check every 5 minutes
self._cleanup_expired_sessions()
except Exception as e:
logger.error(f"Cleanup loop error: {e}")
def _cleanup_expired_sessions(self):
"""Remove sessions that have exceeded their TTL"""
current_time = time.time()
expired = []
with self.lock:
for session_id, session_data in list(self.sessions.items()):
age = current_time - session_data['created_at']
if age > self.ttl:
expired.append(session_id)
# Remove expired sessions
for session_id in expired:
del self.sessions[session_id]
logger.info(f"Expired session: {session_id}")
if expired:
logger.info(f"Cleaned {len(expired)} expired sessions. Active: {len(self.sessions)}")
def cleanup_all(self):
"""Clear all sessions (typically called on shutdown)"""
with self.lock:
count = len(self.sessions)
self.sessions.clear()
logger.info(f"Cleaned up all {count} sessions")
def get_memory_stats(self) -> dict:
"""Get memory statistics"""
with self.lock:
total_age = 0
total_idle = 0
total_accesses = 0
current_time = time.time()
for session_data in self.sessions.values():
total_age += current_time - session_data['created_at']
total_idle += current_time - session_data['last_accessed']
total_accesses += session_data['access_count']
count = len(self.sessions)
return {
'total_sessions': count,
'avg_age_seconds': total_age / count if count > 0 else 0,
'avg_idle_seconds': total_idle / count if count > 0 else 0,
'total_accesses': total_accesses,
'ttl_seconds': self.ttl,
}