Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import time | |
| import threading | |
| import uuid | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| from server.dataops_env_environment import DataOpsEnvironment | |
| class SessionRecord: | |
| env: DataOpsEnvironment | |
| last_access_at: float | |
| class EnvironmentSessionManager: | |
| """Small in-memory session store for isolated environment instances.""" | |
| def __init__( | |
| self, | |
| *, | |
| max_sessions: int = 128, | |
| session_timeout_s: float = 1800.0, | |
| ) -> None: | |
| self._lock = threading.Lock() | |
| self._sessions: dict[str, SessionRecord] = {} | |
| self._max_sessions = max(1, max_sessions) | |
| self._session_timeout_s = max(1.0, session_timeout_s) | |
| def reset_session( | |
| self, | |
| *, | |
| task_id: str, | |
| seed: Optional[int], | |
| episode_id: Optional[str], | |
| session_id: Optional[str], | |
| ) -> tuple[str, DataOpsEnvironment, object]: | |
| now = time.monotonic() | |
| to_close: list[DataOpsEnvironment] = [] | |
| with self._lock: | |
| to_close.extend(self._collect_expired_envs_locked(now)) | |
| record = self._sessions.get(session_id) if session_id else None | |
| if record is None: | |
| resolved_session_id = str(uuid.uuid4()) | |
| to_close.extend(self._evict_if_full_locked(now)) | |
| env = DataOpsEnvironment() | |
| self._sessions[resolved_session_id] = SessionRecord( | |
| env=env, | |
| last_access_at=now, | |
| ) | |
| else: | |
| resolved_session_id = session_id or str(uuid.uuid4()) | |
| record.last_access_at = now | |
| env = record.env | |
| self._close_envs(to_close) | |
| obs = env.reset(seed=seed, episode_id=episode_id, task_id=task_id) | |
| return resolved_session_id, env, obs | |
| def get_session( | |
| self, session_id: Optional[str] | |
| ) -> tuple[Optional[str], Optional[DataOpsEnvironment]]: | |
| now = time.monotonic() | |
| to_close: list[DataOpsEnvironment] = [] | |
| with self._lock: | |
| to_close.extend(self._collect_expired_envs_locked(now)) | |
| if session_id: | |
| record = self._sessions.get(session_id) | |
| if record is not None: | |
| record.last_access_at = now | |
| env = record.env | |
| else: | |
| env = None | |
| result = (session_id, env) | |
| else: | |
| result = (None, None) | |
| self._close_envs(to_close) | |
| return result | |
| def close_all(self) -> None: | |
| with self._lock: | |
| records = list(self._sessions.values()) | |
| self._sessions.clear() | |
| self._close_envs([record.env for record in records]) | |
| def _collect_expired_envs_locked(self, now: float) -> list[DataOpsEnvironment]: | |
| expired_ids = [ | |
| session_id | |
| for session_id, record in self._sessions.items() | |
| if now - record.last_access_at > self._session_timeout_s | |
| ] | |
| return self._remove_sessions_locked(expired_ids) | |
| def _evict_if_full_locked(self, now: float) -> list[DataOpsEnvironment]: | |
| if len(self._sessions) < self._max_sessions: | |
| return [] | |
| oldest_session_id = min( | |
| self._sessions, | |
| key=lambda session_id: self._sessions[session_id].last_access_at, | |
| ) | |
| return self._remove_sessions_locked([oldest_session_id]) | |
| def _remove_sessions_locked(self, session_ids: list[str]) -> list[DataOpsEnvironment]: | |
| removed: list[DataOpsEnvironment] = [] | |
| for session_id in session_ids: | |
| record = self._sessions.pop(session_id, None) | |
| if record is not None: | |
| removed.append(record.env) | |
| return removed | |
| def _close_envs(self, envs: list[DataOpsEnvironment]) -> None: | |
| for env in envs: | |
| env.close() | |
| def __del__(self) -> None: | |
| try: | |
| self.close_all() | |
| except Exception: | |
| pass | |