File size: 8,610 Bytes
8c546c2 c827431 8c546c2 b0bb325 8c546c2 077760c 8c546c2 077760c 8c546c2 077760c 8c546c2 d421c67 8c546c2 0d14afe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | # env.py
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from collections import defaultdict
from collections.abc import Callable, Sequence, Mapping
from functools import partial
from typing import Any
# RoboCasa ์ ์ฉ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS
OBS_STATE_DIM = 16
ACTION_DIM = 12
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
def convert_state(dict_state):
"""์๋ฎฌ๋ ์ดํฐ ์ํ๋ฅผ LeRobot์ด ๊ธฐ๋ํ๋ ํํ๋ก ๋ณํ(Conversion)ํฉ๋๋ค."""
dict_state = dict_state.copy()
final_state = np.concatenate([
dict_state["state.base_position"],
dict_state["state.base_rotation"],
dict_state["state.end_effector_position_relative"],
dict_state["state.end_effector_rotation_relative"],
dict_state["state.gripper_qpos"],
], axis=0)
return final_state
def convert_action(action):
"""LeRobot์ ์ก์
์ ์๋ฎฌ๋ ์ดํฐ๊ฐ ์ดํดํ๋ dict ํํ๋ก ๋ณํํฉ๋๋ค."""
action = action.copy()
output_action = {
"action.base_motion": action[0:4],
"action.control_mode": action[4:5],
"action.end_effector_position": action[5:8],
"action.end_effector_rotation": action[8:11],
"action.gripper_close": action[11:12],
}
return output_action
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""์นด๋ฉ๋ผ ์ด๋ฆ์ ๋ฆฌ์คํธ ํํ๋ก ์ ๊ทํ(Normalization)ํฉ๋๋ค."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()]
else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
if not cams:
raise ValueError("camera_name resolved to an empty list.")
return cams
class RoboCasaEnv(RoboCasaGymEnv):
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
def __init__(
self,
task: str,
camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
render_mode: str = "rgb_array",
obs_type: str = "pixels_agent_pos",
observation_width: int = 256,
observation_height: int = 256,
split: str | None = None,
**kwargs
):
self.obs_type = obs_type
self.render_mode = render_mode
self.split = split
self.task = task
kwargs.pop("fps", None)
self.kwargs = kwargs
meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
try:
self._max_episode_steps = meta_info[task]['horizon']
except KeyError:
raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
super().__init__(
task,
camera_names=camera_name,
camera_widths=observation_width,
camera_heights=observation_height,
enable_render=(render_mode is not None),
split=split,
**kwargs
)
def _create_obs_and_action_space(self):
images = {}
for cam in self.camera_names:
images[cam] = spaces.Box(
low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8
)
if self.obs_type == "state":
raise NotImplementedError("The 'state' observation type is not supported.")
elif self.obs_type == "pixels":
self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict({
"pixels": spaces.Dict(images),
"agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
})
else:
raise ValueError(f"Unknown obs_type: {self.obs_type}")
self.action_space = spaces.Box(
low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32
)
def reset(self, seed: int | None = None, **kwargs):
self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
observation, info = super().reset(seed, **kwargs)
return self._format_raw_obs(observation), info
def _format_raw_obs(self, raw_obs: dict):
new_obs = {}
if self.obs_type == "pixels_agent_pos":
new_obs["agent_pos"] = convert_state(raw_obs)
new_obs["pixels"] = {}
for k, v in raw_obs.items():
if "video." in k:
new_obs["pixels"][k.replace("video.", "")] = v
return new_obs
def step(self, action: np.ndarray):
self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
action_dict = convert_action(action)
observation, reward, done, truncated, info = super().step(action_dict)
new_obs = self._format_raw_obs(observation)
is_success = bool(info.get("success", 0))
terminated = done or is_success
info.update({"task": self.task, "done": done, "is_success": is_success})
if terminated:
info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
self.reset()
return new_obs, reward, terminated, truncated, info
def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]):
def _make_env(episode_index: int, **kwargs):
seed = kwargs.pop("seed", episode_index)
return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs)
return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
# ======================================================================
# LeRobot Hub ํ์ ์ง์
์ (Entry Point)
# ======================================================================
def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
"""
LeRobot์ด Hub์์ ํ๊ฒฝ์ ๋ก๋ํ ๋ ํธ์ถํ๋ ๋ฉ์ธ ํจ์์
๋๋ค.
"""
# ํ๊ฒฝ ๋ํผ ํด๋์ค ์ ํ
env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
# ์ค์ ๊ฐ ์ถ์ถ (cfg ๊ฐ์ฒด๊ฐ ์์ผ๋ฉด ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ์ ์ฉ)
if cfg is not None:
task_name = getattr(cfg, "task", "CloseFridge")
fps = getattr(cfg, "fps", 20) # fps ์ถ์ถ
gym_kwargs = {
"obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
"render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์ ์ง
"observation_width": getattr(cfg, "observation_width", 256),
"observation_height": getattr(cfg, "observation_height", 256),
"camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
"split": getattr(cfg, "split", None),
"fps": fps, # ํต์ฌ ์ธ์ ๋๋ฝ ๋ฐฉ์ง
}
else:
# cfg ์์ด ์ง์ ํธ์ถ๋ ๋์ ๊ธฐ๋ณธ๊ฐ
task_name = "CloseFridge"
gym_kwargs = {
"obs_type": "pixels_agent_pos",
"render_mode": "rgb_array",
"observation_width": 256,
"observation_height": 256,
"camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
"split": None,
}
parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
# ๋ฒค์น๋งํฌ์ธ์ง ๋จ์ผ ํ์คํฌ์ธ์ง ๊ตฌ๋ถ
if task_name in combined_tasks:
task_names = combined_tasks[task_name]
gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
else:
task_names = [t.strip() for t in task_name.split(",")]
out = defaultdict(dict)
# ํ์คํฌ๋ณ๋ก ํ๊ฒฝ ์์ฑ
for task in task_names:
fns = _make_env_fns(
task_name=task,
n_envs=n_envs,
camera_names=parsed_camera_names,
gym_kwargs=gym_kwargs
)
out[task][0] = env_cls(fns)
# {suite_name: {task_id: VectorEnv}} ํํ๋ก ๋ฐํ
#return {"robocasa": dict(out)}
return {suite: dict(task_map) for suite, task_map in out.items()} |