""" Session Management Module ========================= This module handles user sessions, parallel processing pools, and caching. """ import uuid import time import threading import logging import concurrent.futures from typing import List, Dict, Any, Optional, Tuple from queue import Queue from functools import lru_cache from dataclasses import dataclass, field from datetime import datetime logger = logging.getLogger(__name__) @dataclass class SessionMetrics: """Metrics for a user session""" session_id: str created_at: float = field(default_factory=time.time) total_queries: int = 0 cache_hits: int = 0 total_response_time: float = 0.0 tool_usage: Dict[str, int] = field(default_factory=dict) errors: List[Dict[str, Any]] = field(default_factory=list) parallel_executions: int = 0 @property def average_response_time(self) -> float: """Calculate average response time""" if self.total_queries == 0: return 0.0 return self.total_response_time / self.total_queries @property def cache_hit_rate(self) -> float: """Calculate cache hit rate percentage""" if self.total_queries == 0: return 0.0 return (self.cache_hits / self.total_queries) * 100 @property def uptime_hours(self) -> float: """Calculate session uptime in hours""" return (time.time() - self.created_at) / 3600 class AsyncResponseCache: """Advanced response caching with TTL and intelligent invalidation""" def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600): self.max_size = max_size self.ttl_seconds = ttl_seconds self.cache = {} self.timestamps = {} self.lock = threading.RLock() # Start cleanup thread self.cleanup_thread = threading.Thread(target=self._cleanup_expired, daemon=True) self.cleanup_thread.start() logger.info(f"Initialized AsyncResponseCache with max_size={self.max_size}, ttl={self.ttl_seconds}s") def _cleanup_expired(self): """Background thread to clean up expired cache entries""" while True: try: time.sleep(60) # Run cleanup every minute current_time = time.time() with self.lock: expired_keys = [ key for key, timestamp in self.timestamps.items() if current_time - timestamp > self.ttl_seconds ] for key in expired_keys: self.cache.pop(key, None) self.timestamps.pop(key, None) if expired_keys: logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries") except Exception as e: logger.error(f"Error in cache cleanup: {e}") def get(self, key: str) -> Optional[Any]: """Get value from cache if not expired""" with self.lock: if key in self.cache: # Check if expired if time.time() - self.timestamps[key] <= self.ttl_seconds: logger.debug(f"Cache hit for key: {key[:50]}...") return self.cache[key] else: # Remove expired entry self.cache.pop(key) self.timestamps.pop(key) return None def set(self, key: str, value: Any): """Set value in cache with timestamp""" with self.lock: # Implement LRU eviction if cache is full if len(self.cache) >= self.max_size: # Remove oldest entry oldest_key = min(self.timestamps.keys(), key=lambda k: self.timestamps[k]) self.cache.pop(oldest_key) self.timestamps.pop(oldest_key) self.cache[key] = value self.timestamps[key] = time.time() logger.debug(f"Cache set for key: {key[:50]}...") def invalidate(self, pattern: str = None): """Invalidate cache entries matching pattern""" with self.lock: if pattern: keys_to_remove = [key for key in self.cache.keys() if pattern in key] else: keys_to_remove = list(self.cache.keys()) for key in keys_to_remove: self.cache.pop(key, None) self.timestamps.pop(key, None) logger.info(f"Invalidated {len(keys_to_remove)} cache entries") def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" with self.lock: return { "size": len(self.cache), "max_size": self.max_size, "ttl_seconds": self.ttl_seconds, "oldest_entry": min(self.timestamps.values()) if self.timestamps else None, "newest_entry": max(self.timestamps.values()) if self.timestamps else None } class SessionManager: """Manages user sessions and their state""" def __init__(self): self.sessions: Dict[str, SessionMetrics] = {} self.cache = AsyncResponseCache() self.lock = threading.RLock() logger.info("SessionManager initialized") def create_session(self, session_id: str = None) -> str: """Create a new user session""" if not session_id: session_id = str(uuid.uuid4()) with self.lock: if session_id not in self.sessions: self.sessions[session_id] = SessionMetrics(session_id=session_id) logger.info(f"Created new session: {session_id}") else: logger.warning(f"Session already exists: {session_id}") return session_id def get_session(self, session_id: str) -> Optional[SessionMetrics]: """Get session metrics by ID""" with self.lock: return self.sessions.get(session_id) def update_session(self, session_id: str, **kwargs): """Update session metrics""" with self.lock: if session_id in self.sessions: session = self.sessions[session_id] for key, value in kwargs.items(): if hasattr(session, key): setattr(session, key, value) logger.debug(f"Updated session {session_id}: {kwargs}") def record_query(self, session_id: str, response_time: float, tool_usage: Dict[str, int] = None): """Record a query in session metrics""" with self.lock: if session_id in self.sessions: session = self.sessions[session_id] session.total_queries += 1 session.total_response_time += response_time if tool_usage: for tool, count in tool_usage.items(): session.tool_usage[tool] = session.tool_usage.get(tool, 0) + count def record_cache_hit(self, session_id: str): """Record a cache hit""" with self.lock: if session_id in self.sessions: self.sessions[session_id].cache_hits += 1 def record_error(self, session_id: str, error: Dict[str, Any]): """Record an error in session""" with self.lock: if session_id in self.sessions: self.sessions[session_id].errors.append({ **error, "timestamp": time.time() }) def get_session_stats(self, session_id: str) -> Optional[Dict[str, Any]]: """Get comprehensive session statistics""" session = self.get_session(session_id) if not session: return None return { "session_id": session.session_id, "created_at": datetime.fromtimestamp(session.created_at).isoformat(), "uptime_hours": session.uptime_hours, "total_queries": session.total_queries, "cache_hits": session.cache_hits, "cache_hit_rate": session.cache_hit_rate, "average_response_time": session.average_response_time, "tool_usage": session.tool_usage, "errors": len(session.errors), "parallel_executions": session.parallel_executions } def cleanup_old_sessions(self, max_age_hours: float = 24.0): """Remove sessions older than specified age""" current_time = time.time() max_age_seconds = max_age_hours * 3600 with self.lock: sessions_to_remove = [ session_id for session_id, session in self.sessions.items() if current_time - session.created_at > max_age_seconds ] for session_id in sessions_to_remove: self.sessions.pop(session_id) if sessions_to_remove: logger.info(f"Cleaned up {len(sessions_to_remove)} old sessions") def get_cache(self) -> AsyncResponseCache: """Get the response cache""" return self.cache class ParallelAgentPool: """Manages parallel execution of agent tasks""" def __init__(self, max_workers: int = 4): self.max_workers = max_workers self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) self.active_tasks: Dict[str, concurrent.futures.Future] = {} self.lock = threading.RLock() logger.info(f"ParallelAgentPool initialized with {max_workers} workers") def submit_task(self, task_id: str, func, *args, **kwargs) -> concurrent.futures.Future: """Submit a task for parallel execution""" with self.lock: if task_id in self.active_tasks: logger.warning(f"Task {task_id} already exists, cancelling previous") self.active_tasks[task_id].cancel() future = self.executor.submit(func, *args, **kwargs) self.active_tasks[task_id] = future logger.debug(f"Submitted task {task_id} for parallel execution") return future def get_task_result(self, task_id: str, timeout: float = None) -> Any: """Get result of a submitted task""" with self.lock: if task_id not in self.active_tasks: raise ValueError(f"Task {task_id} not found") future = self.active_tasks[task_id] try: result = future.result(timeout=timeout) # Remove completed task self.active_tasks.pop(task_id) return result except concurrent.futures.TimeoutError: logger.warning(f"Task {task_id} timed out") raise except Exception as e: logger.error(f"Task {task_id} failed: {e}") self.active_tasks.pop(task_id) raise def cancel_task(self, task_id: str): """Cancel a running task""" with self.lock: if task_id in self.active_tasks: self.active_tasks[task_id].cancel() self.active_tasks.pop(task_id) logger.info(f"Cancelled task {task_id}") def get_active_tasks(self) -> List[str]: """Get list of active task IDs""" with self.lock: return list(self.active_tasks.keys()) def shutdown(self, wait: bool = True): """Shutdown the thread pool""" self.executor.shutdown(wait=wait) logger.info("ParallelAgentPool shutdown complete") def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() # Global instances session_manager = SessionManager() parallel_pool = ParallelAgentPool() def get_session_manager() -> SessionManager: """Get the global session manager instance""" return session_manager def get_parallel_pool() -> ParallelAgentPool: """Get the global parallel pool instance""" return parallel_pool def get_cache() -> AsyncResponseCache: """Get the global response cache""" return session_manager.get_cache()