Spaces:
Sleeping
Sleeping
| """ | |
| server/session_manager.py — Thread-safe UUID-based session store. | |
| Each /reset creates a new isolated DataCentricEnvironment instance | |
| identified by a UUID session_id. Multiple concurrent clients can run | |
| episodes without state corruption. | |
| Sessions expire after SESSION_TTL_SECONDS (default: 30 min). | |
| Old sessions are cleaned up on each new /reset call. | |
| """ | |
| import threading | |
| import time | |
| import uuid | |
| from typing import Optional | |
| from server.config import cfg | |
| from server.logger import get_logger, log_event | |
| logger = get_logger("session_manager") | |
| class SessionManager: | |
| def __init__(self): | |
| self._sessions: dict[str, dict] = {} | |
| self._lock = threading.Lock() | |
| self._total_sessions = 0 | |
| self._total_resets = 0 | |
| def create_session(self, env_instance) -> str: | |
| """Register a new environment instance, return its session_id.""" | |
| self._cleanup_expired() | |
| with self._lock: | |
| if len(self._sessions) >= cfg.MAX_CONCURRENT_SESSIONS: | |
| # Evict the oldest session | |
| oldest_id = min(self._sessions, key=lambda k: self._sessions[k]["created_at"]) | |
| del self._sessions[oldest_id] | |
| log_event(logger, "session_evicted", session_id=oldest_id, reason="max_sessions_reached") | |
| session_id = uuid.uuid4().hex | |
| self._sessions[session_id] = { | |
| "env": env_instance, | |
| "created_at": time.time(), | |
| "last_accessed": time.time(), | |
| "step_count": 0, | |
| } | |
| self._total_sessions += 1 | |
| self._total_resets += 1 | |
| log_event(logger, "session_created", session_id=session_id, | |
| total_active=len(self._sessions)) | |
| return session_id | |
| def get_session(self, session_id: str) -> Optional[dict]: | |
| """Return session dict or None if not found / expired.""" | |
| with self._lock: | |
| session = self._sessions.get(session_id) | |
| if session is None: | |
| return None | |
| if time.time() - session["created_at"] > cfg.SESSION_TTL_SECONDS: | |
| del self._sessions[session_id] | |
| log_event(logger, "session_expired", session_id=session_id) | |
| return None | |
| session["last_accessed"] = time.time() | |
| return session | |
| def get_env(self, session_id: str): | |
| """Return the environment for a session_id, or None.""" | |
| session = self.get_session(session_id) | |
| return session["env"] if session else None | |
| def increment_steps(self, session_id: str): | |
| with self._lock: | |
| if session_id in self._sessions: | |
| self._sessions[session_id]["step_count"] += 1 | |
| def delete_session(self, session_id: str): | |
| with self._lock: | |
| self._sessions.pop(session_id, None) | |
| def _cleanup_expired(self): | |
| """Remove sessions older than TTL.""" | |
| now = time.time() | |
| with self._lock: | |
| expired = [ | |
| sid for sid, s in self._sessions.items() | |
| if now - s["created_at"] > cfg.SESSION_TTL_SECONDS | |
| ] | |
| for sid in expired: | |
| del self._sessions[sid] | |
| log_event(logger, "session_cleaned_up", session_id=sid) | |
| def metrics(self) -> dict: | |
| with self._lock: | |
| active = len(self._sessions) | |
| step_counts = [s["step_count"] for s in self._sessions.values()] | |
| return { | |
| "active_sessions": active, | |
| "total_sessions_created": self._total_sessions, | |
| "total_resets": self._total_resets, | |
| "avg_steps_active": round(sum(step_counts) / max(len(step_counts), 1), 2), | |
| } | |
| # Global singleton — imported by main.py | |
| session_manager = SessionManager() | |