dataops-env / server /session_manager.py
visheshrathi's picture
Upload folder using huggingface_hub
f89b1ac verified
from __future__ import annotations
import time
import threading
import uuid
from dataclasses import dataclass
from typing import Optional
from server.dataops_env_environment import DataOpsEnvironment
@dataclass
class SessionRecord:
env: DataOpsEnvironment
last_access_at: float
class EnvironmentSessionManager:
"""Small in-memory session store for isolated environment instances."""
def __init__(
self,
*,
max_sessions: int = 128,
session_timeout_s: float = 1800.0,
) -> None:
self._lock = threading.Lock()
self._sessions: dict[str, SessionRecord] = {}
self._max_sessions = max(1, max_sessions)
self._session_timeout_s = max(1.0, session_timeout_s)
def reset_session(
self,
*,
task_id: str,
seed: Optional[int],
episode_id: Optional[str],
session_id: Optional[str],
) -> tuple[str, DataOpsEnvironment, object]:
now = time.monotonic()
to_close: list[DataOpsEnvironment] = []
with self._lock:
to_close.extend(self._collect_expired_envs_locked(now))
record = self._sessions.get(session_id) if session_id else None
if record is None:
resolved_session_id = str(uuid.uuid4())
to_close.extend(self._evict_if_full_locked(now))
env = DataOpsEnvironment()
self._sessions[resolved_session_id] = SessionRecord(
env=env,
last_access_at=now,
)
else:
resolved_session_id = session_id or str(uuid.uuid4())
record.last_access_at = now
env = record.env
self._close_envs(to_close)
obs = env.reset(seed=seed, episode_id=episode_id, task_id=task_id)
return resolved_session_id, env, obs
def get_session(
self, session_id: Optional[str]
) -> tuple[Optional[str], Optional[DataOpsEnvironment]]:
now = time.monotonic()
to_close: list[DataOpsEnvironment] = []
with self._lock:
to_close.extend(self._collect_expired_envs_locked(now))
if session_id:
record = self._sessions.get(session_id)
if record is not None:
record.last_access_at = now
env = record.env
else:
env = None
result = (session_id, env)
else:
result = (None, None)
self._close_envs(to_close)
return result
def close_all(self) -> None:
with self._lock:
records = list(self._sessions.values())
self._sessions.clear()
self._close_envs([record.env for record in records])
def _collect_expired_envs_locked(self, now: float) -> list[DataOpsEnvironment]:
expired_ids = [
session_id
for session_id, record in self._sessions.items()
if now - record.last_access_at > self._session_timeout_s
]
return self._remove_sessions_locked(expired_ids)
def _evict_if_full_locked(self, now: float) -> list[DataOpsEnvironment]:
if len(self._sessions) < self._max_sessions:
return []
oldest_session_id = min(
self._sessions,
key=lambda session_id: self._sessions[session_id].last_access_at,
)
return self._remove_sessions_locked([oldest_session_id])
def _remove_sessions_locked(self, session_ids: list[str]) -> list[DataOpsEnvironment]:
removed: list[DataOpsEnvironment] = []
for session_id in session_ids:
record = self._sessions.pop(session_id, None)
if record is not None:
removed.append(record.env)
return removed
def _close_envs(self, envs: list[DataOpsEnvironment]) -> None:
for env in envs:
env.close()
def __del__(self) -> None:
try:
self.close_all()
except Exception:
pass