atr0p05's picture
Upload 291 files
8a682b5 verified
"""
Session Management Module
=========================
This module handles user sessions, parallel processing pools, and caching.
"""
import uuid
import time
import threading
import logging
import concurrent.futures
from typing import List, Dict, Any, Optional, Tuple
from queue import Queue
from functools import lru_cache
from dataclasses import dataclass, field
from datetime import datetime
logger = logging.getLogger(__name__)
@dataclass
class SessionMetrics:
"""Metrics for a user session"""
session_id: str
created_at: float = field(default_factory=time.time)
total_queries: int = 0
cache_hits: int = 0
total_response_time: float = 0.0
tool_usage: Dict[str, int] = field(default_factory=dict)
errors: List[Dict[str, Any]] = field(default_factory=list)
parallel_executions: int = 0
@property
def average_response_time(self) -> float:
"""Calculate average response time"""
if self.total_queries == 0:
return 0.0
return self.total_response_time / self.total_queries
@property
def cache_hit_rate(self) -> float:
"""Calculate cache hit rate percentage"""
if self.total_queries == 0:
return 0.0
return (self.cache_hits / self.total_queries) * 100
@property
def uptime_hours(self) -> float:
"""Calculate session uptime in hours"""
return (time.time() - self.created_at) / 3600
class AsyncResponseCache:
"""Advanced response caching with TTL and intelligent invalidation"""
def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600):
self.max_size = max_size
self.ttl_seconds = ttl_seconds
self.cache = {}
self.timestamps = {}
self.lock = threading.RLock()
# Start cleanup thread
self.cleanup_thread = threading.Thread(target=self._cleanup_expired, daemon=True)
self.cleanup_thread.start()
logger.info(f"Initialized AsyncResponseCache with max_size={self.max_size}, ttl={self.ttl_seconds}s")
def _cleanup_expired(self):
"""Background thread to clean up expired cache entries"""
while True:
try:
time.sleep(60) # Run cleanup every minute
current_time = time.time()
with self.lock:
expired_keys = [
key for key, timestamp in self.timestamps.items()
if current_time - timestamp > self.ttl_seconds
]
for key in expired_keys:
self.cache.pop(key, None)
self.timestamps.pop(key, None)
if expired_keys:
logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries")
except Exception as e:
logger.error(f"Error in cache cleanup: {e}")
def get(self, key: str) -> Optional[Any]:
"""Get value from cache if not expired"""
with self.lock:
if key in self.cache:
# Check if expired
if time.time() - self.timestamps[key] <= self.ttl_seconds:
logger.debug(f"Cache hit for key: {key[:50]}...")
return self.cache[key]
else:
# Remove expired entry
self.cache.pop(key)
self.timestamps.pop(key)
return None
def set(self, key: str, value: Any):
"""Set value in cache with timestamp"""
with self.lock:
# Implement LRU eviction if cache is full
if len(self.cache) >= self.max_size:
# Remove oldest entry
oldest_key = min(self.timestamps.keys(), key=lambda k: self.timestamps[k])
self.cache.pop(oldest_key)
self.timestamps.pop(oldest_key)
self.cache[key] = value
self.timestamps[key] = time.time()
logger.debug(f"Cache set for key: {key[:50]}...")
def invalidate(self, pattern: str = None):
"""Invalidate cache entries matching pattern"""
with self.lock:
if pattern:
keys_to_remove = [key for key in self.cache.keys() if pattern in key]
else:
keys_to_remove = list(self.cache.keys())
for key in keys_to_remove:
self.cache.pop(key, None)
self.timestamps.pop(key, None)
logger.info(f"Invalidated {len(keys_to_remove)} cache entries")
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
with self.lock:
return {
"size": len(self.cache),
"max_size": self.max_size,
"ttl_seconds": self.ttl_seconds,
"oldest_entry": min(self.timestamps.values()) if self.timestamps else None,
"newest_entry": max(self.timestamps.values()) if self.timestamps else None
}
class SessionManager:
"""Manages user sessions and their state"""
def __init__(self):
self.sessions: Dict[str, SessionMetrics] = {}
self.cache = AsyncResponseCache()
self.lock = threading.RLock()
logger.info("SessionManager initialized")
def create_session(self, session_id: str = None) -> str:
"""Create a new user session"""
if not session_id:
session_id = str(uuid.uuid4())
with self.lock:
if session_id not in self.sessions:
self.sessions[session_id] = SessionMetrics(session_id=session_id)
logger.info(f"Created new session: {session_id}")
else:
logger.warning(f"Session already exists: {session_id}")
return session_id
def get_session(self, session_id: str) -> Optional[SessionMetrics]:
"""Get session metrics by ID"""
with self.lock:
return self.sessions.get(session_id)
def update_session(self, session_id: str, **kwargs):
"""Update session metrics"""
with self.lock:
if session_id in self.sessions:
session = self.sessions[session_id]
for key, value in kwargs.items():
if hasattr(session, key):
setattr(session, key, value)
logger.debug(f"Updated session {session_id}: {kwargs}")
def record_query(self, session_id: str, response_time: float, tool_usage: Dict[str, int] = None):
"""Record a query in session metrics"""
with self.lock:
if session_id in self.sessions:
session = self.sessions[session_id]
session.total_queries += 1
session.total_response_time += response_time
if tool_usage:
for tool, count in tool_usage.items():
session.tool_usage[tool] = session.tool_usage.get(tool, 0) + count
def record_cache_hit(self, session_id: str):
"""Record a cache hit"""
with self.lock:
if session_id in self.sessions:
self.sessions[session_id].cache_hits += 1
def record_error(self, session_id: str, error: Dict[str, Any]):
"""Record an error in session"""
with self.lock:
if session_id in self.sessions:
self.sessions[session_id].errors.append({
**error,
"timestamp": time.time()
})
def get_session_stats(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get comprehensive session statistics"""
session = self.get_session(session_id)
if not session:
return None
return {
"session_id": session.session_id,
"created_at": datetime.fromtimestamp(session.created_at).isoformat(),
"uptime_hours": session.uptime_hours,
"total_queries": session.total_queries,
"cache_hits": session.cache_hits,
"cache_hit_rate": session.cache_hit_rate,
"average_response_time": session.average_response_time,
"tool_usage": session.tool_usage,
"errors": len(session.errors),
"parallel_executions": session.parallel_executions
}
def cleanup_old_sessions(self, max_age_hours: float = 24.0):
"""Remove sessions older than specified age"""
current_time = time.time()
max_age_seconds = max_age_hours * 3600
with self.lock:
sessions_to_remove = [
session_id for session_id, session in self.sessions.items()
if current_time - session.created_at > max_age_seconds
]
for session_id in sessions_to_remove:
self.sessions.pop(session_id)
if sessions_to_remove:
logger.info(f"Cleaned up {len(sessions_to_remove)} old sessions")
def get_cache(self) -> AsyncResponseCache:
"""Get the response cache"""
return self.cache
class ParallelAgentPool:
"""Manages parallel execution of agent tasks"""
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self.active_tasks: Dict[str, concurrent.futures.Future] = {}
self.lock = threading.RLock()
logger.info(f"ParallelAgentPool initialized with {max_workers} workers")
def submit_task(self, task_id: str, func, *args, **kwargs) -> concurrent.futures.Future:
"""Submit a task for parallel execution"""
with self.lock:
if task_id in self.active_tasks:
logger.warning(f"Task {task_id} already exists, cancelling previous")
self.active_tasks[task_id].cancel()
future = self.executor.submit(func, *args, **kwargs)
self.active_tasks[task_id] = future
logger.debug(f"Submitted task {task_id} for parallel execution")
return future
def get_task_result(self, task_id: str, timeout: float = None) -> Any:
"""Get result of a submitted task"""
with self.lock:
if task_id not in self.active_tasks:
raise ValueError(f"Task {task_id} not found")
future = self.active_tasks[task_id]
try:
result = future.result(timeout=timeout)
# Remove completed task
self.active_tasks.pop(task_id)
return result
except concurrent.futures.TimeoutError:
logger.warning(f"Task {task_id} timed out")
raise
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
self.active_tasks.pop(task_id)
raise
def cancel_task(self, task_id: str):
"""Cancel a running task"""
with self.lock:
if task_id in self.active_tasks:
self.active_tasks[task_id].cancel()
self.active_tasks.pop(task_id)
logger.info(f"Cancelled task {task_id}")
def get_active_tasks(self) -> List[str]:
"""Get list of active task IDs"""
with self.lock:
return list(self.active_tasks.keys())
def shutdown(self, wait: bool = True):
"""Shutdown the thread pool"""
self.executor.shutdown(wait=wait)
logger.info("ParallelAgentPool shutdown complete")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
# Global instances
session_manager = SessionManager()
parallel_pool = ParallelAgentPool()
def get_session_manager() -> SessionManager:
"""Get the global session manager instance"""
return session_manager
def get_parallel_pool() -> ParallelAgentPool:
"""Get the global parallel pool instance"""
return parallel_pool
def get_cache() -> AsyncResponseCache:
"""Get the global response cache"""
return session_manager.get_cache()