| """ |
| 状态管理模块 |
| 管理所有全局状态和 Session 生命周期。 |
| |
| GLOBAL_SESSIONS 中存储的是 ProcessSessionProxy,而不是 OracleSession。 |
| 实际的 OracleSession 运行在独立工作进程中,通过代理对象进行通信。 |
| """ |
|
|
| import logging |
| import threading |
|
|
| from process_session import ProcessSessionProxy |
|
|
| LOGGER = logging.getLogger("robomme.state_manager") |
|
|
| |
| GLOBAL_SESSIONS = {} |
| ACTIVE_SESSION_SLOTS = set() |
|
|
| |
| TASK_INDEX_MAP = {} |
|
|
| |
| UI_PHASE_MAP = {} |
|
|
| |
| EXECUTE_COUNTS = {} |
|
|
| |
| TASK_START_TIMES = {} |
|
|
| |
| PLAY_BUTTON_CLICKED = {} |
|
|
| _state_lock = threading.Lock() |
|
|
|
|
| def get_session(uid): |
| """获取指定 uid 的 ProcessSessionProxy。""" |
| with _state_lock: |
| return GLOBAL_SESSIONS.get(uid) |
|
|
|
|
| def _try_reserve_session_slot_locked(uid, session_concurrency_limit): |
| if uid in ACTIVE_SESSION_SLOTS: |
| return True |
|
|
| if len(ACTIVE_SESSION_SLOTS) >= int(session_concurrency_limit): |
| LOGGER.info( |
| "try_reserve_session_slot rejected uid=%s active_slots=%s limit=%s", |
| uid, |
| len(ACTIVE_SESSION_SLOTS), |
| session_concurrency_limit, |
| ) |
| return False |
|
|
| ACTIVE_SESSION_SLOTS.add(uid) |
| LOGGER.info( |
| "try_reserve_session_slot acquired uid=%s active_slots=%s", |
| uid, |
| len(ACTIVE_SESSION_SLOTS), |
| ) |
| return True |
|
|
|
|
| def try_reserve_session_slot(uid): |
| """ |
| Try to reserve a session slot without blocking. |
| |
| Returns: |
| bool: whether the slot was acquired |
| """ |
| if not uid: |
| raise ValueError("Session uid cannot be empty") |
|
|
| from config import SESSION_CONCURRENCY_LIMIT |
|
|
| with _state_lock: |
| return _try_reserve_session_slot_locked(uid, SESSION_CONCURRENCY_LIMIT) |
|
|
|
|
| def release_session_slot(uid): |
| if not uid: |
| return |
|
|
| with _state_lock: |
| if uid in ACTIVE_SESSION_SLOTS: |
| ACTIVE_SESSION_SLOTS.remove(uid) |
| LOGGER.info( |
| "release_session_slot uid=%s active_slots=%s", |
| uid, |
| len(ACTIVE_SESSION_SLOTS), |
| ) |
|
|
|
|
| def try_create_session(uid): |
| """ |
| Try to create a ProcessSessionProxy without blocking on session slot wait. |
| |
| Returns: |
| bool: whether the session is ready |
| """ |
| if not uid: |
| raise ValueError("Session uid cannot be empty") |
|
|
| with _state_lock: |
| if GLOBAL_SESSIONS.get(uid) is not None: |
| return True |
| if not _try_reserve_session_slot_locked(uid, _get_session_concurrency_limit()): |
| return False |
| try: |
| GLOBAL_SESSIONS[uid] = ProcessSessionProxy() |
| except Exception: |
| ACTIVE_SESSION_SLOTS.discard(uid) |
| raise |
|
|
| LOGGER.info("try_create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS)) |
| return True |
|
|
|
|
| def create_session(uid): |
| """ |
| 为指定 session key 创建 ProcessSessionProxy。 |
| |
| 超出并发上限时立即失败,不执行排队等待。 |
| """ |
| if not uid: |
| raise ValueError("Session uid cannot be empty") |
|
|
| ready = try_create_session(uid) |
| if not ready: |
| raise RuntimeError("No session slots available") |
| LOGGER.info("create_session uid=%s total_sessions=%s", uid, len(GLOBAL_SESSIONS)) |
| return uid |
|
|
|
|
| def _get_session_concurrency_limit(): |
| from config import SESSION_CONCURRENCY_LIMIT |
|
|
| return SESSION_CONCURRENCY_LIMIT |
|
|
|
|
| def get_task_index(uid): |
| """获取任务索引信息。""" |
| with _state_lock: |
| return TASK_INDEX_MAP.get(uid) |
|
|
|
|
| def set_task_index(uid, task_index, total_tasks): |
| """设置任务索引信息。""" |
| with _state_lock: |
| TASK_INDEX_MAP[uid] = { |
| "task_index": task_index, |
| "total_tasks": total_tasks, |
| } |
|
|
|
|
| def get_ui_phase(uid): |
| """获取 UI 阶段。""" |
| with _state_lock: |
| return UI_PHASE_MAP.get(uid, "watching_demo") |
|
|
|
|
| def set_ui_phase(uid, phase): |
| """设置 UI 阶段。""" |
| with _state_lock: |
| UI_PHASE_MAP[uid] = phase |
|
|
|
|
| def reset_ui_phase(uid): |
| """重置 UI 阶段为初始阶段。""" |
| with _state_lock: |
| UI_PHASE_MAP[uid] = "watching_demo" |
|
|
|
|
| def set_play_button_clicked(uid, clicked=True): |
| """设置播放按钮是否已被点击。""" |
| with _state_lock: |
| PLAY_BUTTON_CLICKED[uid] = clicked |
|
|
|
|
| def get_play_button_clicked(uid): |
| """获取播放按钮是否已被点击。""" |
| with _state_lock: |
| return PLAY_BUTTON_CLICKED.get(uid, False) |
|
|
|
|
| def reset_play_button_clicked(uid): |
| """重置播放按钮点击状态。""" |
| with _state_lock: |
| PLAY_BUTTON_CLICKED.pop(uid, None) |
|
|
|
|
| def _get_task_key(uid, env_id, episode_idx): |
| return f"{uid}:{env_id}:{episode_idx}" |
|
|
|
|
| def get_execute_count(uid, env_id, episode_idx): |
| """获取指定任务的 execute 次数。""" |
| with _state_lock: |
| task_key = _get_task_key(uid, env_id, episode_idx) |
| return EXECUTE_COUNTS.get(task_key, 0) |
|
|
|
|
| def increment_execute_count(uid, env_id, episode_idx): |
| """增加指定任务的 execute 次数。""" |
| with _state_lock: |
| task_key = _get_task_key(uid, env_id, episode_idx) |
| current_count = EXECUTE_COUNTS.get(task_key, 0) |
| EXECUTE_COUNTS[task_key] = current_count + 1 |
| return EXECUTE_COUNTS[task_key] |
|
|
|
|
| def reset_execute_count(uid, env_id, episode_idx): |
| """重置指定任务的 execute 次数为 0。""" |
| with _state_lock: |
| task_key = _get_task_key(uid, env_id, episode_idx) |
| EXECUTE_COUNTS[task_key] = 0 |
|
|
|
|
| def get_task_start_time(uid, env_id, episode_idx): |
| """获取指定任务的开始时间。""" |
| with _state_lock: |
| task_key = _get_task_key(uid, env_id, episode_idx) |
| return TASK_START_TIMES.get(task_key) |
|
|
|
|
| def set_task_start_time(uid, env_id, episode_idx, start_time): |
| """设置指定任务的开始时间。""" |
| with _state_lock: |
| task_key = _get_task_key(uid, env_id, episode_idx) |
| TASK_START_TIMES[task_key] = start_time |
|
|
|
|
| def clear_task_start_time(uid, env_id, episode_idx): |
| """清除指定任务的开始时间记录。""" |
| with _state_lock: |
| task_key = _get_task_key(uid, env_id, episode_idx) |
| TASK_START_TIMES.pop(task_key, None) |
|
|
|
|
| def cleanup_session(uid): |
| """清理指定会话的所有资源。""" |
| if not uid: |
| return |
|
|
| session = None |
| task_prefix = f"{uid}:" |
|
|
| with _state_lock: |
| session = GLOBAL_SESSIONS.pop(uid, None) |
| TASK_INDEX_MAP.pop(uid, None) |
| UI_PHASE_MAP.pop(uid, None) |
| PLAY_BUTTON_CLICKED.pop(uid, None) |
|
|
| execute_keys = [key for key in EXECUTE_COUNTS if key.startswith(task_prefix)] |
| task_start_keys = [key for key in TASK_START_TIMES if key.startswith(task_prefix)] |
|
|
| for key in execute_keys: |
| del EXECUTE_COUNTS[key] |
| for key in task_start_keys: |
| del TASK_START_TIMES[key] |
|
|
| if session is not None: |
| try: |
| LOGGER.info("cleanup_session uid=%s closing ProcessSessionProxy", uid) |
| session.close() |
| LOGGER.info("cleanup_session uid=%s proxy closed", uid) |
| except Exception as exc: |
| LOGGER.exception("cleanup_session uid=%s proxy close failed: %s", uid, exc) |
| release_session_slot(uid) |
|
|
| from user_manager import user_manager |
|
|
| user_manager.cleanup_session(uid) |
| LOGGER.info("cleanup_session uid=%s done", uid) |
|
|