File size: 7,752 Bytes
06c11b0
 
a365309
06c11b0
a365309
 
06c11b0
a365309
93eb118
06c11b0
a365309
06c11b0
a365309
93eb118
06c11b0
 
 
6d95d0c
06c11b0
 
 
 
 
 
 
 
a365309
06c11b0
 
a365309
06c11b0
 
a365309
06c11b0
 
 
 
 
a365309
06c11b0
 
 
 
cd45b78
 
4ac0c53
cd45b78
4ac0c53
cd45b78
4ac0c53
cd45b78
 
4ac0c53
cd45b78
4ac0c53
cd45b78
4ac0c53
cd45b78
4ac0c53
cd45b78
 
 
4ac0c53
cd45b78
 
 
 
 
 
 
4ac0c53
cd45b78
 
 
 
 
 
4ac0c53
cd45b78
 
 
6d95d0c
 
 
 
4ac0c53
6d95d0c
 
 
 
 
 
 
 
 
cd45b78
 
 
 
 
4ac0c53
cd45b78
 
 
 
 
 
4ac0c53
 
 
 
 
 
 
 
cd45b78
 
4ac0c53
cd45b78
 
a365309
06c11b0
a365309
 
4ac0c53
06c11b0
a365309
 
 
4ac0c53
 
 
93eb118
06c11b0
 
 
4ac0c53
 
 
 
 
 
06c11b0
a365309
06c11b0
 
 
 
 
a365309
06c11b0
 
 
a365309
06c11b0
 
 
 
a365309
06c11b0
a365309
06c11b0
 
 
a365309
06c11b0
 
 
 
 
a365309
06c11b0
 
 
 
 
a365309
06c11b0
 
 
 
 
a365309
06c11b0
 
 
 
 
a365309
06c11b0
a365309
06c11b0
 
 
 
 
 
 
a365309
06c11b0
 
 
 
 
 
a365309
06c11b0
 
 
 
 
 
 
 
a365309
06c11b0
 
 
 
 
 
a365309
06c11b0
 
 
 
 
 
a365309
06c11b0
 
 
 
 
 
a365309
06c11b0
 
a365309
06c11b0
 
 
a365309
06c11b0
 
 
a365309
 
06c11b0
 
a365309
 
 
 
06c11b0
a365309
 
06c11b0
a365309
 
 
 
06c11b0
a365309
06c11b0
a365309
 
 
 
 
6d95d0c
a365309
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""
状态管理模块
管理所有全局状态和 Session 生命周期。

GLOBAL_SESSIONS 中存储的是 ProcessSessionProxy,而不是 OracleSession。
实际的 OracleSession 运行在独立工作进程中,通过代理对象进行通信。
"""

import logging
import threading

from process_session import ProcessSessionProxy

LOGGER = logging.getLogger("robomme.state_manager")

# --- 全局会话存储 ---
GLOBAL_SESSIONS = {}
ACTIVE_SESSION_SLOTS = set()

# --- 任务索引存储(用于进度显示) ---
TASK_INDEX_MAP = {}  # {uid: {"task_index": int, "total_tasks": int}}

# --- UI阶段存储 ---
UI_PHASE_MAP = {}  # {uid: "watching_demo" | "executing_task"}

# --- Execute 次数跟踪 ---
EXECUTE_COUNTS = {}  # {"{uid}:{env_id}:{episode_idx}": count}

# --- 任务开始时间跟踪 ---
TASK_START_TIMES = {}  # {"{uid}:{env_id}:{episode_idx}": iso_timestamp}

# --- 播放按钮状态跟踪 ---
PLAY_BUTTON_CLICKED = {}  # {uid: bool}

_state_lock = threading.Lock()


def get_session(uid):
    """获取指定 uid 的 ProcessSessionProxy。"""
    with _state_lock:
        return GLOBAL_SESSIONS.get(uid)


def _try_reserve_session_slot_locked(uid, session_concurrency_limit):
    if uid in ACTIVE_SESSION_SLOTS:
        return True

    if len(ACTIVE_SESSION_SLOTS) >= int(session_concurrency_limit):
        LOGGER.info(
            "try_reserve_session_slot rejected uid=%s active_slots=%s limit=%s",
            uid,
            len(ACTIVE_SESSION_SLOTS),
            session_concurrency_limit,
        )
        return False

    ACTIVE_SESSION_SLOTS.add(uid)
    LOGGER.info(
        "try_reserve_session_slot acquired uid=%s active_slots=%s",
        uid,
        len(ACTIVE_SESSION_SLOTS),
    )
    return True


def try_reserve_session_slot(uid):
    """
    Try to reserve a session slot without blocking.

    Returns:
        bool: whether the slot was acquired
    """
    if not uid:
        raise ValueError("Session uid cannot be empty")

    from config import SESSION_CONCURRENCY_LIMIT

    with _state_lock:
        return _try_reserve_session_slot_locked(uid, SESSION_CONCURRENCY_LIMIT)


def release_session_slot(uid):
    if not uid:
        return

    with _state_lock:
        if uid in ACTIVE_SESSION_SLOTS:
            ACTIVE_SESSION_SLOTS.remove(uid)
            LOGGER.info(
                "release_session_slot uid=%s active_slots=%s",
                uid,
                len(ACTIVE_SESSION_SLOTS),
            )


def try_create_session(uid):
    """
    Try to create a ProcessSessionProxy without blocking on session slot wait.

    Returns:
        bool: whether the session is ready
    """
    if not uid:
        raise ValueError("Session uid cannot be empty")

    with _state_lock:
        if GLOBAL_SESSIONS.get(uid) is not None:
            return True
        if not _try_reserve_session_slot_locked(uid, _get_session_concurrency_limit()):
            return False
        try:
            GLOBAL_SESSIONS[uid] = ProcessSessionProxy()
        except Exception:
            ACTIVE_SESSION_SLOTS.discard(uid)
            raise

    LOGGER.info("try_create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS))
    return True


def create_session(uid):
    """
    为指定 session key 创建 ProcessSessionProxy。

    超出并发上限时立即失败,不执行排队等待。
    """
    if not uid:
        raise ValueError("Session uid cannot be empty")

    ready = try_create_session(uid)
    if not ready:
        raise RuntimeError("No session slots available")
    LOGGER.info("create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS))
    return uid


def _get_session_concurrency_limit():
    from config import SESSION_CONCURRENCY_LIMIT

    return SESSION_CONCURRENCY_LIMIT


def get_task_index(uid):
    """获取任务索引信息。"""
    with _state_lock:
        return TASK_INDEX_MAP.get(uid)


def set_task_index(uid, task_index, total_tasks):
    """设置任务索引信息。"""
    with _state_lock:
        TASK_INDEX_MAP[uid] = {
            "task_index": task_index,
            "total_tasks": total_tasks,
        }


def get_ui_phase(uid):
    """获取 UI 阶段。"""
    with _state_lock:
        return UI_PHASE_MAP.get(uid, "watching_demo")


def set_ui_phase(uid, phase):
    """设置 UI 阶段。"""
    with _state_lock:
        UI_PHASE_MAP[uid] = phase


def reset_ui_phase(uid):
    """重置 UI 阶段为初始阶段。"""
    with _state_lock:
        UI_PHASE_MAP[uid] = "watching_demo"


def set_play_button_clicked(uid, clicked=True):
    """设置播放按钮是否已被点击。"""
    with _state_lock:
        PLAY_BUTTON_CLICKED[uid] = clicked


def get_play_button_clicked(uid):
    """获取播放按钮是否已被点击。"""
    with _state_lock:
        return PLAY_BUTTON_CLICKED.get(uid, False)


def reset_play_button_clicked(uid):
    """重置播放按钮点击状态。"""
    with _state_lock:
        PLAY_BUTTON_CLICKED.pop(uid, None)


def _get_task_key(uid, env_id, episode_idx):
    return f"{uid}:{env_id}:{episode_idx}"


def get_execute_count(uid, env_id, episode_idx):
    """获取指定任务的 execute 次数。"""
    with _state_lock:
        task_key = _get_task_key(uid, env_id, episode_idx)
        return EXECUTE_COUNTS.get(task_key, 0)


def increment_execute_count(uid, env_id, episode_idx):
    """增加指定任务的 execute 次数。"""
    with _state_lock:
        task_key = _get_task_key(uid, env_id, episode_idx)
        current_count = EXECUTE_COUNTS.get(task_key, 0)
        EXECUTE_COUNTS[task_key] = current_count + 1
        return EXECUTE_COUNTS[task_key]


def reset_execute_count(uid, env_id, episode_idx):
    """重置指定任务的 execute 次数为 0。"""
    with _state_lock:
        task_key = _get_task_key(uid, env_id, episode_idx)
        EXECUTE_COUNTS[task_key] = 0


def get_task_start_time(uid, env_id, episode_idx):
    """获取指定任务的开始时间。"""
    with _state_lock:
        task_key = _get_task_key(uid, env_id, episode_idx)
        return TASK_START_TIMES.get(task_key)


def set_task_start_time(uid, env_id, episode_idx, start_time):
    """设置指定任务的开始时间。"""
    with _state_lock:
        task_key = _get_task_key(uid, env_id, episode_idx)
        TASK_START_TIMES[task_key] = start_time


def clear_task_start_time(uid, env_id, episode_idx):
    """清除指定任务的开始时间记录。"""
    with _state_lock:
        task_key = _get_task_key(uid, env_id, episode_idx)
        TASK_START_TIMES.pop(task_key, None)


def cleanup_session(uid):
    """清理指定会话的所有资源。"""
    if not uid:
        return

    session = None
    task_prefix = f"{uid}:"

    with _state_lock:
        session = GLOBAL_SESSIONS.pop(uid, None)
        TASK_INDEX_MAP.pop(uid, None)
        UI_PHASE_MAP.pop(uid, None)
        PLAY_BUTTON_CLICKED.pop(uid, None)

        execute_keys = [key for key in EXECUTE_COUNTS if key.startswith(task_prefix)]
        task_start_keys = [key for key in TASK_START_TIMES if key.startswith(task_prefix)]

        for key in execute_keys:
            del EXECUTE_COUNTS[key]
        for key in task_start_keys:
            del TASK_START_TIMES[key]

    if session is not None:
        try:
            LOGGER.info("cleanup_session uid=%s closing ProcessSessionProxy", uid)
            session.close()
            LOGGER.info("cleanup_session uid=%s proxy closed", uid)
        except Exception as exc:
            LOGGER.exception("cleanup_session uid=%s proxy close failed: %s", uid, exc)
    release_session_slot(uid)

    from user_manager import user_manager

    user_manager.cleanup_session(uid)
    LOGGER.info("cleanup_session uid=%s done", uid)