File size: 3,798 Bytes
00e073a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
server/session_manager.py — Thread-safe UUID-based session store.

Each /reset creates a new isolated DataCentricEnvironment instance
identified by a UUID session_id. Multiple concurrent clients can run
episodes without state corruption.

Sessions expire after SESSION_TTL_SECONDS (default: 30 min).
Old sessions are cleaned up on each new /reset call.
"""
import threading
import time
import uuid
from typing import Optional
from server.config import cfg
from server.logger import get_logger, log_event

logger = get_logger("session_manager")


class SessionManager:
    def __init__(self):
        self._sessions: dict[str, dict] = {}
        self._lock = threading.Lock()
        self._total_sessions = 0
        self._total_resets = 0

    def create_session(self, env_instance) -> str:
        """Register a new environment instance, return its session_id."""
        self._cleanup_expired()

        with self._lock:
            if len(self._sessions) >= cfg.MAX_CONCURRENT_SESSIONS:
                # Evict the oldest session
                oldest_id = min(self._sessions, key=lambda k: self._sessions[k]["created_at"])
                del self._sessions[oldest_id]
                log_event(logger, "session_evicted", session_id=oldest_id, reason="max_sessions_reached")

            session_id = uuid.uuid4().hex
            self._sessions[session_id] = {
                "env": env_instance,
                "created_at": time.time(),
                "last_accessed": time.time(),
                "step_count": 0,
            }
            self._total_sessions += 1
            self._total_resets += 1

        log_event(logger, "session_created", session_id=session_id,
                  total_active=len(self._sessions))
        return session_id

    def get_session(self, session_id: str) -> Optional[dict]:
        """Return session dict or None if not found / expired."""
        with self._lock:
            session = self._sessions.get(session_id)
            if session is None:
                return None
            if time.time() - session["created_at"] > cfg.SESSION_TTL_SECONDS:
                del self._sessions[session_id]
                log_event(logger, "session_expired", session_id=session_id)
                return None
            session["last_accessed"] = time.time()
            return session

    def get_env(self, session_id: str):
        """Return the environment for a session_id, or None."""
        session = self.get_session(session_id)
        return session["env"] if session else None

    def increment_steps(self, session_id: str):
        with self._lock:
            if session_id in self._sessions:
                self._sessions[session_id]["step_count"] += 1

    def delete_session(self, session_id: str):
        with self._lock:
            self._sessions.pop(session_id, None)

    def _cleanup_expired(self):
        """Remove sessions older than TTL."""
        now = time.time()
        with self._lock:
            expired = [
                sid for sid, s in self._sessions.items()
                if now - s["created_at"] > cfg.SESSION_TTL_SECONDS
            ]
            for sid in expired:
                del self._sessions[sid]
                log_event(logger, "session_cleaned_up", session_id=sid)

    def metrics(self) -> dict:
        with self._lock:
            active = len(self._sessions)
            step_counts = [s["step_count"] for s in self._sessions.values()]
        return {
            "active_sessions": active,
            "total_sessions_created": self._total_sessions,
            "total_resets": self._total_resets,
            "avg_steps_active": round(sum(step_counts) / max(len(step_counts), 1), 2),
        }


# Global singleton — imported by main.py
session_manager = SessionManager()