File size: 8,610 Bytes
8c546c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c827431
 
 
8c546c2
 
 
 
 
 
 
 
 
 
 
 
b0bb325
8c546c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
077760c
8c546c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
077760c
8c546c2
 
 
 
077760c
 
 
 
 
 
 
 
 
 
8c546c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d421c67
8c546c2
 
 
 
 
 
 
 
 
 
 
 
 
 
0d14afe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# env.py
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from collections import defaultdict
from collections.abc import Callable, Sequence, Mapping
from functools import partial
from typing import Any

# RoboCasa ์ „์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS

OBS_STATE_DIM = 16
ACTION_DIM = 12
ACTION_LOW = -1.0
ACTION_HIGH = 1.0

def convert_state(dict_state): 
    """์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ ์ƒํƒœ๋ฅผ LeRobot์ด ๊ธฐ๋Œ€ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜(Conversion)ํ•ฉ๋‹ˆ๋‹ค."""
    dict_state = dict_state.copy()
    final_state = np.concatenate([
        dict_state["state.base_position"],
        dict_state["state.base_rotation"],
        dict_state["state.end_effector_position_relative"],
        dict_state["state.end_effector_rotation_relative"],
        dict_state["state.gripper_qpos"],
    ], axis=0)
    return final_state

def convert_action(action):
    """LeRobot์˜ ์•ก์…˜์„ ์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ๊ฐ€ ์ดํ•ดํ•˜๋Š” dict ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
    action = action.copy()
    output_action = {
        "action.base_motion": action[0:4],
        "action.control_mode": action[4:5],
        "action.end_effector_position": action[5:8],
        "action.end_effector_rotation": action[8:11],
        "action.gripper_close": action[11:12],
    }
    return output_action

def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
    """์นด๋ฉ”๋ผ ์ด๋ฆ„์„ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ •๊ทœํ™”(Normalization)ํ•ฉ๋‹ˆ๋‹ค."""
    if isinstance(camera_name, str):
        cams = [c.strip() for c in camera_name.split(",") if c.strip()]
    elif isinstance(camera_name, (list, tuple)):
        cams = [str(c).strip() for c in camera_name if str(c).strip()]
    else:
        raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
    if not cams:
        raise ValueError("camera_name resolved to an empty list.")
    return cams

class RoboCasaEnv(RoboCasaGymEnv):
    metadata = {"render_modes": ["rgb_array"], "render_fps": 20}

    def __init__(
        self,
        task: str,
        camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
        render_mode: str = "rgb_array",
        obs_type: str = "pixels_agent_pos",
        observation_width: int = 256,
        observation_height: int = 256,
        split: str | None = None,
        **kwargs
    ):
        self.obs_type = obs_type
        self.render_mode = render_mode
        self.split = split
        self.task = task
        
        kwargs.pop("fps", None)
        self.kwargs = kwargs

        meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
        try:
            self._max_episode_steps = meta_info[task]['horizon']
        except KeyError:
            raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
        
        super().__init__(
            task,
            camera_names=camera_name,
            camera_widths=observation_width,
            camera_heights=observation_height,
            enable_render=(render_mode is not None),
            split=split,
            **kwargs
        )
    
    def _create_obs_and_action_space(self):
        images = {}
        for cam in self.camera_names:
            images[cam] = spaces.Box(
                low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8
            )
        if self.obs_type == "state":
            raise NotImplementedError("The 'state' observation type is not supported.")
        elif self.obs_type == "pixels":
            self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
        elif self.obs_type == "pixels_agent_pos":
            self.observation_space = spaces.Dict({
                "pixels": spaces.Dict(images),
                "agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
            })
        else:
            raise ValueError(f"Unknown obs_type: {self.obs_type}")

        self.action_space = spaces.Box(
            low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32
        )

    def reset(self, seed: int | None = None, **kwargs):
        self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
        observation, info = super().reset(seed, **kwargs)
        return self._format_raw_obs(observation), info 
    
    def _format_raw_obs(self, raw_obs: dict):
        new_obs = {}
        if self.obs_type == "pixels_agent_pos":
            new_obs["agent_pos"] = convert_state(raw_obs)
        new_obs["pixels"] = {}
        for k, v in raw_obs.items():
            if "video." in k:
                new_obs["pixels"][k.replace("video.", "")] = v
        return new_obs

    def step(self, action: np.ndarray):
        self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
        action_dict = convert_action(action)
        observation, reward, done, truncated, info = super().step(action_dict)
        new_obs = self._format_raw_obs(observation)

        is_success = bool(info.get("success", 0))
        terminated = done or is_success
        info.update({"task": self.task, "done": done, "is_success": is_success})
        
        if terminated:
            info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
            self.reset()

        return new_obs, reward, terminated, truncated, info


def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]):
    def _make_env(episode_index: int, **kwargs):
        seed = kwargs.pop("seed", episode_index)
        return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs)

    return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]


# ======================================================================
# LeRobot Hub ํ•„์ˆ˜ ์ง„์ž…์ (Entry Point)
# ======================================================================
def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
    """
    LeRobot์ด Hub์—์„œ ํ™˜๊ฒฝ์„ ๋กœ๋“œํ•  ๋•Œ ํ˜ธ์ถœํ•˜๋Š” ๋ฉ”์ธ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
    """
    # ํ™˜๊ฒฝ ๋ž˜ํผ ํด๋ž˜์Šค ์„ ํƒ
    env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv

    # ์„ค์ •๊ฐ’ ์ถ”์ถœ (cfg ๊ฐ์ฒด๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์ ์šฉ)
    if cfg is not None:
        task_name = getattr(cfg, "task", "CloseFridge")
        fps = getattr(cfg, "fps", 20)  # fps ์ถ”์ถœ
        gym_kwargs = {
            "obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
            "render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์œ ์ง€
            "observation_width": getattr(cfg, "observation_width", 256),
            "observation_height": getattr(cfg, "observation_height", 256),
            "camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
            "split": getattr(cfg, "split", None),
            "fps": fps,  # ํ•ต์‹ฌ ์ธ์ž ๋ˆ„๋ฝ ๋ฐฉ์ง€
        }
    else:
        # cfg ์—†์ด ์ง์ ‘ ํ˜ธ์ถœ๋  ๋•Œ์˜ ๊ธฐ๋ณธ๊ฐ’
        task_name = "CloseFridge"
        gym_kwargs = {
            "obs_type": "pixels_agent_pos",
            "render_mode": "rgb_array",
            "observation_width": 256,
            "observation_height": 256,
            "camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
            "split": None,
        }

    parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
    combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}

    # ๋ฒค์น˜๋งˆํฌ์ธ์ง€ ๋‹จ์ผ ํƒœ์Šคํฌ์ธ์ง€ ๊ตฌ๋ถ„
    if task_name in combined_tasks:
        task_names = combined_tasks[task_name]
        gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
    else:
        task_names = [t.strip() for t in task_name.split(",")]
        

    out = defaultdict(dict)

    # ํƒœ์Šคํฌ๋ณ„๋กœ ํ™˜๊ฒฝ ์ƒ์„ฑ
    for task in task_names:
        fns = _make_env_fns(
            task_name=task,
            n_envs=n_envs,
            camera_names=parsed_camera_names,
            gym_kwargs=gym_kwargs
        )
        out[task][0] = env_cls(fns)

    # {suite_name: {task_id: VectorEnv}} ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜
    #return {"robocasa": dict(out)}
    return {suite: dict(task_map) for suite, task_map in out.items()}