# env.py import gymnasium as gym from gymnasium import spaces import numpy as np from collections import defaultdict from collections.abc import Callable, Sequence, Mapping from functools import partial from typing import Any # RoboCasa 전용 라이브러리 임포트 from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS OBS_STATE_DIM = 16 ACTION_DIM = 12 ACTION_LOW = -1.0 ACTION_HIGH = 1.0 def convert_state(dict_state): """시뮬레이터 상태를 LeRobot이 기대하는 형태로 변환(Conversion)합니다.""" dict_state = dict_state.copy() final_state = np.concatenate([ dict_state["state.base_position"], dict_state["state.base_rotation"], dict_state["state.end_effector_position_relative"], dict_state["state.end_effector_rotation_relative"], dict_state["state.gripper_qpos"], ], axis=0) return final_state def convert_action(action): """LeRobot의 액션을 시뮬레이터가 이해하는 dict 형태로 변환합니다.""" action = action.copy() output_action = { "action.base_motion": action[0:4], "action.control_mode": action[4:5], "action.end_effector_position": action[5:8], "action.end_effector_rotation": action[8:11], "action.gripper_close": action[11:12], } return output_action def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: """카메라 이름을 리스트 형태로 정규화(Normalization)합니다.""" if isinstance(camera_name, str): cams = [c.strip() for c in camera_name.split(",") if c.strip()] elif isinstance(camera_name, (list, tuple)): cams = [str(c).strip() for c in camera_name if str(c).strip()] else: raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") if not cams: raise ValueError("camera_name resolved to an empty list.") return cams class RoboCasaEnv(RoboCasaGymEnv): metadata = {"render_modes": ["rgb_array"], "render_fps": 20} def __init__( self, task: str, camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"], render_mode: str = "rgb_array", obs_type: str = "pixels_agent_pos", observation_width: int = 256, observation_height: int = 256, split: str | None = None, **kwargs ): self.obs_type = obs_type self.render_mode = render_mode self.split = split self.task = task kwargs.pop("fps", None) self.kwargs = kwargs meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS} try: self._max_episode_steps = meta_info[task]['horizon'] except KeyError: raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}") super().__init__( task, camera_names=camera_name, camera_widths=observation_width, camera_heights=observation_height, enable_render=(render_mode is not None), split=split, **kwargs ) def _create_obs_and_action_space(self): images = {} for cam in self.camera_names: images[cam] = spaces.Box( low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8 ) if self.obs_type == "state": raise NotImplementedError("The 'state' observation type is not supported.") elif self.obs_type == "pixels": self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)}) elif self.obs_type == "pixels_agent_pos": self.observation_space = spaces.Dict({ "pixels": spaces.Dict(images), "agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32), }) else: raise ValueError(f"Unknown obs_type: {self.obs_type}") self.action_space = spaces.Box( low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32 ) def reset(self, seed: int | None = None, **kwargs): self.unwrapped.sim._render_context_offscreen.gl_ctx.free() observation, info = super().reset(seed, **kwargs) return self._format_raw_obs(observation), info def _format_raw_obs(self, raw_obs: dict): new_obs = {} if self.obs_type == "pixels_agent_pos": new_obs["agent_pos"] = convert_state(raw_obs) new_obs["pixels"] = {} for k, v in raw_obs.items(): if "video." in k: new_obs["pixels"][k.replace("video.", "")] = v return new_obs def step(self, action: np.ndarray): self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current() action_dict = convert_action(action) observation, reward, done, truncated, info = super().step(action_dict) new_obs = self._format_raw_obs(observation) is_success = bool(info.get("success", 0)) terminated = done or is_success info.update({"task": self.task, "done": done, "is_success": is_success}) if terminated: info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)} self.reset() return new_obs, reward, terminated, truncated, info def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]): def _make_env(episode_index: int, **kwargs): seed = kwargs.pop("seed", episode_index) return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs) return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)] # ====================================================================== # LeRobot Hub 필수 진입점(Entry Point) # ====================================================================== def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]: """ LeRobot이 Hub에서 환경을 로드할 때 호출하는 메인 함수입니다. """ # 환경 래퍼 클래스 선택 env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv # 설정값 추출 (cfg 객체가 있으면 사용하고, 없으면 기본값 적용) if cfg is not None: task_name = getattr(cfg, "task", "CloseFridge") fps = getattr(cfg, "fps", 20) # fps 추출 gym_kwargs = { "obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"), "render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode 유지 "observation_width": getattr(cfg, "observation_width", 256), "observation_height": getattr(cfg, "observation_height", 256), "camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"), "split": getattr(cfg, "split", None), "fps": fps, # 핵심 인자 누락 방지 } else: # cfg 없이 직접 호출될 때의 기본값 task_name = "CloseFridge" gym_kwargs = { "obs_type": "pixels_agent_pos", "render_mode": "rgb_array", "observation_width": 256, "observation_height": 256, "camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right", "split": None, } parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name")) combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS} # 벤치마크인지 단일 태스크인지 구분 if task_name in combined_tasks: task_names = combined_tasks[task_name] gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain" else: task_names = [t.strip() for t in task_name.split(",")] out = defaultdict(dict) # 태스크별로 환경 생성 for task in task_names: fns = _make_env_fns( task_name=task, n_envs=n_envs, camera_names=parsed_camera_names, gym_kwargs=gym_kwargs ) out[task][0] = env_cls(fns) # {suite_name: {task_id: VectorEnv}} 형태로 반환 #return {"robocasa": dict(out)} return {suite: dict(task_map) for suite, task_map in out.items()}