Spaces:
Build error
Build error
| """ | |
| Session Management Module | |
| ========================= | |
| This module handles user sessions, parallel processing pools, and caching. | |
| """ | |
| import uuid | |
| import time | |
| import threading | |
| import logging | |
| import concurrent.futures | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from queue import Queue | |
| from functools import lru_cache | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| logger = logging.getLogger(__name__) | |
| class SessionMetrics: | |
| """Metrics for a user session""" | |
| session_id: str | |
| created_at: float = field(default_factory=time.time) | |
| total_queries: int = 0 | |
| cache_hits: int = 0 | |
| total_response_time: float = 0.0 | |
| tool_usage: Dict[str, int] = field(default_factory=dict) | |
| errors: List[Dict[str, Any]] = field(default_factory=list) | |
| parallel_executions: int = 0 | |
| def average_response_time(self) -> float: | |
| """Calculate average response time""" | |
| if self.total_queries == 0: | |
| return 0.0 | |
| return self.total_response_time / self.total_queries | |
| def cache_hit_rate(self) -> float: | |
| """Calculate cache hit rate percentage""" | |
| if self.total_queries == 0: | |
| return 0.0 | |
| return (self.cache_hits / self.total_queries) * 100 | |
| def uptime_hours(self) -> float: | |
| """Calculate session uptime in hours""" | |
| return (time.time() - self.created_at) / 3600 | |
| class AsyncResponseCache: | |
| """Advanced response caching with TTL and intelligent invalidation""" | |
| def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600): | |
| self.max_size = max_size | |
| self.ttl_seconds = ttl_seconds | |
| self.cache = {} | |
| self.timestamps = {} | |
| self.lock = threading.RLock() | |
| # Start cleanup thread | |
| self.cleanup_thread = threading.Thread(target=self._cleanup_expired, daemon=True) | |
| self.cleanup_thread.start() | |
| logger.info(f"Initialized AsyncResponseCache with max_size={self.max_size}, ttl={self.ttl_seconds}s") | |
| def _cleanup_expired(self): | |
| """Background thread to clean up expired cache entries""" | |
| while True: | |
| try: | |
| time.sleep(60) # Run cleanup every minute | |
| current_time = time.time() | |
| with self.lock: | |
| expired_keys = [ | |
| key for key, timestamp in self.timestamps.items() | |
| if current_time - timestamp > self.ttl_seconds | |
| ] | |
| for key in expired_keys: | |
| self.cache.pop(key, None) | |
| self.timestamps.pop(key, None) | |
| if expired_keys: | |
| logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries") | |
| except Exception as e: | |
| logger.error(f"Error in cache cleanup: {e}") | |
| def get(self, key: str) -> Optional[Any]: | |
| """Get value from cache if not expired""" | |
| with self.lock: | |
| if key in self.cache: | |
| # Check if expired | |
| if time.time() - self.timestamps[key] <= self.ttl_seconds: | |
| logger.debug(f"Cache hit for key: {key[:50]}...") | |
| return self.cache[key] | |
| else: | |
| # Remove expired entry | |
| self.cache.pop(key) | |
| self.timestamps.pop(key) | |
| return None | |
| def set(self, key: str, value: Any): | |
| """Set value in cache with timestamp""" | |
| with self.lock: | |
| # Implement LRU eviction if cache is full | |
| if len(self.cache) >= self.max_size: | |
| # Remove oldest entry | |
| oldest_key = min(self.timestamps.keys(), key=lambda k: self.timestamps[k]) | |
| self.cache.pop(oldest_key) | |
| self.timestamps.pop(oldest_key) | |
| self.cache[key] = value | |
| self.timestamps[key] = time.time() | |
| logger.debug(f"Cache set for key: {key[:50]}...") | |
| def invalidate(self, pattern: str = None): | |
| """Invalidate cache entries matching pattern""" | |
| with self.lock: | |
| if pattern: | |
| keys_to_remove = [key for key in self.cache.keys() if pattern in key] | |
| else: | |
| keys_to_remove = list(self.cache.keys()) | |
| for key in keys_to_remove: | |
| self.cache.pop(key, None) | |
| self.timestamps.pop(key, None) | |
| logger.info(f"Invalidated {len(keys_to_remove)} cache entries") | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics""" | |
| with self.lock: | |
| return { | |
| "size": len(self.cache), | |
| "max_size": self.max_size, | |
| "ttl_seconds": self.ttl_seconds, | |
| "oldest_entry": min(self.timestamps.values()) if self.timestamps else None, | |
| "newest_entry": max(self.timestamps.values()) if self.timestamps else None | |
| } | |
| class SessionManager: | |
| """Manages user sessions and their state""" | |
| def __init__(self): | |
| self.sessions: Dict[str, SessionMetrics] = {} | |
| self.cache = AsyncResponseCache() | |
| self.lock = threading.RLock() | |
| logger.info("SessionManager initialized") | |
| def create_session(self, session_id: str = None) -> str: | |
| """Create a new user session""" | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| with self.lock: | |
| if session_id not in self.sessions: | |
| self.sessions[session_id] = SessionMetrics(session_id=session_id) | |
| logger.info(f"Created new session: {session_id}") | |
| else: | |
| logger.warning(f"Session already exists: {session_id}") | |
| return session_id | |
| def get_session(self, session_id: str) -> Optional[SessionMetrics]: | |
| """Get session metrics by ID""" | |
| with self.lock: | |
| return self.sessions.get(session_id) | |
| def update_session(self, session_id: str, **kwargs): | |
| """Update session metrics""" | |
| with self.lock: | |
| if session_id in self.sessions: | |
| session = self.sessions[session_id] | |
| for key, value in kwargs.items(): | |
| if hasattr(session, key): | |
| setattr(session, key, value) | |
| logger.debug(f"Updated session {session_id}: {kwargs}") | |
| def record_query(self, session_id: str, response_time: float, tool_usage: Dict[str, int] = None): | |
| """Record a query in session metrics""" | |
| with self.lock: | |
| if session_id in self.sessions: | |
| session = self.sessions[session_id] | |
| session.total_queries += 1 | |
| session.total_response_time += response_time | |
| if tool_usage: | |
| for tool, count in tool_usage.items(): | |
| session.tool_usage[tool] = session.tool_usage.get(tool, 0) + count | |
| def record_cache_hit(self, session_id: str): | |
| """Record a cache hit""" | |
| with self.lock: | |
| if session_id in self.sessions: | |
| self.sessions[session_id].cache_hits += 1 | |
| def record_error(self, session_id: str, error: Dict[str, Any]): | |
| """Record an error in session""" | |
| with self.lock: | |
| if session_id in self.sessions: | |
| self.sessions[session_id].errors.append({ | |
| **error, | |
| "timestamp": time.time() | |
| }) | |
| def get_session_stats(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get comprehensive session statistics""" | |
| session = self.get_session(session_id) | |
| if not session: | |
| return None | |
| return { | |
| "session_id": session.session_id, | |
| "created_at": datetime.fromtimestamp(session.created_at).isoformat(), | |
| "uptime_hours": session.uptime_hours, | |
| "total_queries": session.total_queries, | |
| "cache_hits": session.cache_hits, | |
| "cache_hit_rate": session.cache_hit_rate, | |
| "average_response_time": session.average_response_time, | |
| "tool_usage": session.tool_usage, | |
| "errors": len(session.errors), | |
| "parallel_executions": session.parallel_executions | |
| } | |
| def cleanup_old_sessions(self, max_age_hours: float = 24.0): | |
| """Remove sessions older than specified age""" | |
| current_time = time.time() | |
| max_age_seconds = max_age_hours * 3600 | |
| with self.lock: | |
| sessions_to_remove = [ | |
| session_id for session_id, session in self.sessions.items() | |
| if current_time - session.created_at > max_age_seconds | |
| ] | |
| for session_id in sessions_to_remove: | |
| self.sessions.pop(session_id) | |
| if sessions_to_remove: | |
| logger.info(f"Cleaned up {len(sessions_to_remove)} old sessions") | |
| def get_cache(self) -> AsyncResponseCache: | |
| """Get the response cache""" | |
| return self.cache | |
| class ParallelAgentPool: | |
| """Manages parallel execution of agent tasks""" | |
| def __init__(self, max_workers: int = 4): | |
| self.max_workers = max_workers | |
| self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) | |
| self.active_tasks: Dict[str, concurrent.futures.Future] = {} | |
| self.lock = threading.RLock() | |
| logger.info(f"ParallelAgentPool initialized with {max_workers} workers") | |
| def submit_task(self, task_id: str, func, *args, **kwargs) -> concurrent.futures.Future: | |
| """Submit a task for parallel execution""" | |
| with self.lock: | |
| if task_id in self.active_tasks: | |
| logger.warning(f"Task {task_id} already exists, cancelling previous") | |
| self.active_tasks[task_id].cancel() | |
| future = self.executor.submit(func, *args, **kwargs) | |
| self.active_tasks[task_id] = future | |
| logger.debug(f"Submitted task {task_id} for parallel execution") | |
| return future | |
| def get_task_result(self, task_id: str, timeout: float = None) -> Any: | |
| """Get result of a submitted task""" | |
| with self.lock: | |
| if task_id not in self.active_tasks: | |
| raise ValueError(f"Task {task_id} not found") | |
| future = self.active_tasks[task_id] | |
| try: | |
| result = future.result(timeout=timeout) | |
| # Remove completed task | |
| self.active_tasks.pop(task_id) | |
| return result | |
| except concurrent.futures.TimeoutError: | |
| logger.warning(f"Task {task_id} timed out") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Task {task_id} failed: {e}") | |
| self.active_tasks.pop(task_id) | |
| raise | |
| def cancel_task(self, task_id: str): | |
| """Cancel a running task""" | |
| with self.lock: | |
| if task_id in self.active_tasks: | |
| self.active_tasks[task_id].cancel() | |
| self.active_tasks.pop(task_id) | |
| logger.info(f"Cancelled task {task_id}") | |
| def get_active_tasks(self) -> List[str]: | |
| """Get list of active task IDs""" | |
| with self.lock: | |
| return list(self.active_tasks.keys()) | |
| def shutdown(self, wait: bool = True): | |
| """Shutdown the thread pool""" | |
| self.executor.shutdown(wait=wait) | |
| logger.info("ParallelAgentPool shutdown complete") | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.shutdown() | |
| # Global instances | |
| session_manager = SessionManager() | |
| parallel_pool = ParallelAgentPool() | |
| def get_session_manager() -> SessionManager: | |
| """Get the global session manager instance""" | |
| return session_manager | |
| def get_parallel_pool() -> ParallelAgentPool: | |
| """Get the global parallel pool instance""" | |
| return parallel_pool | |
| def get_cache() -> AsyncResponseCache: | |
| """Get the global response cache""" | |
| return session_manager.get_cache() |