| """ |
| 多进程会话管理模块 |
| |
| 本模块实现了多进程架构,将每个用户的 OracleSession 运行在独立的工作进程中。 |
| 这样可以确保重计算任务不会阻塞主进程,多个用户可以并发使用系统。 |
| |
| 架构说明: |
| 1. ProcessSessionProxy: 主进程中的代理类,提供与 OracleSession 相同的接口 |
| 2. session_worker_loop: 工作进程中的循环函数,运行实际的 OracleSession |
| 3. 进程间通信:通过 multiprocessing.Queue 进行命令和结果的传递 |
| 4. 视频帧同步:工作进程产生的新帧通过 stream_queue 推送到主进程,由后台线程同步到代理的本地缓存 |
| """ |
| import logging |
| import multiprocessing |
| import queue |
| import threading |
| import time |
| import sys |
| import os |
|
|
| |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| parent_dir = os.path.dirname(current_dir) |
| src_dir = os.path.join(parent_dir, "src") |
| if parent_dir not in sys.path: |
| sys.path.insert(0, parent_dir) |
| if src_dir not in sys.path: |
| sys.path.insert(0, src_dir) |
|
|
| from oracle_logic import OracleSession, DEFAULT_DATASET_ROOT |
|
|
| |
| try: |
| from robomme.robomme_env.utils.planner_fail_safe import ScrewPlanFailure |
| except ImportError: |
| |
| ScrewPlanFailure = RuntimeError |
|
|
| |
| class ScrewPlanFailureError(RuntimeError): |
| """Exception raised when screw plan fails, to be caught and displayed via gr.Info popup""" |
| pass |
|
|
| |
| CMD_LOAD_EPISODE = "load_episode" |
| CMD_UPDATE_OBSERVATION = "update_observation" |
| CMD_GET_PIL_IMAGE = "get_pil_image" |
| CMD_EXECUTE_ACTION = "execute_action" |
| CMD_GET_REFERENCE_ACTION = "get_reference_action" |
| CMD_CLOSE = "close" |
| LOGGER = logging.getLogger("robomme.process_session") |
|
|
|
|
| def _setup_worker_logging(): |
| level_name = os.getenv("LOG_LEVEL", "DEBUG").upper() |
| level = getattr(logging, level_name, logging.DEBUG) |
| logging.basicConfig( |
| level=level, |
| format=( |
| "%(asctime)s | %(levelname)s | %(name)s | " |
| "pid=%(process)d tid=%(threadName)s | %(message)s" |
| ), |
| stream=sys.stdout, |
| force=True, |
| ) |
| for noisy_logger in [ |
| "asyncio", |
| "httpx", |
| "httpcore", |
| "urllib3", |
| "matplotlib", |
| "PIL", |
| "h5py", |
| "trimesh", |
| "toppra", |
| ]: |
| logging.getLogger(noisy_logger).setLevel(logging.WARNING) |
| logging.getLogger("robomme").setLevel(logging.DEBUG) |
| LOGGER.debug("worker logging initialized level=%s", level_name) |
|
|
| def _sanitize_options(options): |
| """ |
| 清理选项数据,移除不可序列化的项(如 'solve' 函数) |
| |
| 在跨进程通信时,需要确保所有数据都可以被 pickle 序列化。 |
| raw_solve_options 中包含的 'solve' 函数无法序列化,需要移除。 |
| 'available' 字段可能是复杂对象,需要转换为简单的布尔值。 |
| |
| Args: |
| options: 原始选项列表 |
| |
| Returns: |
| list: 清理后的选项列表 |
| """ |
| clean_opts = [] |
| if not options: |
| return clean_opts |
| for opt in options: |
| clean_opt = opt.copy() |
| if "solve" in clean_opt: |
| del clean_opt["solve"] |
| if "available" in clean_opt: |
| |
| clean_opt["available"] = bool(clean_opt["available"]) |
| clean_opts.append(clean_opt) |
| return clean_opts |
|
|
| def session_worker_loop(cmd_queue, result_queue, stream_queue, dataset_root, gui_render): |
| """ |
| 工作进程主循环 |
| |
| 此函数在工作进程中运行,负责: |
| 1. 初始化 OracleSession 实例 |
| 2. 监听来自主进程的命令(通过 cmd_queue) |
| 3. 执行命令并返回结果(通过 result_queue) |
| 4. 监控视频帧变化,将新帧推送到流队列(通过 stream_queue) |
| 5. 处理异常和清理资源 |
| |
| Args: |
| cmd_queue: 命令队列,主进程发送命令到此队列 |
| result_queue: 结果队列,工作进程返回命令执行结果到此队列 |
| stream_queue: 流队列,工作进程推送新视频帧到此队列 |
| dataset_root: 数据集根目录路径 |
| gui_render: 是否使用GUI渲染模式 |
| """ |
| _setup_worker_logging() |
| session = None |
| try: |
| LOGGER.info( |
| "worker loop starting dataset_root=%s gui_render=%s", |
| dataset_root, |
| gui_render, |
| ) |
| session = OracleSession(dataset_root=dataset_root, gui_render=gui_render) |
| session.stream_frame_callback = lambda frames: stream_queue.put({"base": frames, "wrist": []}) |
| LOGGER.info("worker OracleSession initialized") |
| |
| while True: |
| try: |
| |
| cmd_data = cmd_queue.get(timeout=0.1) |
| except queue.Empty: |
| continue |
| |
| cmd = cmd_data["cmd"] |
| args = cmd_data.get("args", []) |
| kwargs = cmd_data.get("kwargs", {}) |
| LOGGER.debug("worker received cmd=%s args=%s kwargs_keys=%s", cmd, len(args), list(kwargs.keys())) |
| |
| if cmd == CMD_CLOSE: |
| if session: |
| session.close() |
| LOGGER.info("worker received close command, exiting") |
| break |
| |
| elif cmd == CMD_LOAD_EPISODE: |
| |
| res = session.load_episode(*args, **kwargs) |
| LOGGER.info( |
| "worker load_episode env=%s episode=%s result_msg=%s", |
| getattr(session, "env_id", None), |
| getattr(session, "episode_idx", None), |
| res[1] if isinstance(res, tuple) and len(res) > 1 else None, |
| ) |
| |
| |
| session.last_base_frame_idx = len(session.base_frames) |
| session.last_wrist_frame_idx = len(session.wrist_frames) |
| |
| |
| is_demonstration = False |
| if session.env: |
| is_demonstration = getattr(session.env, 'current_task_demonstration', False) |
| |
| |
| state_update = { |
| "env_id": session.env_id, |
| "episode_idx": session.episode_idx, |
| "language_goal": session.language_goal, |
| "difficulty": session.difficulty, |
| "seed": session.seed, |
| "demonstration_frames": session.demonstration_frames, |
| "last_execution_frames": [], |
| "base_frames": session.base_frames, |
| "wrist_frames": session.wrist_frames, |
| "available_options": session.available_options, |
| "raw_solve_options": _sanitize_options(session.raw_solve_options), |
| "seg_vis": session.seg_vis, |
| "is_demonstration": is_demonstration, |
| "non_demonstration_task_length": session.non_demonstration_task_length |
| } |
| result_queue.put({"status": "success", "result": res, "state": state_update}) |
| |
| elif cmd == CMD_EXECUTE_ACTION: |
| |
| execute_base_start = len(session.base_frames) |
| try: |
| res = session.execute_action(*args, **kwargs) |
| LOGGER.info( |
| "worker execute_action done env=%s episode=%s done=%s", |
| getattr(session, "env_id", None), |
| getattr(session, "episode_idx", None), |
| res[2] if isinstance(res, tuple) and len(res) > 2 else None, |
| ) |
| except ScrewPlanFailure as e: |
| |
| LOGGER.warning("worker screw_plan_failure: %s", e) |
| result_queue.put({"status": "screw_plan_failure", "message": str(e)}) |
| continue |
| except Exception as e: |
| |
| LOGGER.exception("worker execution_error") |
| result_queue.put({"status": "execution_error", "message": str(e)}) |
| continue |
| |
| |
| new_base = session.base_frames[session.last_base_frame_idx:] |
| new_wrist = session.wrist_frames[session.last_wrist_frame_idx:] |
| streamed_count = int(getattr(session, "_execute_streamed_frame_count", 0) or 0) |
| |
| if streamed_count > 0 and new_base: |
| if streamed_count >= len(new_base): |
| new_base = [] |
| else: |
| new_base = new_base[streamed_count:] |
| |
| |
| session.last_base_frame_idx = len(session.base_frames) |
| session.last_wrist_frame_idx = len(session.wrist_frames) |
| |
| |
| if new_base or new_wrist: |
| stream_queue.put({"base": new_base, "wrist": new_wrist}) |
| LOGGER.debug( |
| "worker execute_action streamed frames base=%s wrist=%s", |
| len(new_base), |
| len(new_wrist), |
| ) |
|
|
| |
| is_demonstration = False |
| if session.env: |
| is_demonstration = getattr(session.env, 'current_task_demonstration', False) |
|
|
| execution_frames = session.base_frames[execute_base_start:] |
|
|
| |
| state_update = { |
| "last_execution_frames": execution_frames, |
| "available_options": session.available_options, |
| "raw_solve_options": _sanitize_options(session.raw_solve_options), |
| "seg_vis": session.seg_vis, |
| "is_demonstration": is_demonstration |
| } |
| result_queue.put({"status": "success", "result": res, "state": state_update}) |
|
|
| elif cmd == CMD_GET_PIL_IMAGE: |
| res = session.get_pil_image(*args, **kwargs) |
| result_queue.put({"status": "success", "result": res}) |
| |
| elif cmd == CMD_UPDATE_OBSERVATION: |
| |
| res = session.update_observation(*args, **kwargs) |
| |
| |
| new_base = session.base_frames[session.last_base_frame_idx:] |
| new_wrist = session.wrist_frames[session.last_wrist_frame_idx:] |
| |
| |
| session.last_base_frame_idx = len(session.base_frames) |
| session.last_wrist_frame_idx = len(session.wrist_frames) |
|
|
| |
| if new_base or new_wrist: |
| stream_queue.put({"base": new_base, "wrist": new_wrist}) |
| LOGGER.debug( |
| "worker update_observation streamed frames base=%s wrist=%s", |
| len(new_base), |
| len(new_wrist), |
| ) |
| |
| |
| is_demonstration = False |
| if session.env: |
| is_demonstration = getattr(session.env, 'current_task_demonstration', False) |
| |
| |
| state_update = { |
| "available_options": session.available_options, |
| "raw_solve_options": _sanitize_options(session.raw_solve_options), |
| "seg_vis": session.seg_vis, |
| "is_demonstration": is_demonstration |
| } |
| result_queue.put({"status": "success", "result": res, "state": state_update}) |
|
|
| elif cmd == CMD_GET_REFERENCE_ACTION: |
| res = session.get_reference_action(*args, **kwargs) |
| LOGGER.debug("worker get_reference_action ok=%s", bool(res.get("ok")) if isinstance(res, dict) else None) |
| result_queue.put({"status": "success", "result": res}) |
| |
| else: |
| LOGGER.error("worker unknown command=%s", cmd) |
| result_queue.put({"status": "error", "message": f"Unknown command: {cmd}"}) |
| |
| except Exception as e: |
| LOGGER.exception("worker fatal error") |
| result_queue.put({"status": "fatal", "message": str(e)}) |
|
|
|
|
| class ProcessSessionProxy: |
| """ |
| 进程会话代理类 |
| |
| 此类在主进程中运行,提供与 OracleSession 相同的接口。 |
| 所有方法调用都会被转发到工作进程中的实际 OracleSession 实例。 |
| |
| 主要功能: |
| 1. 启动和管理工作进程 |
| 2. 通过队列与工作进程通信 |
| 3. 维护本地状态缓存(从工作进程同步) |
| 4. 后台线程实时同步视频帧 |
| """ |
| |
| def __init__(self, dataset_root=DEFAULT_DATASET_ROOT, gui_render=False): |
| """ |
| 初始化代理对象 |
| |
| Args: |
| dataset_root: 数据集根目录路径 |
| gui_render: 是否使用GUI渲染模式 |
| """ |
| |
| ctx = multiprocessing.get_context("spawn") |
| |
| |
| self.cmd_queue = ctx.Queue() |
| self.result_queue = ctx.Queue() |
| self.stream_queue = ctx.Queue() |
| |
| |
| self.process = ctx.Process( |
| target=session_worker_loop, |
| args=(self.cmd_queue, self.result_queue, self.stream_queue, dataset_root, gui_render), |
| daemon=True |
| ) |
| self.process.start() |
| LOGGER.info( |
| "ProcessSessionProxy started worker pid=%s dataset_root=%s gui_render=%s", |
| self.process.pid, |
| dataset_root, |
| gui_render, |
| ) |
| |
| |
| self.env_id = None |
| self.episode_idx = None |
| self.language_goal = "" |
| self.difficulty = None |
| self.seed = None |
| self.demonstration_frames = [] |
| self.last_execution_frames = [] |
| self.base_frames = [] |
| self.wrist_frames = [] |
| self.available_options = [] |
| self.raw_solve_options = [] |
| self.seg_vis = None |
| self.is_demonstration = False |
| self.non_demonstration_task_length = None |
| |
| |
| self.stop_sync = False |
| self.sync_thread = threading.Thread(target=self._sync_loop, daemon=True) |
| self.sync_thread.start() |
|
|
| def _sync_loop(self): |
| """ |
| 后台线程循环:从流队列消费视频帧并更新本地缓存 |
| |
| 此线程持续运行,实时接收工作进程推送的新视频帧, |
| 并将其追加到本地的 base_frames 和 wrist_frames 列表中。 |
| UI 刷新逻辑会直接从代理的本地缓存读取帧数据。 |
| """ |
| while not self.stop_sync: |
| try: |
| |
| frames = self.stream_queue.get(timeout=0.1) |
| new_base = frames.get("base", []) |
| new_wrist = frames.get("wrist", []) |
| |
| |
| if new_base: |
| self.base_frames.extend(new_base) |
| if new_wrist: |
| self.wrist_frames.extend(new_wrist) |
| except queue.Empty: |
| continue |
| except Exception: |
| LOGGER.exception("ProcessSessionProxy sync loop crashed") |
| break |
| |
| def _send_cmd(self, cmd, *args, **kwargs): |
| """ |
| 发送命令到工作进程并等待结果 |
| |
| Args: |
| cmd: 命令名称 |
| *args: 位置参数 |
| **kwargs: 关键字参数 |
| |
| Returns: |
| 命令执行结果 |
| |
| Raises: |
| RuntimeError: 工作进程返回错误或致命错误 |
| TimeoutError: 工作进程超时(600秒) |
| """ |
| |
| start_ts = time.time() |
| LOGGER.debug( |
| "proxy send cmd=%s pid=%s args=%s kwargs_keys=%s", |
| cmd, |
| self.process.pid, |
| len(args), |
| list(kwargs.keys()), |
| ) |
| self.cmd_queue.put({"cmd": cmd, "args": args, "kwargs": kwargs}) |
| try: |
| |
| res = self.result_queue.get(timeout=600) |
| elapsed_ms = int((time.time() - start_ts) * 1000) |
| LOGGER.debug( |
| "proxy recv cmd=%s pid=%s status=%s elapsed_ms=%s", |
| cmd, |
| self.process.pid, |
| res.get("status"), |
| elapsed_ms, |
| ) |
| |
| |
| if res.get("status") == "screw_plan_failure": |
| raise ScrewPlanFailureError(f"screw plan failed: {res.get('message', 'unknown error')}") |
| if res.get("status") == "execution_error": |
| raise RuntimeError(f"Execution error: {res.get('message', 'unknown error')}") |
| if res.get("status") == "fatal": |
| raise RuntimeError(f"工作进程致命错误: {res.get('message')}") |
| if res.get("status") == "error": |
| raise RuntimeError(f"命令执行错误: {res.get('message')}") |
| |
| |
| if "state" in res: |
| state = res["state"] |
| for k, v in state.items(): |
| if k in ["base_frames", "wrist_frames"]: |
| |
| |
| if v is not None: |
| setattr(self, k, v) |
| else: |
| |
| setattr(self, k, v) |
| |
| return res.get("result") |
| except queue.Empty: |
| LOGGER.error("proxy command timeout cmd=%s pid=%s", cmd, self.process.pid) |
| raise TimeoutError("工作进程超时") |
|
|
| def load_episode(self, env_id, episode_idx): |
| """ |
| 加载环境episode(在工作进程中执行) |
| |
| Args: |
| env_id: 环境ID |
| episode_idx: episode索引 |
| |
| Returns: |
| tuple: (PIL.Image, str) 图像和状态消息 |
| """ |
| return self._send_cmd(CMD_LOAD_EPISODE, env_id, episode_idx) |
|
|
| def execute_action(self, action_idx, click_coords): |
| """ |
| 执行动作(在工作进程中执行,重计算任务) |
| |
| Args: |
| action_idx: 动作索引 |
| click_coords: 点击坐标 (x, y) 或 None |
| |
| Returns: |
| tuple: (PIL.Image, str, bool) 图像、状态消息、是否完成 |
| """ |
| self.last_execution_frames = [] |
| return self._send_cmd(CMD_EXECUTE_ACTION, action_idx, click_coords) |
| |
| def get_pil_image(self, use_segmented=True): |
| """ |
| 获取PIL图像(在工作进程中执行) |
| |
| Args: |
| use_segmented: 是否使用分割视图 |
| |
| Returns: |
| PIL.Image: 图像对象 |
| """ |
| return self._send_cmd(CMD_GET_PIL_IMAGE, use_segmented=use_segmented) |
| |
| def update_observation(self, use_segmentation=True): |
| """ |
| 更新观察(在工作进程中执行) |
| |
| Args: |
| use_segmentation: 是否使用分割视图 |
| |
| Returns: |
| tuple: (PIL.Image, str) 图像和状态消息 |
| """ |
| return self._send_cmd(CMD_UPDATE_OBSERVATION, use_segmentation=use_segmentation) |
|
|
| def get_reference_action(self): |
| """ |
| 获取当前步参考动作与坐标(在工作进程中执行) |
| |
| Returns: |
| dict: 参考动作结果 |
| """ |
| return self._send_cmd(CMD_GET_REFERENCE_ACTION) |
| |
| def close(self): |
| """ |
| 关闭代理并清理资源 |
| |
| 此方法会: |
| 1. 停止帧同步线程 |
| 2. 发送关闭命令到工作进程 |
| 3. 等待工作进程优雅退出(最多1秒) |
| 4. 如果进程仍在运行,强制终止 |
| """ |
| self.stop_sync = True |
| try: |
| self.cmd_queue.put({"cmd": CMD_CLOSE}) |
| LOGGER.debug("proxy close command sent pid=%s", self.process.pid) |
| except Exception: |
| LOGGER.exception("proxy failed to send close command pid=%s", self.process.pid) |
| pass |
| |
| self.process.join(timeout=1) |
| if self.process.is_alive(): |
| LOGGER.warning("proxy worker still alive after join; terminating pid=%s", self.process.pid) |
| self.process.terminate() |
| LOGGER.info("ProcessSessionProxy closed pid=%s", self.process.pid) |
|
|