File size: 4,071 Bytes
f89b1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

import time
import threading
import uuid
from dataclasses import dataclass
from typing import Optional

from server.dataops_env_environment import DataOpsEnvironment


@dataclass
class SessionRecord:
    env: DataOpsEnvironment
    last_access_at: float


class EnvironmentSessionManager:
    """Small in-memory session store for isolated environment instances."""

    def __init__(
        self,
        *,
        max_sessions: int = 128,
        session_timeout_s: float = 1800.0,
    ) -> None:
        self._lock = threading.Lock()
        self._sessions: dict[str, SessionRecord] = {}
        self._max_sessions = max(1, max_sessions)
        self._session_timeout_s = max(1.0, session_timeout_s)

    def reset_session(
        self,
        *,
        task_id: str,
        seed: Optional[int],
        episode_id: Optional[str],
        session_id: Optional[str],
    ) -> tuple[str, DataOpsEnvironment, object]:
        now = time.monotonic()
        to_close: list[DataOpsEnvironment] = []

        with self._lock:
            to_close.extend(self._collect_expired_envs_locked(now))

            record = self._sessions.get(session_id) if session_id else None
            if record is None:
                resolved_session_id = str(uuid.uuid4())
                to_close.extend(self._evict_if_full_locked(now))
                env = DataOpsEnvironment()
                self._sessions[resolved_session_id] = SessionRecord(
                    env=env,
                    last_access_at=now,
                )
            else:
                resolved_session_id = session_id or str(uuid.uuid4())
                record.last_access_at = now
                env = record.env

        self._close_envs(to_close)
        obs = env.reset(seed=seed, episode_id=episode_id, task_id=task_id)
        return resolved_session_id, env, obs

    def get_session(
        self, session_id: Optional[str]
    ) -> tuple[Optional[str], Optional[DataOpsEnvironment]]:
        now = time.monotonic()
        to_close: list[DataOpsEnvironment] = []

        with self._lock:
            to_close.extend(self._collect_expired_envs_locked(now))

            if session_id:
                record = self._sessions.get(session_id)
                if record is not None:
                    record.last_access_at = now
                    env = record.env
                else:
                    env = None
                result = (session_id, env)
            else:
                result = (None, None)

        self._close_envs(to_close)
        return result

    def close_all(self) -> None:
        with self._lock:
            records = list(self._sessions.values())
            self._sessions.clear()

        self._close_envs([record.env for record in records])

    def _collect_expired_envs_locked(self, now: float) -> list[DataOpsEnvironment]:
        expired_ids = [
            session_id
            for session_id, record in self._sessions.items()
            if now - record.last_access_at > self._session_timeout_s
        ]
        return self._remove_sessions_locked(expired_ids)

    def _evict_if_full_locked(self, now: float) -> list[DataOpsEnvironment]:
        if len(self._sessions) < self._max_sessions:
            return []

        oldest_session_id = min(
            self._sessions,
            key=lambda session_id: self._sessions[session_id].last_access_at,
        )
        return self._remove_sessions_locked([oldest_session_id])

    def _remove_sessions_locked(self, session_ids: list[str]) -> list[DataOpsEnvironment]:
        removed: list[DataOpsEnvironment] = []
        for session_id in session_ids:
            record = self._sessions.pop(session_id, None)
            if record is not None:
                removed.append(record.env)
        return removed

    def _close_envs(self, envs: list[DataOpsEnvironment]) -> None:
        for env in envs:
            env.close()

    def __del__(self) -> None:
        try:
            self.close_all()
        except Exception:
            pass