File size: 7,752 Bytes
06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 93eb118 06c11b0 a365309 06c11b0 a365309 93eb118 06c11b0 6d95d0c 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 6d95d0c 4ac0c53 6d95d0c cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 4ac0c53 cd45b78 a365309 06c11b0 a365309 4ac0c53 06c11b0 a365309 4ac0c53 93eb118 06c11b0 4ac0c53 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 06c11b0 a365309 6d95d0c a365309 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | """
状态管理模块
管理所有全局状态和 Session 生命周期。
GLOBAL_SESSIONS 中存储的是 ProcessSessionProxy,而不是 OracleSession。
实际的 OracleSession 运行在独立工作进程中,通过代理对象进行通信。
"""
import logging
import threading
from process_session import ProcessSessionProxy
LOGGER = logging.getLogger("robomme.state_manager")
# --- 全局会话存储 ---
GLOBAL_SESSIONS = {}
ACTIVE_SESSION_SLOTS = set()
# --- 任务索引存储(用于进度显示) ---
TASK_INDEX_MAP = {} # {uid: {"task_index": int, "total_tasks": int}}
# --- UI阶段存储 ---
UI_PHASE_MAP = {} # {uid: "watching_demo" | "executing_task"}
# --- Execute 次数跟踪 ---
EXECUTE_COUNTS = {} # {"{uid}:{env_id}:{episode_idx}": count}
# --- 任务开始时间跟踪 ---
TASK_START_TIMES = {} # {"{uid}:{env_id}:{episode_idx}": iso_timestamp}
# --- 播放按钮状态跟踪 ---
PLAY_BUTTON_CLICKED = {} # {uid: bool}
_state_lock = threading.Lock()
def get_session(uid):
"""获取指定 uid 的 ProcessSessionProxy。"""
with _state_lock:
return GLOBAL_SESSIONS.get(uid)
def _try_reserve_session_slot_locked(uid, session_concurrency_limit):
if uid in ACTIVE_SESSION_SLOTS:
return True
if len(ACTIVE_SESSION_SLOTS) >= int(session_concurrency_limit):
LOGGER.info(
"try_reserve_session_slot rejected uid=%s active_slots=%s limit=%s",
uid,
len(ACTIVE_SESSION_SLOTS),
session_concurrency_limit,
)
return False
ACTIVE_SESSION_SLOTS.add(uid)
LOGGER.info(
"try_reserve_session_slot acquired uid=%s active_slots=%s",
uid,
len(ACTIVE_SESSION_SLOTS),
)
return True
def try_reserve_session_slot(uid):
"""
Try to reserve a session slot without blocking.
Returns:
bool: whether the slot was acquired
"""
if not uid:
raise ValueError("Session uid cannot be empty")
from config import SESSION_CONCURRENCY_LIMIT
with _state_lock:
return _try_reserve_session_slot_locked(uid, SESSION_CONCURRENCY_LIMIT)
def release_session_slot(uid):
if not uid:
return
with _state_lock:
if uid in ACTIVE_SESSION_SLOTS:
ACTIVE_SESSION_SLOTS.remove(uid)
LOGGER.info(
"release_session_slot uid=%s active_slots=%s",
uid,
len(ACTIVE_SESSION_SLOTS),
)
def try_create_session(uid):
"""
Try to create a ProcessSessionProxy without blocking on session slot wait.
Returns:
bool: whether the session is ready
"""
if not uid:
raise ValueError("Session uid cannot be empty")
with _state_lock:
if GLOBAL_SESSIONS.get(uid) is not None:
return True
if not _try_reserve_session_slot_locked(uid, _get_session_concurrency_limit()):
return False
try:
GLOBAL_SESSIONS[uid] = ProcessSessionProxy()
except Exception:
ACTIVE_SESSION_SLOTS.discard(uid)
raise
LOGGER.info("try_create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS))
return True
def create_session(uid):
"""
为指定 session key 创建 ProcessSessionProxy。
超出并发上限时立即失败,不执行排队等待。
"""
if not uid:
raise ValueError("Session uid cannot be empty")
ready = try_create_session(uid)
if not ready:
raise RuntimeError("No session slots available")
LOGGER.info("create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS))
return uid
def _get_session_concurrency_limit():
from config import SESSION_CONCURRENCY_LIMIT
return SESSION_CONCURRENCY_LIMIT
def get_task_index(uid):
"""获取任务索引信息。"""
with _state_lock:
return TASK_INDEX_MAP.get(uid)
def set_task_index(uid, task_index, total_tasks):
"""设置任务索引信息。"""
with _state_lock:
TASK_INDEX_MAP[uid] = {
"task_index": task_index,
"total_tasks": total_tasks,
}
def get_ui_phase(uid):
"""获取 UI 阶段。"""
with _state_lock:
return UI_PHASE_MAP.get(uid, "watching_demo")
def set_ui_phase(uid, phase):
"""设置 UI 阶段。"""
with _state_lock:
UI_PHASE_MAP[uid] = phase
def reset_ui_phase(uid):
"""重置 UI 阶段为初始阶段。"""
with _state_lock:
UI_PHASE_MAP[uid] = "watching_demo"
def set_play_button_clicked(uid, clicked=True):
"""设置播放按钮是否已被点击。"""
with _state_lock:
PLAY_BUTTON_CLICKED[uid] = clicked
def get_play_button_clicked(uid):
"""获取播放按钮是否已被点击。"""
with _state_lock:
return PLAY_BUTTON_CLICKED.get(uid, False)
def reset_play_button_clicked(uid):
"""重置播放按钮点击状态。"""
with _state_lock:
PLAY_BUTTON_CLICKED.pop(uid, None)
def _get_task_key(uid, env_id, episode_idx):
return f"{uid}:{env_id}:{episode_idx}"
def get_execute_count(uid, env_id, episode_idx):
"""获取指定任务的 execute 次数。"""
with _state_lock:
task_key = _get_task_key(uid, env_id, episode_idx)
return EXECUTE_COUNTS.get(task_key, 0)
def increment_execute_count(uid, env_id, episode_idx):
"""增加指定任务的 execute 次数。"""
with _state_lock:
task_key = _get_task_key(uid, env_id, episode_idx)
current_count = EXECUTE_COUNTS.get(task_key, 0)
EXECUTE_COUNTS[task_key] = current_count + 1
return EXECUTE_COUNTS[task_key]
def reset_execute_count(uid, env_id, episode_idx):
"""重置指定任务的 execute 次数为 0。"""
with _state_lock:
task_key = _get_task_key(uid, env_id, episode_idx)
EXECUTE_COUNTS[task_key] = 0
def get_task_start_time(uid, env_id, episode_idx):
"""获取指定任务的开始时间。"""
with _state_lock:
task_key = _get_task_key(uid, env_id, episode_idx)
return TASK_START_TIMES.get(task_key)
def set_task_start_time(uid, env_id, episode_idx, start_time):
"""设置指定任务的开始时间。"""
with _state_lock:
task_key = _get_task_key(uid, env_id, episode_idx)
TASK_START_TIMES[task_key] = start_time
def clear_task_start_time(uid, env_id, episode_idx):
"""清除指定任务的开始时间记录。"""
with _state_lock:
task_key = _get_task_key(uid, env_id, episode_idx)
TASK_START_TIMES.pop(task_key, None)
def cleanup_session(uid):
"""清理指定会话的所有资源。"""
if not uid:
return
session = None
task_prefix = f"{uid}:"
with _state_lock:
session = GLOBAL_SESSIONS.pop(uid, None)
TASK_INDEX_MAP.pop(uid, None)
UI_PHASE_MAP.pop(uid, None)
PLAY_BUTTON_CLICKED.pop(uid, None)
execute_keys = [key for key in EXECUTE_COUNTS if key.startswith(task_prefix)]
task_start_keys = [key for key in TASK_START_TIMES if key.startswith(task_prefix)]
for key in execute_keys:
del EXECUTE_COUNTS[key]
for key in task_start_keys:
del TASK_START_TIMES[key]
if session is not None:
try:
LOGGER.info("cleanup_session uid=%s closing ProcessSessionProxy", uid)
session.close()
LOGGER.info("cleanup_session uid=%s proxy closed", uid)
except Exception as exc:
LOGGER.exception("cleanup_session uid=%s proxy close failed: %s", uid, exc)
release_session_slot(uid)
from user_manager import user_manager
user_manager.cleanup_session(uid)
LOGGER.info("cleanup_session uid=%s done", uid)
|