| | """ |
| | 状态管理模块 |
| | 管理所有全局状态和 Session 生命周期。 |
| | |
| | GLOBAL_SESSIONS 中存储的是 ProcessSessionProxy,而不是 OracleSession。 |
| | 实际的 OracleSession 运行在独立工作进程中,通过代理对象进行通信。 |
| | """ |
| |
|
| | import logging |
| | import threading |
| |
|
| | from process_session import ProcessSessionProxy |
| |
|
| | LOGGER = logging.getLogger("robomme.state_manager") |
| |
|
| | |
| | GLOBAL_SESSIONS = {} |
| | ACTIVE_SESSION_SLOTS = set() |
| |
|
| | |
| | TASK_INDEX_MAP = {} |
| |
|
| | |
| | UI_PHASE_MAP = {} |
| |
|
| | |
| | EXECUTE_COUNTS = {} |
| |
|
| | |
| | TASK_START_TIMES = {} |
| |
|
| | |
| | PLAY_BUTTON_CLICKED = {} |
| |
|
| | _state_lock = threading.Lock() |
| |
|
| |
|
| | def get_session(uid): |
| | """获取指定 uid 的 ProcessSessionProxy。""" |
| | with _state_lock: |
| | return GLOBAL_SESSIONS.get(uid) |
| |
|
| |
|
| | def _try_reserve_session_slot_locked(uid, session_concurrency_limit): |
| | if uid in ACTIVE_SESSION_SLOTS: |
| | return True |
| |
|
| | if len(ACTIVE_SESSION_SLOTS) >= int(session_concurrency_limit): |
| | LOGGER.info( |
| | "try_reserve_session_slot rejected uid=%s active_slots=%s limit=%s", |
| | uid, |
| | len(ACTIVE_SESSION_SLOTS), |
| | session_concurrency_limit, |
| | ) |
| | return False |
| |
|
| | ACTIVE_SESSION_SLOTS.add(uid) |
| | LOGGER.info( |
| | "try_reserve_session_slot acquired uid=%s active_slots=%s", |
| | uid, |
| | len(ACTIVE_SESSION_SLOTS), |
| | ) |
| | return True |
| |
|
| |
|
| | def try_reserve_session_slot(uid): |
| | """ |
| | Try to reserve a session slot without blocking. |
| | |
| | Returns: |
| | bool: whether the slot was acquired |
| | """ |
| | if not uid: |
| | raise ValueError("Session uid cannot be empty") |
| |
|
| | from config import SESSION_CONCURRENCY_LIMIT |
| |
|
| | with _state_lock: |
| | return _try_reserve_session_slot_locked(uid, SESSION_CONCURRENCY_LIMIT) |
| |
|
| |
|
| | def release_session_slot(uid): |
| | if not uid: |
| | return |
| |
|
| | with _state_lock: |
| | if uid in ACTIVE_SESSION_SLOTS: |
| | ACTIVE_SESSION_SLOTS.remove(uid) |
| | LOGGER.info( |
| | "release_session_slot uid=%s active_slots=%s", |
| | uid, |
| | len(ACTIVE_SESSION_SLOTS), |
| | ) |
| |
|
| |
|
| | def try_create_session(uid): |
| | """ |
| | Try to create a ProcessSessionProxy without blocking on session slot wait. |
| | |
| | Returns: |
| | bool: whether the session is ready |
| | """ |
| | if not uid: |
| | raise ValueError("Session uid cannot be empty") |
| |
|
| | with _state_lock: |
| | if GLOBAL_SESSIONS.get(uid) is not None: |
| | return True |
| | if not _try_reserve_session_slot_locked(uid, _get_session_concurrency_limit()): |
| | return False |
| | try: |
| | GLOBAL_SESSIONS[uid] = ProcessSessionProxy() |
| | except Exception: |
| | ACTIVE_SESSION_SLOTS.discard(uid) |
| | raise |
| |
|
| | LOGGER.info("try_create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS)) |
| | return True |
| |
|
| |
|
| | def create_session(uid): |
| | """ |
| | 为指定 session key 创建 ProcessSessionProxy。 |
| | |
| | 超出并发上限时立即失败,不执行排队等待。 |
| | """ |
| | if not uid: |
| | raise ValueError("Session uid cannot be empty") |
| |
|
| | ready = try_create_session(uid) |
| | if not ready: |
| | raise RuntimeError("No session slots available") |
| | LOGGER.info("create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS)) |
| | return uid |
| |
|
| |
|
| | def _get_session_concurrency_limit(): |
| | from config import SESSION_CONCURRENCY_LIMIT |
| |
|
| | return SESSION_CONCURRENCY_LIMIT |
| |
|
| |
|
| | def get_task_index(uid): |
| | """获取任务索引信息。""" |
| | with _state_lock: |
| | return TASK_INDEX_MAP.get(uid) |
| |
|
| |
|
| | def set_task_index(uid, task_index, total_tasks): |
| | """设置任务索引信息。""" |
| | with _state_lock: |
| | TASK_INDEX_MAP[uid] = { |
| | "task_index": task_index, |
| | "total_tasks": total_tasks, |
| | } |
| |
|
| |
|
| | def get_ui_phase(uid): |
| | """获取 UI 阶段。""" |
| | with _state_lock: |
| | return UI_PHASE_MAP.get(uid, "watching_demo") |
| |
|
| |
|
| | def set_ui_phase(uid, phase): |
| | """设置 UI 阶段。""" |
| | with _state_lock: |
| | UI_PHASE_MAP[uid] = phase |
| |
|
| |
|
| | def reset_ui_phase(uid): |
| | """重置 UI 阶段为初始阶段。""" |
| | with _state_lock: |
| | UI_PHASE_MAP[uid] = "watching_demo" |
| |
|
| |
|
| | def set_play_button_clicked(uid, clicked=True): |
| | """设置播放按钮是否已被点击。""" |
| | with _state_lock: |
| | PLAY_BUTTON_CLICKED[uid] = clicked |
| |
|
| |
|
| | def get_play_button_clicked(uid): |
| | """获取播放按钮是否已被点击。""" |
| | with _state_lock: |
| | return PLAY_BUTTON_CLICKED.get(uid, False) |
| |
|
| |
|
| | def reset_play_button_clicked(uid): |
| | """重置播放按钮点击状态。""" |
| | with _state_lock: |
| | PLAY_BUTTON_CLICKED.pop(uid, None) |
| |
|
| |
|
| | def _get_task_key(uid, env_id, episode_idx): |
| | return f"{uid}:{env_id}:{episode_idx}" |
| |
|
| |
|
| | def get_execute_count(uid, env_id, episode_idx): |
| | """获取指定任务的 execute 次数。""" |
| | with _state_lock: |
| | task_key = _get_task_key(uid, env_id, episode_idx) |
| | return EXECUTE_COUNTS.get(task_key, 0) |
| |
|
| |
|
| | def increment_execute_count(uid, env_id, episode_idx): |
| | """增加指定任务的 execute 次数。""" |
| | with _state_lock: |
| | task_key = _get_task_key(uid, env_id, episode_idx) |
| | current_count = EXECUTE_COUNTS.get(task_key, 0) |
| | EXECUTE_COUNTS[task_key] = current_count + 1 |
| | return EXECUTE_COUNTS[task_key] |
| |
|
| |
|
| | def reset_execute_count(uid, env_id, episode_idx): |
| | """重置指定任务的 execute 次数为 0。""" |
| | with _state_lock: |
| | task_key = _get_task_key(uid, env_id, episode_idx) |
| | EXECUTE_COUNTS[task_key] = 0 |
| |
|
| |
|
| | def get_task_start_time(uid, env_id, episode_idx): |
| | """获取指定任务的开始时间。""" |
| | with _state_lock: |
| | task_key = _get_task_key(uid, env_id, episode_idx) |
| | return TASK_START_TIMES.get(task_key) |
| |
|
| |
|
| | def set_task_start_time(uid, env_id, episode_idx, start_time): |
| | """设置指定任务的开始时间。""" |
| | with _state_lock: |
| | task_key = _get_task_key(uid, env_id, episode_idx) |
| | TASK_START_TIMES[task_key] = start_time |
| |
|
| |
|
| | def clear_task_start_time(uid, env_id, episode_idx): |
| | """清除指定任务的开始时间记录。""" |
| | with _state_lock: |
| | task_key = _get_task_key(uid, env_id, episode_idx) |
| | TASK_START_TIMES.pop(task_key, None) |
| |
|
| |
|
| | def cleanup_session(uid): |
| | """清理指定会话的所有资源。""" |
| | if not uid: |
| | return |
| |
|
| | session = None |
| | task_prefix = f"{uid}:" |
| |
|
| | with _state_lock: |
| | session = GLOBAL_SESSIONS.pop(uid, None) |
| | TASK_INDEX_MAP.pop(uid, None) |
| | UI_PHASE_MAP.pop(uid, None) |
| | PLAY_BUTTON_CLICKED.pop(uid, None) |
| |
|
| | execute_keys = [key for key in EXECUTE_COUNTS if key.startswith(task_prefix)] |
| | task_start_keys = [key for key in TASK_START_TIMES if key.startswith(task_prefix)] |
| |
|
| | for key in execute_keys: |
| | del EXECUTE_COUNTS[key] |
| | for key in task_start_keys: |
| | del TASK_START_TIMES[key] |
| |
|
| | if session is not None: |
| | try: |
| | LOGGER.info("cleanup_session uid=%s closing ProcessSessionProxy", uid) |
| | session.close() |
| | LOGGER.info("cleanup_session uid=%s proxy closed", uid) |
| | except Exception as exc: |
| | LOGGER.exception("cleanup_session uid=%s proxy close failed: %s", uid, exc) |
| | release_session_slot(uid) |
| |
|
| | from user_manager import user_manager |
| |
|
| | user_manager.cleanup_session(uid) |
| | LOGGER.info("cleanup_session uid=%s done", uid) |
| |
|