datacentric-env / server /session_manager.py
Aswini-Kumar's picture
Upload server/session_manager.py with huggingface_hub
00e073a verified
"""
server/session_manager.py — Thread-safe UUID-based session store.
Each /reset creates a new isolated DataCentricEnvironment instance
identified by a UUID session_id. Multiple concurrent clients can run
episodes without state corruption.
Sessions expire after SESSION_TTL_SECONDS (default: 30 min).
Old sessions are cleaned up on each new /reset call.
"""
import threading
import time
import uuid
from typing import Optional
from server.config import cfg
from server.logger import get_logger, log_event
logger = get_logger("session_manager")
class SessionManager:
def __init__(self):
self._sessions: dict[str, dict] = {}
self._lock = threading.Lock()
self._total_sessions = 0
self._total_resets = 0
def create_session(self, env_instance) -> str:
"""Register a new environment instance, return its session_id."""
self._cleanup_expired()
with self._lock:
if len(self._sessions) >= cfg.MAX_CONCURRENT_SESSIONS:
# Evict the oldest session
oldest_id = min(self._sessions, key=lambda k: self._sessions[k]["created_at"])
del self._sessions[oldest_id]
log_event(logger, "session_evicted", session_id=oldest_id, reason="max_sessions_reached")
session_id = uuid.uuid4().hex
self._sessions[session_id] = {
"env": env_instance,
"created_at": time.time(),
"last_accessed": time.time(),
"step_count": 0,
}
self._total_sessions += 1
self._total_resets += 1
log_event(logger, "session_created", session_id=session_id,
total_active=len(self._sessions))
return session_id
def get_session(self, session_id: str) -> Optional[dict]:
"""Return session dict or None if not found / expired."""
with self._lock:
session = self._sessions.get(session_id)
if session is None:
return None
if time.time() - session["created_at"] > cfg.SESSION_TTL_SECONDS:
del self._sessions[session_id]
log_event(logger, "session_expired", session_id=session_id)
return None
session["last_accessed"] = time.time()
return session
def get_env(self, session_id: str):
"""Return the environment for a session_id, or None."""
session = self.get_session(session_id)
return session["env"] if session else None
def increment_steps(self, session_id: str):
with self._lock:
if session_id in self._sessions:
self._sessions[session_id]["step_count"] += 1
def delete_session(self, session_id: str):
with self._lock:
self._sessions.pop(session_id, None)
def _cleanup_expired(self):
"""Remove sessions older than TTL."""
now = time.time()
with self._lock:
expired = [
sid for sid, s in self._sessions.items()
if now - s["created_at"] > cfg.SESSION_TTL_SECONDS
]
for sid in expired:
del self._sessions[sid]
log_event(logger, "session_cleaned_up", session_id=sid)
def metrics(self) -> dict:
with self._lock:
active = len(self._sessions)
step_counts = [s["step_count"] for s in self._sessions.values()]
return {
"active_sessions": active,
"total_sessions_created": self._total_sessions,
"total_resets": self._total_resets,
"avg_steps_active": round(sum(step_counts) / max(len(step_counts), 1), 2),
}
# Global singleton — imported by main.py
session_manager = SessionManager()