""" 多进程会话管理模块 本模块实现了多进程架构,将每个用户的 OracleSession 运行在独立的工作进程中。 这样可以确保重计算任务不会阻塞主进程,多个用户可以并发使用系统。 架构说明: 1. ProcessSessionProxy: 主进程中的代理类,提供与 OracleSession 相同的接口 2. session_worker_loop: 工作进程中的循环函数,运行实际的 OracleSession 3. 进程间通信:通过 multiprocessing.Queue 进行命令和结果的传递 4. 视频帧同步:工作进程产生的新帧通过 stream_queue 推送到主进程,由后台线程同步到代理的本地缓存 """ import logging import multiprocessing import queue import threading import time import sys import os # 添加父目录到路径(逻辑复制自 oracle_logic.py) current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) src_dir = os.path.join(parent_dir, "src") if parent_dir not in sys.path: sys.path.insert(0, parent_dir) if src_dir not in sys.path: sys.path.insert(0, src_dir) from oracle_logic import OracleSession, DEFAULT_DATASET_ROOT # Import ScrewPlanFailure for exception handling try: from robomme.robomme_env.utils.planner_fail_safe import ScrewPlanFailure except ImportError: # Fallback if import fails ScrewPlanFailure = RuntimeError # Custom exception for screw plan failures (to be caught in gradio_callbacks) class ScrewPlanFailureError(RuntimeError): """Exception raised when screw plan fails, to be caught and displayed via gr.Info popup""" pass # 定义命令常量 CMD_LOAD_EPISODE = "load_episode" CMD_UPDATE_OBSERVATION = "update_observation" CMD_GET_PIL_IMAGE = "get_pil_image" CMD_EXECUTE_ACTION = "execute_action" CMD_GET_REFERENCE_ACTION = "get_reference_action" CMD_CLOSE = "close" LOGGER = logging.getLogger("robomme.process_session") def _setup_worker_logging(): level_name = os.getenv("LOG_LEVEL", "DEBUG").upper() level = getattr(logging, level_name, logging.DEBUG) logging.basicConfig( level=level, format=( "%(asctime)s | %(levelname)s | %(name)s | " "pid=%(process)d tid=%(threadName)s | %(message)s" ), stream=sys.stdout, force=True, ) for noisy_logger in [ "asyncio", "httpx", "httpcore", "urllib3", "matplotlib", "PIL", "h5py", "trimesh", "toppra", ]: logging.getLogger(noisy_logger).setLevel(logging.WARNING) logging.getLogger("robomme").setLevel(logging.DEBUG) LOGGER.debug("worker logging initialized level=%s", level_name) def _sanitize_options(options): """ 清理选项数据,移除不可序列化的项(如 'solve' 函数) 在跨进程通信时,需要确保所有数据都可以被 pickle 序列化。 raw_solve_options 中包含的 'solve' 函数无法序列化,需要移除。 'available' 字段可能是复杂对象,需要转换为简单的布尔值。 Args: options: 原始选项列表 Returns: list: 清理后的选项列表 """ clean_opts = [] if not options: return clean_opts for opt in options: clean_opt = opt.copy() if "solve" in clean_opt: del clean_opt["solve"] if "available" in clean_opt: # Only keep truthiness for UI logic clean_opt["available"] = bool(clean_opt["available"]) clean_opts.append(clean_opt) return clean_opts def session_worker_loop(cmd_queue, result_queue, stream_queue, dataset_root, gui_render): """ 工作进程主循环 此函数在工作进程中运行,负责: 1. 初始化 OracleSession 实例 2. 监听来自主进程的命令(通过 cmd_queue) 3. 执行命令并返回结果(通过 result_queue) 4. 监控视频帧变化,将新帧推送到流队列(通过 stream_queue) 5. 处理异常和清理资源 Args: cmd_queue: 命令队列,主进程发送命令到此队列 result_queue: 结果队列,工作进程返回命令执行结果到此队列 stream_queue: 流队列,工作进程推送新视频帧到此队列 dataset_root: 数据集根目录路径 gui_render: 是否使用GUI渲染模式 """ _setup_worker_logging() session = None try: LOGGER.info( "worker loop starting dataset_root=%s gui_render=%s", dataset_root, gui_render, ) session = OracleSession(dataset_root=dataset_root, gui_render=gui_render) session.stream_frame_callback = lambda frames: stream_queue.put({"base": frames, "wrist": []}) LOGGER.info("worker OracleSession initialized") while True: try: # Check for commands cmd_data = cmd_queue.get(timeout=0.1) except queue.Empty: continue cmd = cmd_data["cmd"] args = cmd_data.get("args", []) kwargs = cmd_data.get("kwargs", {}) LOGGER.debug("worker received cmd=%s args=%s kwargs_keys=%s", cmd, len(args), list(kwargs.keys())) if cmd == CMD_CLOSE: if session: session.close() LOGGER.info("worker received close command, exiting") break elif cmd == CMD_LOAD_EPISODE: # 加载环境episode res = session.load_episode(*args, **kwargs) LOGGER.info( "worker load_episode env=%s episode=%s result_msg=%s", getattr(session, "env_id", None), getattr(session, "episode_idx", None), res[1] if isinstance(res, tuple) and len(res) > 1 else None, ) # 更新帧索引跟踪(用于增量同步) session.last_base_frame_idx = len(session.base_frames) session.last_wrist_frame_idx = len(session.wrist_frames) # 获取演示状态(从 DemonstrationWrapper 获取) is_demonstration = False if session.env: is_demonstration = getattr(session.env, 'current_task_demonstration', False) # 构建状态更新(完整同步,因为这是加载操作) state_update = { "env_id": session.env_id, "episode_idx": session.episode_idx, "language_goal": session.language_goal, "difficulty": session.difficulty, "seed": session.seed, "demonstration_frames": session.demonstration_frames, "last_execution_frames": [], "base_frames": session.base_frames, # 加载时完整同步 "wrist_frames": session.wrist_frames, # 加载时完整同步 "available_options": session.available_options, "raw_solve_options": _sanitize_options(session.raw_solve_options), "seg_vis": session.seg_vis, "is_demonstration": is_demonstration, "non_demonstration_task_length": session.non_demonstration_task_length # 同步非demonstration任务长度 } result_queue.put({"status": "success", "result": res, "state": state_update}) elif cmd == CMD_EXECUTE_ACTION: # 执行动作(重计算任务) execute_base_start = len(session.base_frames) try: res = session.execute_action(*args, **kwargs) LOGGER.info( "worker execute_action done env=%s episode=%s done=%s", getattr(session, "env_id", None), getattr(session, "episode_idx", None), res[2] if isinstance(res, tuple) and len(res) > 2 else None, ) except ScrewPlanFailure as e: # 捕获 ScrewPlanFailure 并作为特殊状态传递到主进程,用于显示弹窗 LOGGER.warning("worker screw_plan_failure: %s", e) result_queue.put({"status": "screw_plan_failure", "message": str(e)}) continue except Exception as e: # 捕获所有其他异常并传递到主进程,用于显示弹窗 LOGGER.exception("worker execution_error") result_queue.put({"status": "execution_error", "message": str(e)}) continue # 增量帧同步:只发送新增的帧 new_base = session.base_frames[session.last_base_frame_idx:] new_wrist = session.wrist_frames[session.last_wrist_frame_idx:] streamed_count = int(getattr(session, "_execute_streamed_frame_count", 0) or 0) # Frames already pushed by stream_frame_callback during solve() should not be sent twice. if streamed_count > 0 and new_base: if streamed_count >= len(new_base): new_base = [] else: new_base = new_base[streamed_count:] # 更新帧索引 session.last_base_frame_idx = len(session.base_frames) session.last_wrist_frame_idx = len(session.wrist_frames) # 如果有新帧,推送到流队列 if new_base or new_wrist: stream_queue.put({"base": new_base, "wrist": new_wrist}) LOGGER.debug( "worker execute_action streamed frames base=%s wrist=%s", len(new_base), len(new_wrist), ) # 获取演示状态(从 DemonstrationWrapper 获取) is_demonstration = False if session.env: is_demonstration = getattr(session.env, 'current_task_demonstration', False) execution_frames = session.base_frames[execute_base_start:] # 构建状态更新(只更新选项和分割视图,帧通过流队列同步) state_update = { "last_execution_frames": execution_frames, "available_options": session.available_options, "raw_solve_options": _sanitize_options(session.raw_solve_options), "seg_vis": session.seg_vis, "is_demonstration": is_demonstration } result_queue.put({"status": "success", "result": res, "state": state_update}) elif cmd == CMD_GET_PIL_IMAGE: res = session.get_pil_image(*args, **kwargs) result_queue.put({"status": "success", "result": res}) elif cmd == CMD_UPDATE_OBSERVATION: # 更新观察(获取当前环境状态) res = session.update_observation(*args, **kwargs) # 增量帧同步 new_base = session.base_frames[session.last_base_frame_idx:] new_wrist = session.wrist_frames[session.last_wrist_frame_idx:] # 更新帧索引 session.last_base_frame_idx = len(session.base_frames) session.last_wrist_frame_idx = len(session.wrist_frames) # 如果有新帧,推送到流队列 if new_base or new_wrist: stream_queue.put({"base": new_base, "wrist": new_wrist}) LOGGER.debug( "worker update_observation streamed frames base=%s wrist=%s", len(new_base), len(new_wrist), ) # 获取演示状态(从 DemonstrationWrapper 获取) is_demonstration = False if session.env: is_demonstration = getattr(session.env, 'current_task_demonstration', False) # 构建状态更新 state_update = { "available_options": session.available_options, "raw_solve_options": _sanitize_options(session.raw_solve_options), "seg_vis": session.seg_vis, "is_demonstration": is_demonstration } result_queue.put({"status": "success", "result": res, "state": state_update}) elif cmd == CMD_GET_REFERENCE_ACTION: res = session.get_reference_action(*args, **kwargs) LOGGER.debug("worker get_reference_action ok=%s", bool(res.get("ok")) if isinstance(res, dict) else None) result_queue.put({"status": "success", "result": res}) else: LOGGER.error("worker unknown command=%s", cmd) result_queue.put({"status": "error", "message": f"Unknown command: {cmd}"}) except Exception as e: LOGGER.exception("worker fatal error") result_queue.put({"status": "fatal", "message": str(e)}) class ProcessSessionProxy: """ 进程会话代理类 此类在主进程中运行,提供与 OracleSession 相同的接口。 所有方法调用都会被转发到工作进程中的实际 OracleSession 实例。 主要功能: 1. 启动和管理工作进程 2. 通过队列与工作进程通信 3. 维护本地状态缓存(从工作进程同步) 4. 后台线程实时同步视频帧 """ def __init__(self, dataset_root=DEFAULT_DATASET_ROOT, gui_render=False): """ 初始化代理对象 Args: dataset_root: 数据集根目录路径 gui_render: 是否使用GUI渲染模式 """ # 使用 spawn 上下文以获得更清晰的进程隔离 ctx = multiprocessing.get_context("spawn") # 创建进程间通信队列 self.cmd_queue = ctx.Queue() # 命令队列:主进程 -> 工作进程 self.result_queue = ctx.Queue() # 结果队列:工作进程 -> 主进程 self.stream_queue = ctx.Queue() # 流队列:工作进程 -> 主进程(视频帧) # 启动工作进程 self.process = ctx.Process( target=session_worker_loop, args=(self.cmd_queue, self.result_queue, self.stream_queue, dataset_root, gui_render), daemon=True ) self.process.start() LOGGER.info( "ProcessSessionProxy started worker pid=%s dataset_root=%s gui_render=%s", self.process.pid, dataset_root, gui_render, ) # 本地状态缓存(从工作进程同步) self.env_id = None self.episode_idx = None self.language_goal = "" self.difficulty = None self.seed = None self.demonstration_frames = [] self.last_execution_frames = [] self.base_frames = [] # 由后台同步线程持续更新 self.wrist_frames = [] # 由后台同步线程持续更新 self.available_options = [] self.raw_solve_options = [] self.seg_vis = None self.is_demonstration = False # 演示模式标志 self.non_demonstration_task_length = None # 从工作进程同步的非demonstration任务长度 # 帧同步线程:从流队列接收新帧并更新本地缓存 self.stop_sync = False self.sync_thread = threading.Thread(target=self._sync_loop, daemon=True) self.sync_thread.start() def _sync_loop(self): """ 后台线程循环:从流队列消费视频帧并更新本地缓存 此线程持续运行,实时接收工作进程推送的新视频帧, 并将其追加到本地的 base_frames 和 wrist_frames 列表中。 UI 刷新逻辑会直接从代理的本地缓存读取帧数据。 """ while not self.stop_sync: try: # Use a short timeout to check stop_sync frequently frames = self.stream_queue.get(timeout=0.1) new_base = frames.get("base", []) new_wrist = frames.get("wrist", []) # Append to local lists if new_base: self.base_frames.extend(new_base) if new_wrist: self.wrist_frames.extend(new_wrist) except queue.Empty: continue except Exception: LOGGER.exception("ProcessSessionProxy sync loop crashed") break def _send_cmd(self, cmd, *args, **kwargs): """ 发送命令到工作进程并等待结果 Args: cmd: 命令名称 *args: 位置参数 **kwargs: 关键字参数 Returns: 命令执行结果 Raises: RuntimeError: 工作进程返回错误或致命错误 TimeoutError: 工作进程超时(600秒) """ # 发送命令到工作进程 start_ts = time.time() LOGGER.debug( "proxy send cmd=%s pid=%s args=%s kwargs_keys=%s", cmd, self.process.pid, len(args), list(kwargs.keys()), ) self.cmd_queue.put({"cmd": cmd, "args": args, "kwargs": kwargs}) try: # 等待结果(重任务如加载/执行可能需要较长时间,设置600秒超时) res = self.result_queue.get(timeout=600) elapsed_ms = int((time.time() - start_ts) * 1000) LOGGER.debug( "proxy recv cmd=%s pid=%s status=%s elapsed_ms=%s", cmd, self.process.pid, res.get("status"), elapsed_ms, ) # 检查错误状态并转换为异常,以便在 gradio_callbacks 中捕获并显示弹窗 if res.get("status") == "screw_plan_failure": raise ScrewPlanFailureError(f"screw plan failed: {res.get('message', 'unknown error')}") if res.get("status") == "execution_error": raise RuntimeError(f"Execution error: {res.get('message', 'unknown error')}") if res.get("status") == "fatal": raise RuntimeError(f"工作进程致命错误: {res.get('message')}") if res.get("status") == "error": raise RuntimeError(f"命令执行错误: {res.get('message')}") # 更新本地状态缓存(如果工作进程返回了状态更新) if "state" in res: state = res["state"] for k, v in state.items(): if k in ["base_frames", "wrist_frames"]: # 对于帧数据:只有在显式发送时才替换(如加载时) # 否则由同步循环处理增量更新 if v is not None: setattr(self, k, v) else: # 其他状态直接更新 setattr(self, k, v) return res.get("result") except queue.Empty: LOGGER.error("proxy command timeout cmd=%s pid=%s", cmd, self.process.pid) raise TimeoutError("工作进程超时") def load_episode(self, env_id, episode_idx): """ 加载环境episode(在工作进程中执行) Args: env_id: 环境ID episode_idx: episode索引 Returns: tuple: (PIL.Image, str) 图像和状态消息 """ return self._send_cmd(CMD_LOAD_EPISODE, env_id, episode_idx) def execute_action(self, action_idx, click_coords): """ 执行动作(在工作进程中执行,重计算任务) Args: action_idx: 动作索引 click_coords: 点击坐标 (x, y) 或 None Returns: tuple: (PIL.Image, str, bool) 图像、状态消息、是否完成 """ self.last_execution_frames = [] return self._send_cmd(CMD_EXECUTE_ACTION, action_idx, click_coords) def get_pil_image(self, use_segmented=True): """ 获取PIL图像(在工作进程中执行) Args: use_segmented: 是否使用分割视图 Returns: PIL.Image: 图像对象 """ return self._send_cmd(CMD_GET_PIL_IMAGE, use_segmented=use_segmented) def update_observation(self, use_segmentation=True): """ 更新观察(在工作进程中执行) Args: use_segmentation: 是否使用分割视图 Returns: tuple: (PIL.Image, str) 图像和状态消息 """ return self._send_cmd(CMD_UPDATE_OBSERVATION, use_segmentation=use_segmentation) def get_reference_action(self): """ 获取当前步参考动作与坐标(在工作进程中执行) Returns: dict: 参考动作结果 """ return self._send_cmd(CMD_GET_REFERENCE_ACTION) def close(self): """ 关闭代理并清理资源 此方法会: 1. 停止帧同步线程 2. 发送关闭命令到工作进程 3. 等待工作进程优雅退出(最多1秒) 4. 如果进程仍在运行,强制终止 """ self.stop_sync = True try: self.cmd_queue.put({"cmd": CMD_CLOSE}) LOGGER.debug("proxy close command sent pid=%s", self.process.pid) except Exception: LOGGER.exception("proxy failed to send close command pid=%s", self.process.pid) pass # 等待工作进程优雅退出 self.process.join(timeout=1) if self.process.is_alive(): LOGGER.warning("proxy worker still alive after join; terminating pid=%s", self.process.pid) self.process.terminate() LOGGER.info("ProcessSessionProxy closed pid=%s", self.process.pid)