from __future__ import annotations import time import threading import uuid from dataclasses import dataclass from typing import Optional from server.dataops_env_environment import DataOpsEnvironment @dataclass class SessionRecord: env: DataOpsEnvironment last_access_at: float class EnvironmentSessionManager: """Small in-memory session store for isolated environment instances.""" def __init__( self, *, max_sessions: int = 128, session_timeout_s: float = 1800.0, ) -> None: self._lock = threading.Lock() self._sessions: dict[str, SessionRecord] = {} self._max_sessions = max(1, max_sessions) self._session_timeout_s = max(1.0, session_timeout_s) def reset_session( self, *, task_id: str, seed: Optional[int], episode_id: Optional[str], session_id: Optional[str], ) -> tuple[str, DataOpsEnvironment, object]: now = time.monotonic() to_close: list[DataOpsEnvironment] = [] with self._lock: to_close.extend(self._collect_expired_envs_locked(now)) record = self._sessions.get(session_id) if session_id else None if record is None: resolved_session_id = str(uuid.uuid4()) to_close.extend(self._evict_if_full_locked(now)) env = DataOpsEnvironment() self._sessions[resolved_session_id] = SessionRecord( env=env, last_access_at=now, ) else: resolved_session_id = session_id or str(uuid.uuid4()) record.last_access_at = now env = record.env self._close_envs(to_close) obs = env.reset(seed=seed, episode_id=episode_id, task_id=task_id) return resolved_session_id, env, obs def get_session( self, session_id: Optional[str] ) -> tuple[Optional[str], Optional[DataOpsEnvironment]]: now = time.monotonic() to_close: list[DataOpsEnvironment] = [] with self._lock: to_close.extend(self._collect_expired_envs_locked(now)) if session_id: record = self._sessions.get(session_id) if record is not None: record.last_access_at = now env = record.env else: env = None result = (session_id, env) else: result = (None, None) self._close_envs(to_close) return result def close_all(self) -> None: with self._lock: records = list(self._sessions.values()) self._sessions.clear() self._close_envs([record.env for record in records]) def _collect_expired_envs_locked(self, now: float) -> list[DataOpsEnvironment]: expired_ids = [ session_id for session_id, record in self._sessions.items() if now - record.last_access_at > self._session_timeout_s ] return self._remove_sessions_locked(expired_ids) def _evict_if_full_locked(self, now: float) -> list[DataOpsEnvironment]: if len(self._sessions) < self._max_sessions: return [] oldest_session_id = min( self._sessions, key=lambda session_id: self._sessions[session_id].last_access_at, ) return self._remove_sessions_locked([oldest_session_id]) def _remove_sessions_locked(self, session_ids: list[str]) -> list[DataOpsEnvironment]: removed: list[DataOpsEnvironment] = [] for session_id in session_ids: record = self._sessions.pop(session_id, None) if record is not None: removed.append(record.env) return removed def _close_envs(self, envs: list[DataOpsEnvironment]) -> None: for env in envs: env.close() def __del__(self) -> None: try: self.close_all() except Exception: pass