Upload folder using huggingface_hub
Browse files- CHANGES.md +215 -0
- STATE_ACTION_SPEC.md +83 -0
- backup/configs.py +489 -0
- backup/env.py +248 -0
- configs.py +8 -3
- configs.py.bak +489 -0
- env.py +187 -80
- env.py.bak +248 -0
CHANGES.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Robocasa_Env ๋ณ๊ฒฝ์ฌํญ โ ๊ณต์ GR00T eval ํ๋ฆ๊ณผ ์ ๋ ฌ
|
| 2 |
+
|
| 3 |
+
๋ชฉ์ : `lerobot-eval --env.type=robocasa ...`๊ฐ ๊ณต์
|
| 4 |
+
`Isaac-GR00T/scripts/run_eval.py`์ ๊ฐ์ ๋ฐฉ์์ผ๋ก ๋์ํ๋๋ก ์์ .
|
| 5 |
+
**State/Action ๊ณ์ฝ(`STATE_ACTION_SPEC.md`)์ ๊ทธ๋๋ก ์ ์ง** โ lerobot ์ ์ฑ
(์: ACT)์ด
|
| 6 |
+
12-dim concat action์ ์ถ๋ ฅํ๊ณ , env๊ฐ 16-dim concat `agent_pos`๋ฅผ ๋
ธ์ถํ๋ ์ฝ์์ ๊นจ์ง ์๋๋ค.
|
| 7 |
+
|
| 8 |
+
์์ ํ์ผ: `env.py`, `configs.py` ๋ ๊ฐ. (๋ฐฑ์
: `*.bak`)
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## 1. `env.py` ๋ณ๊ฒฝ ์์ฝ
|
| 13 |
+
|
| 14 |
+
### 1-1. `step()`์์ ์์ฒด `self.reset()` ์ ๊ฑฐ (โ
ํต์ฌ)
|
| 15 |
+
|
| 16 |
+
**Before**
|
| 17 |
+
```python
|
| 18 |
+
if terminated:
|
| 19 |
+
info["final_info"] = {...}
|
| 20 |
+
self.reset() # โ ํ๊ฒฝ์ด ์ค์ค๋ก reset
|
| 21 |
+
return new_obs, reward, terminated, truncated, info
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
**After**
|
| 25 |
+
```python
|
| 26 |
+
# self.reset() ํธ์ถ ์ ๊ฑฐ.
|
| 27 |
+
# gymnasium 0.29+ VectorEnv์ autoreset์ด terminated/truncated๋ฅผ ๋ณด๊ณ
|
| 28 |
+
# final_info๋ฅผ ๋ง๋ค๊ณ ์๋์ผ๋ก ๋ฆฌ์
ํ๋ค. wrapper๊ฐ ํ ๋ฒ ๋ resetํ๋ฉด
|
| 29 |
+
# ์ฒซ obs๊ฐ final obs๋ฅผ ๋ฎ์ด์ฐ๊ณ , lerobot rollout์ final_info["is_success"]๊ฐ
|
| 30 |
+
# ์ ํฉ์ฑ์ ์๋๋ค.
|
| 31 |
+
return new_obs, float(reward), terminated, truncated, info
|
| 32 |
+
```
|
| 33 |
+
- ๊ณต์ GR00T `simulation.py`๋ wrapper ์์์ resetํ์ง ์๊ณ ์ธ๋ถ ๋ฃจํ๊ฐ ์ฒ๋ฆฌํจ.
|
| 34 |
+
- `lerobot_eval.py`์ `rollout()`์ด `info["final_info"][i]["is_success"]`๋ฅผ ์ฝ์ผ๋ฏ๋ก,
|
| 35 |
+
์๋ reset์ ๋งก๊ธฐ๋ ๊ฒ ์ ํํ SR ์ง๊ณ๋ก ์ด์ด์ง.
|
| 36 |
+
|
| 37 |
+
### 1-2. ์ธ์ด ์กฐ๊ฑด(`task`)์ obs๋ก ์ง์ ๋
ธ์ถ
|
| 38 |
+
|
| 39 |
+
`RoboCasaGymEnv.get_observation`์ ์ด๋ฏธ `annotation.human.task_description` ํค๋ก
|
| 40 |
+
ep_meta์ lang์ obs์ ์ฑ์ ์ค๋ค. ์ฌ์ฉ์ ์ฝ๋์ `_format_raw_obs`๋ ์ด๊ฑธ ๋ฒ๋ฆฌ๊ณ
|
| 41 |
+
์์์.
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
def _format_raw_obs(self, raw_obs):
|
| 45 |
+
...
|
| 46 |
+
lang = raw_obs.get("annotation.human.task_description") \
|
| 47 |
+
or self._task_description or self.task
|
| 48 |
+
new_obs["task"] = str(lang)
|
| 49 |
+
self._task_description = str(lang)
|
| 50 |
+
return new_obs
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
- `AsyncVectorEnv`์์๋ `env.call("task_description")`์ด worker process๋ฅผ ๊ฑฐ์ณ ๋น์ธ๊ณ
|
| 54 |
+
ํ์ด๋ฐ ์ด์๊ฐ ์๋ค (README์ *use_async_envs=True ๋ณด๋ฅ: task_description ๋๋ฝ* ์ด์).
|
| 55 |
+
obs์ ์ง์ ๋ฃ์ผ๋ฉด sync/async ๋ชจ๋์์ ๋๊ธฐ์ง ์์.
|
| 56 |
+
- `observation_space`์๋ `"task": spaces.Text(max_length=512)`๋ฅผ ์ถ๊ฐ.
|
| 57 |
+
|
| 58 |
+
### 1-3. horizon์ ๊ณต์ ํฌํผ `get_task_horizon` ์ฌ์ฉ
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
from robocasa.utils.dataset_registry_utils import get_task_horizon
|
| 62 |
+
...
|
| 63 |
+
self._max_episode_steps = int(get_task_horizon(task))
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
- ๊ธฐ์กด `meta_info[task]['horizon']` ์ง์กฐํ์ ๋์ผ ๊ฒฐ๊ณผ์ง๋ง, ๊ณต์ ํจ์์ ๋์ผํ ์ฝ๋ ๊ฒฝ๋ก๋ฅผ ํ๊ฒ ํ๋ค.
|
| 67 |
+
- `lerobot_eval.rollout()`์ด `env.call("_max_episode_steps")[0]`์ ์ฝ์ผ๋ฏ๋ก,
|
| 68 |
+
**batch์๋ horizon์ด ๊ฐ์ task๋ง ๋ฌถ์ด๋๋ก** task๋ณ๋ก ๋ณ๋ VectorEnv๋ฅผ ๋ง๋ ๋ค (๊ธฐ์กด ๊ตฌ์กฐ ์ ์ง).
|
| 69 |
+
|
| 70 |
+
### 1-4. `_resolve_task_list()` ์ ์ค โ benchmark/๋จ์ผ task/๋ค์ค task ๋ชจ๋ ์ฒ๋ฆฌ
|
| 71 |
+
|
| 72 |
+
๊ณต์ `run_eval.py`์ ๋ค์ ๋ก์ง์ ์ฌํ:
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
all_env_names = []
|
| 76 |
+
for task_set in task_set_list:
|
| 77 |
+
all_env_names += TASK_SET_REGISTRY[task_set]
|
| 78 |
+
all_env_names = set(all_env_names)
|
| 79 |
+
for env_name in all_env_names:
|
| 80 |
+
config = SimulationConfig(env_name=f"robocasa/{env_name}", split=split, ...)
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
์์ ํ `make_env`:
|
| 84 |
+
- benchmark ํค(`atomic_seen`, `composite_unseen`, `pretrain50`, ...)๋ฉด sub-task ๋ฆฌ์คํธ๋ก ํผ์น๋ค.
|
| 85 |
+
- ๋จ์ผ task ์ด๋ฆ์ด๋ฉด ๊ทธ๋๋ก ์ฌ์ฉ.
|
| 86 |
+
- ์ฌ๋ฌ ๊ฐ๋ฅผ ๊ณต๋ฐฑ/์ฝค๋ง/๋ฆฌ์คํธ ํํ๋ก ๋ชจ๋ ๋ฐ๋๋ค.
|
| 87 |
+
- `--env.task=atomic_seen composite_unseen composite_seen` (draccus list)
|
| 88 |
+
- `--env.task="atomic_seen,composite_unseen"` (์ฝค๋ง ๋ถ๋ฆฌ)
|
| 89 |
+
- `--env.task=PnPCounterToCab` (๋จ์ผ)
|
| 90 |
+
|
| 91 |
+
### 1-5. `cfg.split` ๋ช
์ ์ ์ฐ์
|
| 92 |
+
|
| 93 |
+
**Before**
|
| 94 |
+
```python
|
| 95 |
+
if task_name in combined_tasks:
|
| 96 |
+
task_names = combined_tasks[task_name]
|
| 97 |
+
gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
|
| 98 |
+
# โ ์ฌ์ฉ์๊ฐ --env.split=pretrain ์ค๋ ๊ฐ์ ๋ฎ์ด์
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
**After**
|
| 102 |
+
```python
|
| 103 |
+
if item in TARGET_TASKS:
|
| 104 |
+
split = explicit_split or "target" # explicit์ด ์์ผ๋ฉด ๊ทธ๊ฒ ์ฌ์ฉ
|
| 105 |
+
elif item in PRETRAINING_TASKS:
|
| 106 |
+
split = explicit_split or "pretrain"
|
| 107 |
+
else:
|
| 108 |
+
pairs.append((item, explicit_split)) # ๋จ์ผ task๋ ๊ทธ๋๋ก
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
์ด์ `--env.task=atomic_seen --env.split=pretrain` ๊ฐ์ ์กฐํฉ์ด ์๋๋๋ก ๋์ํ๋ค
|
| 112 |
+
(๊ณต์ `run_eval.py`๋ `--task_set`๊ณผ `--split`์ ๋
๋ฆฝ ์ธ์๋ก ๋ฐ์).
|
| 113 |
+
|
| 114 |
+
### 1-6. GL context ์๋ ์กฐ์์ ๋ณด์กด (๋จ, try/except)
|
| 115 |
+
|
| 116 |
+
`reset()`์ `gl_ctx.free()`, `step()`์ `make_current()`๋ **์ด์ ๋ถ๋ช
์ ์ฐํ hack**์ผ๋ก ๋ณด์ด์ง๋ง
|
| 117 |
+
์ฌ์ฉ์ ํ๊ฒฝ์์ ์๋์ ์ผ๋ก ์ถ๊ฐ๋ ๊ฒ์ผ ๊ฐ๋ฅ์ฑ์ด ์์ด **๋ณด์กด**.
|
| 118 |
+
๋ค๋ง `try/except`๋ก ๊ฐ์ธ ๋ค๋ฅธ ํ๊ฒฝ์์ ๊นจ์ง์ง ์๋๋ก ํ๋ค.
|
| 119 |
+
|
| 120 |
+
### 1-7. ๊ธฐํ ์ ๋ฆฌ
|
| 121 |
+
|
| 122 |
+
- `convert_state` ๊ฒฐ๊ณผ๋ฅผ `float32`๋ก ๋ช
์ (lerobot tensor ๋ณํ ์์ )
|
| 123 |
+
- `convert_action`์ด list/tuple๋ ๋ฐ๋๋ก `np.asarray`
|
| 124 |
+
- `_format_raw_obs`์ `"video." in k` โ `k.startswith("video.")` (์คํ ๋ฐฉ์ง)
|
| 125 |
+
- info dict์ `task`, `task_description`์ ๋งค step ์ฑ์ (lerobot์ `add_envs_task` fallback)
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## 2. `configs.py` ๋ณ๊ฒฝ ์์ฝ
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
@EnvConfig.register_subclass("robocasa")
|
| 133 |
+
@dataclass
|
| 134 |
+
class RoboCasaEnv(HubEnvConfig):
|
| 135 |
+
hub_path: str = "Whalswp/RoboCasa_Env"
|
| 136 |
+
|
| 137 |
+
# โ
list ํ์ฉ โ ๊ณต์ run_eval์ฒ๋ผ ์ฌ๋ฌ task_set ๋์ ์
๋ ฅ
|
| 138 |
+
task: str | list[str] | None = None
|
| 139 |
+
fps: int = 20 # โ
์ ์ค (env.py๊ฐ cfg.fps ์ฐธ์กฐ)
|
| 140 |
+
obs_type: str = "pixels_agent_pos"
|
| 141 |
+
render_mode: str = "rgb_array"
|
| 142 |
+
camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
|
| 143 |
+
observation_height: int = 256
|
| 144 |
+
observation_width: int = 256
|
| 145 |
+
split: str | None = None # `pretrain` | `target` | `all` | None
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
`features` / `features_map` / `hub_path`๋ ๋ณ๊ฒฝํ์ง ์์ โ STATE_ACTION_SPEC ๊ณ์ฝ ์ ์ง.
|
| 149 |
+
|
| 150 |
+
---
|
| 151 |
+
|
| 152 |
+
## 3. ๊ณต์ GR00T eval โ ์์ ํ lerobot-eval ๋์ํ
|
| 153 |
+
|
| 154 |
+
| ๊ณต์ ์ธ์ (`run_eval.py`) | lerobot-eval ์ธ์ | ๋น๊ณ |
|
| 155 |
+
|--------------------------|-------------------|------|
|
| 156 |
+
| `--model_path` | `--policy.path` | lerobot์ `PreTrainedPolicy` ๋ก๋ ๊ฐ๋ฅํด์ผ ํจ |
|
| 157 |
+
| `--task_set atomic_seen composite_unseen` | `--env.task=atomic_seen composite_unseen` | ์ด๋ฒ ์์ ์ผ๋ก ๋๋ฑ |
|
| 158 |
+
| `--split pretrain` | `--env.split=pretrain` | explicit ์ฐ์ ๋๋๋ก ์์ |
|
| 159 |
+
| `--n_episodes 50` | `--eval.n_episodes=50` | ๋์ผ |
|
| 160 |
+
| `--n_envs 5` | `--eval.batch_size=5` | ๋์ผ |
|
| 161 |
+
| `--video_dir <path>` | `--output_dir <path>/videos` | lerobot์ด `output_dir/videos/{task}_{id}/eval_episode_*.mp4` ์๋ |
|
| 162 |
+
| `--n_action_steps 16` | (ํด๋น ์์) | lerobot ์ ์ฑ
์ด single-step ์ถ๋ ฅ. ACT ๋ฑ lerobot ์ ์ฑ
์ ๋ด๋ถ chunk ์ฒ๋ฆฌ |
|
| 163 |
+
|
| 164 |
+
GR00T ์ ์ฑ
์์ฒด๋ chunk(16-step) ์ถ๋ ฅ์ด๋ผ lerobot์ `select_action` ๋จ์ผ ์คํ
ํ๋ฆ๊ณผ ๋ค๋ฅด๋ค.
|
| 165 |
+
์ด๊ฑด ์ ์ฑ
์ด๋ํฐ ์์ญ์ด๋ผ **์ด๋ฒ ์์ ๋ฒ์ ๋ฐ** (env๋ ์ ์ฑ
-ํ๊ฒฝ ์ฌ์ด์ 12-dim ๋จ์ผ ์คํ
๊ณ์ฝ๋ง ์ฑ
์์ง).
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
## 4. ์ฌ์ฉ ์์
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
# ๋จ์ผ task (sanity)
|
| 173 |
+
lerobot-eval \
|
| 174 |
+
--policy.path=BrunoM42/act_base-robocasa_target_PickPlaceCounterToCabinet \
|
| 175 |
+
--env.type=robocasa \
|
| 176 |
+
--env.task=PickPlaceCounterToCabinet \
|
| 177 |
+
--eval.batch_size=5 --eval.n_episodes=5 \
|
| 178 |
+
--policy.device=cuda --trust_remote_code=true \
|
| 179 |
+
--env.split=pretrain \
|
| 180 |
+
--output_dir /home/seonho/clvla/benchmarks/robocasa365/bench_outputs
|
| 181 |
+
|
| 182 |
+
# ์ฌ๋ฌ benchmark ๋์ (= ๊ณต์ --task_set atomic_seen composite_unseen composite_seen)
|
| 183 |
+
lerobot-eval \
|
| 184 |
+
--policy.path=<...> \
|
| 185 |
+
--env.type=robocasa \
|
| 186 |
+
--env.task="atomic_seen composite_unseen composite_seen" \
|
| 187 |
+
--eval.batch_size=5 --eval.n_episodes=50 \
|
| 188 |
+
--policy.device=cuda --trust_remote_code=true \
|
| 189 |
+
--env.split=pretrain \
|
| 190 |
+
--output_dir /home/seonho/clvla/benchmarks/robocasa365/bench_outputs
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
`out[task][0] = VectorEnv(...)` ๊ตฌ์กฐ์ด๋ฏ๋ก `eval_info.json`์ `per_group`์๋
|
| 194 |
+
**task๋ณ success rate**๊ฐ ๊ทธ๋๋ก ๋จ์ด์ง๊ณ , `overall`์ด ์ ์ฒด ํ๊ท ์ด ๋๋ค โ ๊ณต์
|
| 195 |
+
`get_eval_stats.py`๊ฐ ๋ง๋๋ task_set ํ๊ท ๊ณผ ๋น๊ตํ๊ธฐ ์ฌ์ด ํํ.
|
| 196 |
+
|
| 197 |
+
---
|
| 198 |
+
|
| 199 |
+
## 5. ๋ณ๊ฒฝํ์ง ์์ ๊ฒ (์๋)
|
| 200 |
+
|
| 201 |
+
- **State 16-dim / Action 12-dim concat ๊ณ์ฝ** โ `STATE_ACTION_SPEC.md` ์ ์ง
|
| 202 |
+
- `convert_state`, `convert_action`์ ์ธ๋ฑ์ค/ํค ์ ์
|
| 203 |
+
- `hub_path`, `features`, `features_map`
|
| 204 |
+
- `_create_obs_and_action_space`์ ์นด๋ฉ๋ผ/์ก์
๋ฐ์ค shape
|
| 205 |
+
- `gl_ctx.free()` / `make_current()` ํธ์ถ (์ด์ ๋ถ๋ช
, ์์ ์ฐ์ ๋ณด์กด)
|
| 206 |
+
|
| 207 |
+
## 6. ํ๊ณ / ํ์ ๊ณผ์
|
| 208 |
+
|
| 209 |
+
- GR00T ์ ์ฑ
์ lerobot์์ ๊ทธ๋๋ก ์ฐ๋ ค๋ฉด chunkยทhistory๋ฅผ ์ฒ๋ฆฌํ๋ **์ ์ฑ
์ด๋ํฐ**๊ฐ ๋ณ๋๋ก ํ์.
|
| 210 |
+
`lerobot_eval_gap_analysis.md` ยง3.C ์ฐธ๊ณ . ์ด๋ฒ ์์ ์ ํ๊ฒฝ ์ธก๋ง.
|
| 211 |
+
- lerobot์ `eval_info.json`์ ๊ณต์ `evals/<split>/<env>/stats.json` ํธ๋ฆฌ๋ก
|
| 212 |
+
๋ณํํด์ฃผ๋ ์์ ์คํฌ๋ฆฝํธ๊ฐ ์์ผ๋ฉด `gr00t/eval/get_eval_stats.py`๋ฅผ ๊ทธ๋๋ก ์ฌ์ฌ์ฉ ๊ฐ๋ฅ
|
| 213 |
+
(`lerobot_eval_gap_analysis.md` ยง3.D).
|
| 214 |
+
- Hub(`Whalswp/RoboCasa_Env`)๋ฅผ ์ฐ๋ ๊ฒฝ์ฐ, **์ด ๋ก์ปฌ ๋ณ๊ฒฝ์ hub์ ํธ์**ํด์ผ lerobot์
|
| 215 |
+
`trust_remote_code` ๊ฒฝ๋ก๊ฐ ์ ๋ฒ์ ์ ๋ฐ๋๋ค. ๋ก์ปฌ์์ ์ง์ importํ๋ ๊ฒฝ์ฐ์ ๊ทธ๋๋ก ์ ์ฉ๋จ.
|
STATE_ACTION_SPEC.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RoboCasa State / Action ๋ช
์ธ
|
| 2 |
+
|
| 3 |
+
> ๊ทผ๊ฑฐ ํ์ผ: `env.py`, `gym_wrapper.py` (`PandaOmronKeyConverter`), `robosuite/controllers/parts/arm/osc.py`, `robosuite/controllers/config/robots/default_pandaomron.json`
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## State (์ด 16์ฐจ์)
|
| 8 |
+
|
| 9 |
+
`env.py: convert_state()` ๊ธฐ์ค์ผ๋ก concatenate๋จ.
|
| 10 |
+
|
| 11 |
+
| ์ธ๋ฑ์ค | ์ฐจ์ | ํค | absolute / relative | ํํ |
|
| 12 |
+
|--------|------|----|---------------------|------|
|
| 13 |
+
| 0~2 | 3 | `state.base_position` | **absolute** | xyz |
|
| 14 |
+
| 3~6 | 4 | `state.base_rotation` | **absolute** | **Quaternion** (`robot0_base_quat`) |
|
| 15 |
+
| 7~9 | 3 | `state.end_effector_position_relative` | **relative** (base โ EE) | xyz |
|
| 16 |
+
| 10~13 | 4 | `state.end_effector_rotation_relative` | **relative** (base โ EE) | **Quaternion** (`robot0_base_to_eef_quat`) |
|
| 17 |
+
| 14~15 | 2 | `state.gripper_qpos` | โ | joint position |
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Action (์ด 12์ฐจ์)
|
| 22 |
+
|
| 23 |
+
`env.py: convert_action()` ๊ธฐ์ค์ผ๋ก ๋ถํด๋จ.
|
| 24 |
+
|
| 25 |
+
| ์ธ๋ฑ์ค | ์ฐจ์ | ํค | ์ค๋ช
|
|
| 26 |
+
|--------|------|----|------|
|
| 27 |
+
| 0~3 | 4 | `action.base_motion` | ๋ฒ ์ด์ค ์ด๋ (์๋ ์ฐธ๊ณ ) |
|
| 28 |
+
| 4 | 1 | `action.control_mode` | ์ ์ด ๋ชจ๋ ์ค์์น (์๋ ์ฐธ๊ณ ) |
|
| 29 |
+
| 5~7 | 3 | `action.end_effector_position` | EE delta position, **base frame ๊ธฐ์ค** |
|
| 30 |
+
| 8~10 | 3 | `action.end_effector_rotation` | EE delta rotation, **base frame ๊ธฐ์ค, axis-angle** |
|
| 31 |
+
| 11 | 1 | `action.gripper_close` | ๊ทธ๋ฆฌํผ ๋ซ๊ธฐ (0.5 threshold โ binary) |
|
| 32 |
+
|
| 33 |
+
### base_motion (4์ฐจ์) ์์ธ
|
| 34 |
+
|
| 35 |
+
| ์ธ๋ฑ์ค | ๋์ | controller type | ์ค๋ช
|
|
| 36 |
+
|--------|------|-----------------|------|
|
| 37 |
+
| 0~2 | `robot0_base` | `JOINT_VELOCITY` | ๋ชจ๋ฐ์ผ ๋ฒ ์ด์ค x์๋ / y์๋ / yaw์๋ |
|
| 38 |
+
| 3 | `robot0_torso` | `JOINT_POSITION` | ๋ชธํต ์์ง ๋ฆฌํํธ joint position (โ ๋์ด) |
|
| 39 |
+
|
| 40 |
+
### control_mode (1์ฐจ์) ์์ธ
|
| 41 |
+
|
| 42 |
+
| ๊ฐ | base_mode | ๋์ |
|
| 43 |
+
|----|-----------|------|
|
| 44 |
+
| < 0.5 | -1 | **Arm mode** โ ๋ฒ ์ด์ค ๊ณ ์ , ํ๋ก ์กฐ์ (goal: `achieved` ๊ธฐ์ค) |
|
| 45 |
+
| โฅ 0.5 | +1 | **Base mode** โ ๋ฒ ์ด์ค ์ด๋, ํ ๋ชฉํ ์ ์ง (goal: `desired` ๊ธฐ์ค) |
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## EE Rotation: axis-angle์ ์ฐ๋ ์ด์
|
| 50 |
+
|
| 51 |
+
OSC controller(`osc.py`)๋ rotation input์ `Rotation.from_rotvec()` ์ผ๋ก ํด์ โ **axis-angle ๊ณ ์ **.
|
| 52 |
+
|
| 53 |
+
RPY(Euler angle) ๋์ axis-angle์ ์ฐ๋ ์ด์ :
|
| 54 |
+
|
| 55 |
+
1. **Gimbal lock ์์** โ RPY๋ ํน์ ์์ธ์์ ๋ ์ถ์ด ๊ฒน์ณ DOF๋ฅผ ์๋ singularity ๋ฐ์. EE๋ ์์ ํ์ ํ๋ฏ๋ก ์ค์ ๋ฌธ์ ๊ฐ ๋จ.
|
| 56 |
+
2. **Delta ์ ์ด์ ์์ฐ์ค๋ฌ์** โ "์ด ์ถ ๋ฐฉํฅ์ผ๋ก ฮธ๋งํผ ํ์ " ์๋ฏธ๊ฐ ์ง๊ด์ ์ด๊ณ ๋ณด๊ฐ์ด smooth. RPY delta๋ ์์ ์์กด์ฑ(rollโpitchโyaw) ๋๋ฌธ์ ํฉ์ฑ์ด ๋ณต์กํจ.
|
| 57 |
+
3. **ํฌ๊ธฐ = ํ์ ๋** โ ๋ฒกํฐ norm์ด ํ์ ๊ฐ์ด๋ผ output clipping์ด ์์ฐ์ค๋ฌ์. (`output_max: [0.5, 0.5, 0.5]` rad)
|
| 58 |
+
|
| 59 |
+
> RPY ์
๋ ฅ์ ์ฝ๋์ ์ง์ํ์ง ์์. ํ์ํ๋ฉด wrapper์์ ๋ณํ ํ์:
|
| 60 |
+
> ```python
|
| 61 |
+
> from scipy.spatial.transform import Rotation
|
| 62 |
+
> axis_angle = Rotation.from_euler('xyz', rpy).as_rotvec()
|
| 63 |
+
> ```
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## ์๋ฎฌ๋ ์ด์
์ฐ๊ฒฐ ํ๋ฆ
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
policy output (12-dim)
|
| 71 |
+
โ convert_action() [env.py]
|
| 72 |
+
action dict (base_motion, control_mode, EE_pos, EE_rot, gripper_close)
|
| 73 |
+
โ unmap_action() [gym_wrapper.py]
|
| 74 |
+
{
|
| 75 |
+
robot0_right: concat(EE_pos[3], EE_rot[3]) โ OSC_POSE controller
|
| 76 |
+
robot0_right_gripper: threshold(gripper_close, 0.5) โ -1 or +1
|
| 77 |
+
robot0_base: base_motion[0:3] โ JOINT_VELOCITY controller
|
| 78 |
+
robot0_torso: base_motion[3:4] โ JOINT_POSITION controller
|
| 79 |
+
robot0_base_mode: threshold(control_mode, 0.5) โ -1 or +1
|
| 80 |
+
}
|
| 81 |
+
โ env.step() [robosuite]
|
| 82 |
+
MuJoCo simulation
|
| 83 |
+
```
|
backup/configs.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import abc
|
| 16 |
+
from dataclasses import dataclass, field, fields
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import draccus
|
| 20 |
+
|
| 21 |
+
from lerobot.configs.types import FeatureType, PolicyFeature
|
| 22 |
+
from lerobot.robots import RobotConfig
|
| 23 |
+
from lerobot.teleoperators.config import TeleoperatorConfig
|
| 24 |
+
from lerobot.utils.constants import (
|
| 25 |
+
ACTION,
|
| 26 |
+
LIBERO_KEY_EEF_MAT,
|
| 27 |
+
LIBERO_KEY_EEF_POS,
|
| 28 |
+
LIBERO_KEY_EEF_QUAT,
|
| 29 |
+
LIBERO_KEY_GRIPPER_QPOS,
|
| 30 |
+
LIBERO_KEY_GRIPPER_QVEL,
|
| 31 |
+
LIBERO_KEY_JOINTS_POS,
|
| 32 |
+
LIBERO_KEY_JOINTS_VEL,
|
| 33 |
+
LIBERO_KEY_PIXELS_AGENTVIEW,
|
| 34 |
+
LIBERO_KEY_PIXELS_EYE_IN_HAND,
|
| 35 |
+
OBS_ENV_STATE,
|
| 36 |
+
OBS_IMAGE,
|
| 37 |
+
OBS_IMAGES,
|
| 38 |
+
OBS_STATE,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
| 44 |
+
task: str | None = None
|
| 45 |
+
fps: int = 30
|
| 46 |
+
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
| 47 |
+
features_map: dict[str, str] = field(default_factory=dict)
|
| 48 |
+
max_parallel_tasks: int = 1
|
| 49 |
+
disable_env_checker: bool = True
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def type(self) -> str:
|
| 53 |
+
return self.get_choice_name(self.__class__)
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def package_name(self) -> str:
|
| 57 |
+
"""Package name to import if environment not found in gym registry"""
|
| 58 |
+
return f"gym_{self.type}"
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def gym_id(self) -> str:
|
| 62 |
+
"""ID string used in gym.make() to instantiate the environment"""
|
| 63 |
+
return f"{self.package_name}/{self.task}"
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
@abc.abstractmethod
|
| 67 |
+
def gym_kwargs(self) -> dict:
|
| 68 |
+
raise NotImplementedError()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class HubEnvConfig(EnvConfig):
|
| 73 |
+
"""Base class for environments that delegate creation to a hub-hosted make_env.
|
| 74 |
+
|
| 75 |
+
Hub environments download and execute remote code from the HF Hub.
|
| 76 |
+
The hub_path points to a repository containing an env.py with a make_env function.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
hub_path: str | None = None # required: e.g., "username/repo" or "username/repo@branch:file.py"
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def gym_kwargs(self) -> dict:
|
| 83 |
+
# Not used for hub environments - the hub's make_env handles everything
|
| 84 |
+
return {}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@EnvConfig.register_subclass("aloha")
|
| 88 |
+
@dataclass
|
| 89 |
+
class AlohaEnv(EnvConfig):
|
| 90 |
+
task: str | None = "AlohaInsertion-v0"
|
| 91 |
+
fps: int = 50
|
| 92 |
+
episode_length: int = 400
|
| 93 |
+
obs_type: str = "pixels_agent_pos"
|
| 94 |
+
observation_height: int = 480
|
| 95 |
+
observation_width: int = 640
|
| 96 |
+
render_mode: str = "rgb_array"
|
| 97 |
+
features: dict[str, PolicyFeature] = field(
|
| 98 |
+
default_factory=lambda: {
|
| 99 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
features_map: dict[str, str] = field(
|
| 103 |
+
default_factory=lambda: {
|
| 104 |
+
ACTION: ACTION,
|
| 105 |
+
"agent_pos": OBS_STATE,
|
| 106 |
+
"top": f"{OBS_IMAGE}.top",
|
| 107 |
+
"pixels/top": f"{OBS_IMAGES}.top",
|
| 108 |
+
}
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def __post_init__(self):
|
| 112 |
+
if self.obs_type == "pixels":
|
| 113 |
+
self.features["top"] = PolicyFeature(
|
| 114 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 115 |
+
)
|
| 116 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 117 |
+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
| 118 |
+
self.features["pixels/top"] = PolicyFeature(
|
| 119 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def gym_kwargs(self) -> dict:
|
| 124 |
+
return {
|
| 125 |
+
"obs_type": self.obs_type,
|
| 126 |
+
"render_mode": self.render_mode,
|
| 127 |
+
"max_episode_steps": self.episode_length,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@EnvConfig.register_subclass("pusht")
|
| 132 |
+
@dataclass
|
| 133 |
+
class PushtEnv(EnvConfig):
|
| 134 |
+
task: str | None = "PushT-v0"
|
| 135 |
+
fps: int = 10
|
| 136 |
+
episode_length: int = 300
|
| 137 |
+
obs_type: str = "pixels_agent_pos"
|
| 138 |
+
render_mode: str = "rgb_array"
|
| 139 |
+
visualization_width: int = 384
|
| 140 |
+
visualization_height: int = 384
|
| 141 |
+
observation_height: int = 384
|
| 142 |
+
observation_width: int = 384
|
| 143 |
+
features: dict[str, PolicyFeature] = field(
|
| 144 |
+
default_factory=lambda: {
|
| 145 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
| 146 |
+
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
features_map: dict[str, str] = field(
|
| 150 |
+
default_factory=lambda: {
|
| 151 |
+
ACTION: ACTION,
|
| 152 |
+
"agent_pos": OBS_STATE,
|
| 153 |
+
"environment_state": OBS_ENV_STATE,
|
| 154 |
+
"pixels": OBS_IMAGE,
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def __post_init__(self):
|
| 159 |
+
if self.obs_type == "pixels_agent_pos":
|
| 160 |
+
self.features["pixels"] = PolicyFeature(
|
| 161 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 162 |
+
)
|
| 163 |
+
elif self.obs_type == "environment_state_agent_pos":
|
| 164 |
+
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def gym_kwargs(self) -> dict:
|
| 168 |
+
return {
|
| 169 |
+
"obs_type": self.obs_type,
|
| 170 |
+
"render_mode": self.render_mode,
|
| 171 |
+
"visualization_width": self.visualization_width,
|
| 172 |
+
"visualization_height": self.visualization_height,
|
| 173 |
+
"max_episode_steps": self.episode_length,
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@dataclass
|
| 178 |
+
class ImagePreprocessingConfig:
|
| 179 |
+
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
| 180 |
+
resize_size: tuple[int, int] | None = None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@dataclass
|
| 184 |
+
class RewardClassifierConfig:
|
| 185 |
+
"""Configuration for reward classification."""
|
| 186 |
+
|
| 187 |
+
pretrained_path: str | None = None
|
| 188 |
+
success_threshold: float = 0.5
|
| 189 |
+
success_reward: float = 1.0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@dataclass
|
| 193 |
+
class InverseKinematicsConfig:
|
| 194 |
+
"""Configuration for inverse kinematics processing."""
|
| 195 |
+
|
| 196 |
+
urdf_path: str | None = None
|
| 197 |
+
target_frame_name: str | None = None
|
| 198 |
+
end_effector_bounds: dict[str, list[float]] | None = None
|
| 199 |
+
end_effector_step_sizes: dict[str, float] | None = None
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@dataclass
|
| 203 |
+
class ObservationConfig:
|
| 204 |
+
"""Configuration for observation processing."""
|
| 205 |
+
|
| 206 |
+
add_joint_velocity_to_observation: bool = False
|
| 207 |
+
add_current_to_observation: bool = False
|
| 208 |
+
add_ee_pose_to_observation: bool = False
|
| 209 |
+
display_cameras: bool = False
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclass
|
| 213 |
+
class GripperConfig:
|
| 214 |
+
"""Configuration for gripper control and penalties."""
|
| 215 |
+
|
| 216 |
+
use_gripper: bool = True
|
| 217 |
+
gripper_penalty: float = 0.0
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@dataclass
|
| 221 |
+
class ResetConfig:
|
| 222 |
+
"""Configuration for environment reset behavior."""
|
| 223 |
+
|
| 224 |
+
fixed_reset_joint_positions: Any | None = None
|
| 225 |
+
reset_time_s: float = 5.0
|
| 226 |
+
control_time_s: float = 20.0
|
| 227 |
+
terminate_on_success: bool = True
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dataclass
|
| 231 |
+
class HILSerlProcessorConfig:
|
| 232 |
+
"""Configuration for environment processing pipeline."""
|
| 233 |
+
|
| 234 |
+
control_mode: str = "gamepad"
|
| 235 |
+
observation: ObservationConfig | None = None
|
| 236 |
+
image_preprocessing: ImagePreprocessingConfig | None = None
|
| 237 |
+
gripper: GripperConfig | None = None
|
| 238 |
+
reset: ResetConfig | None = None
|
| 239 |
+
inverse_kinematics: InverseKinematicsConfig | None = None
|
| 240 |
+
reward_classifier: RewardClassifierConfig | None = None
|
| 241 |
+
max_gripper_pos: float | None = 100.0
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@EnvConfig.register_subclass(name="gym_manipulator")
|
| 245 |
+
@dataclass
|
| 246 |
+
class HILSerlRobotEnvConfig(EnvConfig):
|
| 247 |
+
"""Configuration for the HILSerlRobotEnv environment."""
|
| 248 |
+
|
| 249 |
+
robot: RobotConfig | None = None
|
| 250 |
+
teleop: TeleoperatorConfig | None = None
|
| 251 |
+
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
|
| 252 |
+
|
| 253 |
+
name: str = "real_robot"
|
| 254 |
+
|
| 255 |
+
@property
|
| 256 |
+
def gym_kwargs(self) -> dict:
|
| 257 |
+
return {}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@EnvConfig.register_subclass("libero")
|
| 261 |
+
@dataclass
|
| 262 |
+
class LiberoEnv(EnvConfig):
|
| 263 |
+
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
| 264 |
+
task_ids: list[int] | None = None
|
| 265 |
+
fps: int = 30
|
| 266 |
+
episode_length: int | None = None
|
| 267 |
+
obs_type: str = "pixels_agent_pos"
|
| 268 |
+
render_mode: str = "rgb_array"
|
| 269 |
+
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
| 270 |
+
init_states: bool = True
|
| 271 |
+
camera_name_mapping: dict[str, str] | None = None
|
| 272 |
+
observation_height: int = 360
|
| 273 |
+
observation_width: int = 360
|
| 274 |
+
features: dict[str, PolicyFeature] = field(
|
| 275 |
+
default_factory=lambda: {
|
| 276 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
| 277 |
+
}
|
| 278 |
+
)
|
| 279 |
+
features_map: dict[str, str] = field(
|
| 280 |
+
default_factory=lambda: {
|
| 281 |
+
ACTION: ACTION,
|
| 282 |
+
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
|
| 283 |
+
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
|
| 284 |
+
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
|
| 285 |
+
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
|
| 286 |
+
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
|
| 287 |
+
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
|
| 288 |
+
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
|
| 289 |
+
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
|
| 290 |
+
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
| 291 |
+
}
|
| 292 |
+
)
|
| 293 |
+
control_mode: str = "relative" # or "absolute"
|
| 294 |
+
|
| 295 |
+
def __post_init__(self):
|
| 296 |
+
if self.obs_type == "pixels":
|
| 297 |
+
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
|
| 298 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 299 |
+
)
|
| 300 |
+
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
|
| 301 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 302 |
+
)
|
| 303 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 304 |
+
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
|
| 305 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 306 |
+
)
|
| 307 |
+
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
|
| 308 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 309 |
+
)
|
| 310 |
+
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
|
| 311 |
+
type=FeatureType.STATE,
|
| 312 |
+
shape=(3,),
|
| 313 |
+
)
|
| 314 |
+
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
|
| 315 |
+
type=FeatureType.STATE,
|
| 316 |
+
shape=(4,),
|
| 317 |
+
)
|
| 318 |
+
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
|
| 319 |
+
type=FeatureType.STATE,
|
| 320 |
+
shape=(3, 3),
|
| 321 |
+
)
|
| 322 |
+
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
|
| 323 |
+
type=FeatureType.STATE,
|
| 324 |
+
shape=(2,),
|
| 325 |
+
)
|
| 326 |
+
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
|
| 327 |
+
type=FeatureType.STATE,
|
| 328 |
+
shape=(2,),
|
| 329 |
+
)
|
| 330 |
+
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
|
| 331 |
+
type=FeatureType.STATE,
|
| 332 |
+
shape=(7,),
|
| 333 |
+
)
|
| 334 |
+
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
|
| 335 |
+
type=FeatureType.STATE,
|
| 336 |
+
shape=(7,),
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
| 340 |
+
|
| 341 |
+
@property
|
| 342 |
+
def gym_kwargs(self) -> dict:
|
| 343 |
+
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
| 344 |
+
if self.task_ids is not None:
|
| 345 |
+
kwargs["task_ids"] = self.task_ids
|
| 346 |
+
return kwargs
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@EnvConfig.register_subclass("metaworld")
|
| 350 |
+
@dataclass
|
| 351 |
+
class MetaworldEnv(EnvConfig):
|
| 352 |
+
task: str = "metaworld-push-v2" # add all tasks
|
| 353 |
+
fps: int = 80
|
| 354 |
+
episode_length: int = 400
|
| 355 |
+
obs_type: str = "pixels_agent_pos"
|
| 356 |
+
render_mode: str = "rgb_array"
|
| 357 |
+
multitask_eval: bool = True
|
| 358 |
+
features: dict[str, PolicyFeature] = field(
|
| 359 |
+
default_factory=lambda: {
|
| 360 |
+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
features_map: dict[str, str] = field(
|
| 364 |
+
default_factory=lambda: {
|
| 365 |
+
"action": ACTION,
|
| 366 |
+
"agent_pos": OBS_STATE,
|
| 367 |
+
"top": f"{OBS_IMAGE}",
|
| 368 |
+
"pixels/top": f"{OBS_IMAGE}",
|
| 369 |
+
}
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def __post_init__(self):
|
| 373 |
+
if self.obs_type == "pixels":
|
| 374 |
+
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
| 375 |
+
|
| 376 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 377 |
+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
| 378 |
+
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
| 379 |
+
|
| 380 |
+
else:
|
| 381 |
+
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
| 382 |
+
|
| 383 |
+
@property
|
| 384 |
+
def gym_kwargs(self) -> dict:
|
| 385 |
+
return {
|
| 386 |
+
"obs_type": self.obs_type,
|
| 387 |
+
"render_mode": self.render_mode,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@EnvConfig.register_subclass("isaaclab_arena")
|
| 392 |
+
@dataclass
|
| 393 |
+
class IsaaclabArenaEnv(HubEnvConfig):
|
| 394 |
+
hub_path: str = "nvidia/isaaclab-arena-envs"
|
| 395 |
+
episode_length: int = 300
|
| 396 |
+
num_envs: int = 1
|
| 397 |
+
embodiment: str | None = "gr1_pink"
|
| 398 |
+
object: str | None = "power_drill"
|
| 399 |
+
mimic: bool = False
|
| 400 |
+
teleop_device: str | None = None
|
| 401 |
+
seed: int | None = 42
|
| 402 |
+
device: str | None = "cuda:0"
|
| 403 |
+
disable_fabric: bool = False
|
| 404 |
+
enable_cameras: bool = False
|
| 405 |
+
headless: bool = False
|
| 406 |
+
enable_pinocchio: bool = True
|
| 407 |
+
environment: str | None = "gr1_microwave"
|
| 408 |
+
task: str | None = "Reach out to the microwave and open it."
|
| 409 |
+
state_dim: int = 54
|
| 410 |
+
action_dim: int = 36
|
| 411 |
+
camera_height: int = 512
|
| 412 |
+
camera_width: int = 512
|
| 413 |
+
video: bool = False
|
| 414 |
+
video_length: int = 100
|
| 415 |
+
video_interval: int = 200
|
| 416 |
+
# Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos"
|
| 417 |
+
state_keys: str = "robot_joint_pos"
|
| 418 |
+
# Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb"
|
| 419 |
+
# Set to None or "" for environments without cameras
|
| 420 |
+
camera_keys: str | None = None
|
| 421 |
+
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
| 422 |
+
features_map: dict[str, str] = field(default_factory=dict)
|
| 423 |
+
kwargs: dict | None = None
|
| 424 |
+
|
| 425 |
+
def __post_init__(self):
|
| 426 |
+
if self.kwargs:
|
| 427 |
+
# dynamically convert kwargs to fields in the dataclass
|
| 428 |
+
# NOTE! the new fields will not bee seen by the dataclass repr
|
| 429 |
+
field_names = {f.name for f in fields(self)}
|
| 430 |
+
for key, value in self.kwargs.items():
|
| 431 |
+
if key not in field_names and key != "kwargs":
|
| 432 |
+
setattr(self, key, value)
|
| 433 |
+
self.kwargs = None
|
| 434 |
+
|
| 435 |
+
# Set action feature
|
| 436 |
+
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
|
| 437 |
+
self.features_map[ACTION] = ACTION
|
| 438 |
+
|
| 439 |
+
# Set state feature
|
| 440 |
+
self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,))
|
| 441 |
+
self.features_map[OBS_STATE] = OBS_STATE
|
| 442 |
+
|
| 443 |
+
# Add camera features for each camera key
|
| 444 |
+
if self.enable_cameras and self.camera_keys:
|
| 445 |
+
for cam_key in self.camera_keys.split(","):
|
| 446 |
+
cam_key = cam_key.strip()
|
| 447 |
+
if cam_key:
|
| 448 |
+
self.features[cam_key] = PolicyFeature(
|
| 449 |
+
type=FeatureType.VISUAL,
|
| 450 |
+
shape=(self.camera_height, self.camera_width, 3),
|
| 451 |
+
)
|
| 452 |
+
self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}"
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def gym_kwargs(self) -> dict:
|
| 456 |
+
return {}
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# ------------------------ Robocasa365 --------------------------------
|
| 460 |
+
|
| 461 |
+
@EnvConfig.register_subclass("robocasa")
|
| 462 |
+
@dataclass
|
| 463 |
+
class RoboCasaEnv(HubEnvConfig):
|
| 464 |
+
|
| 465 |
+
hub_path: str = "Whalswp/RoboCasa_Env"
|
| 466 |
+
|
| 467 |
+
task: str | None = None
|
| 468 |
+
obs_type: str = "pixels_agent_pos"
|
| 469 |
+
render_mode: str = "rgb_array"
|
| 470 |
+
camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
|
| 471 |
+
observation_height: int = 256
|
| 472 |
+
observation_width: int = 256
|
| 473 |
+
split: str | None = None
|
| 474 |
+
|
| 475 |
+
# VLA ๋ชจ๋ธ ๋ฑ์์ ์ฌ์ฉํ Observation & Action ๊ท๊ฒฉ ๋งคํ
|
| 476 |
+
features: dict[str, PolicyFeature] = field(default_factory=lambda: {
|
| 477 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
| 478 |
+
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(16,)),
|
| 479 |
+
"pixels/robot0_agentview_left": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
| 480 |
+
"pixels/robot0_agentview_right": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
| 481 |
+
"pixels/robot0_eye_in_hand": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
| 482 |
+
})
|
| 483 |
+
features_map: dict[str, str] = field(default_factory=lambda: {
|
| 484 |
+
ACTION: ACTION,
|
| 485 |
+
"agent_pos": OBS_STATE,
|
| 486 |
+
"pixels/robot0_agentview_left": f"{OBS_IMAGES}.robot0_agentview_left",
|
| 487 |
+
"pixels/robot0_agentview_right": f"{OBS_IMAGES}.robot0_agentview_right",
|
| 488 |
+
"pixels/robot0_eye_in_hand": f"{OBS_IMAGES}.robot0_eye_in_hand",
|
| 489 |
+
})
|
backup/env.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# env.py
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from gymnasium import spaces
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from collections.abc import Callable, Sequence, Mapping
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
# RoboCasa ์ ์ฉ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ
|
| 11 |
+
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
|
| 12 |
+
from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS
|
| 13 |
+
|
| 14 |
+
OBS_STATE_DIM = 16
|
| 15 |
+
ACTION_DIM = 12
|
| 16 |
+
ACTION_LOW = -1.0
|
| 17 |
+
ACTION_HIGH = 1.0
|
| 18 |
+
|
| 19 |
+
def convert_state(dict_state):
|
| 20 |
+
"""์๋ฎฌ๋ ์ดํฐ ์ํ๋ฅผ LeRobot์ด ๊ธฐ๋ํ๋ ํํ๋ก ๋ณํ(Conversion)ํฉ๋๋ค."""
|
| 21 |
+
dict_state = dict_state.copy()
|
| 22 |
+
final_state = np.concatenate([
|
| 23 |
+
dict_state["state.base_position"],
|
| 24 |
+
dict_state["state.base_rotation"],
|
| 25 |
+
dict_state["state.end_effector_position_relative"],
|
| 26 |
+
dict_state["state.end_effector_rotation_relative"],
|
| 27 |
+
dict_state["state.gripper_qpos"],
|
| 28 |
+
], axis=0)
|
| 29 |
+
return final_state
|
| 30 |
+
|
| 31 |
+
def convert_action(action):
|
| 32 |
+
"""LeRobot์ ์ก์
์ ์๋ฎฌ๋ ์ดํฐ๊ฐ ์ดํดํ๋ dict ํํ๋ก ๋ณํํฉ๋๋ค."""
|
| 33 |
+
action = action.copy()
|
| 34 |
+
output_action = {
|
| 35 |
+
"action.base_motion": action[0:4],
|
| 36 |
+
"action.control_mode": action[4:5],
|
| 37 |
+
"action.end_effector_position": action[5:8],
|
| 38 |
+
"action.end_effector_rotation": action[8:11],
|
| 39 |
+
"action.gripper_close": action[11:12],
|
| 40 |
+
}
|
| 41 |
+
return output_action
|
| 42 |
+
|
| 43 |
+
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
| 44 |
+
"""์นด๋ฉ๋ผ ์ด๋ฆ์ ๋ฆฌ์คํธ ํํ๋ก ์ ๊ทํ(Normalization)ํฉ๋๋ค."""
|
| 45 |
+
if isinstance(camera_name, str):
|
| 46 |
+
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
| 47 |
+
elif isinstance(camera_name, (list, tuple)):
|
| 48 |
+
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
| 49 |
+
else:
|
| 50 |
+
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
| 51 |
+
if not cams:
|
| 52 |
+
raise ValueError("camera_name resolved to an empty list.")
|
| 53 |
+
return cams
|
| 54 |
+
|
| 55 |
+
class RoboCasaEnv(RoboCasaGymEnv):
|
| 56 |
+
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
task: str,
|
| 61 |
+
camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
|
| 62 |
+
render_mode: str = "rgb_array",
|
| 63 |
+
obs_type: str = "pixels_agent_pos",
|
| 64 |
+
observation_width: int = 256,
|
| 65 |
+
observation_height: int = 256,
|
| 66 |
+
split: str | None = None,
|
| 67 |
+
**kwargs
|
| 68 |
+
):
|
| 69 |
+
self.obs_type = obs_type
|
| 70 |
+
self.render_mode = render_mode
|
| 71 |
+
self.split = split
|
| 72 |
+
self.task = task
|
| 73 |
+
self._task_description = ""
|
| 74 |
+
|
| 75 |
+
kwargs.pop("fps", None)
|
| 76 |
+
self.kwargs = kwargs
|
| 77 |
+
|
| 78 |
+
meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
|
| 79 |
+
try:
|
| 80 |
+
self._max_episode_steps = meta_info[task]['horizon']
|
| 81 |
+
except KeyError:
|
| 82 |
+
raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
|
| 83 |
+
|
| 84 |
+
super().__init__(
|
| 85 |
+
task,
|
| 86 |
+
camera_names=camera_name,
|
| 87 |
+
camera_widths=observation_width,
|
| 88 |
+
camera_heights=observation_height,
|
| 89 |
+
enable_render=(render_mode is not None),
|
| 90 |
+
split=split,
|
| 91 |
+
**kwargs
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _create_obs_and_action_space(self):
|
| 95 |
+
images = {}
|
| 96 |
+
for cam in self.camera_names:
|
| 97 |
+
images[cam] = spaces.Box(
|
| 98 |
+
low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8
|
| 99 |
+
)
|
| 100 |
+
if self.obs_type == "state":
|
| 101 |
+
raise NotImplementedError("The 'state' observation type is not supported.")
|
| 102 |
+
elif self.obs_type == "pixels":
|
| 103 |
+
self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
|
| 104 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 105 |
+
self.observation_space = spaces.Dict({
|
| 106 |
+
"pixels": spaces.Dict(images),
|
| 107 |
+
"agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
|
| 108 |
+
})
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown obs_type: {self.obs_type}")
|
| 111 |
+
|
| 112 |
+
self.action_space = spaces.Box(
|
| 113 |
+
low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def task_description(self) -> str:
|
| 118 |
+
return self._task_description
|
| 119 |
+
|
| 120 |
+
def reset(self, seed: int | None = None, **kwargs):
|
| 121 |
+
self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
|
| 122 |
+
observation, info = super().reset(seed, **kwargs)
|
| 123 |
+
self._task_description = self.env.get_ep_meta().get("lang", self.task)
|
| 124 |
+
print(f"[RoboCasaEnv] task_description: {self._task_description!r}")
|
| 125 |
+
return self._format_raw_obs(observation), info
|
| 126 |
+
|
| 127 |
+
def _format_raw_obs(self, raw_obs: dict):
|
| 128 |
+
new_obs = {}
|
| 129 |
+
if self.obs_type == "pixels_agent_pos":
|
| 130 |
+
new_obs["agent_pos"] = convert_state(raw_obs)
|
| 131 |
+
new_obs["pixels"] = {}
|
| 132 |
+
for k, v in raw_obs.items():
|
| 133 |
+
if "video." in k:
|
| 134 |
+
new_obs["pixels"][k.replace("video.", "")] = v
|
| 135 |
+
return new_obs
|
| 136 |
+
|
| 137 |
+
def step(self, action: np.ndarray):
|
| 138 |
+
self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
|
| 139 |
+
action_dict = convert_action(action)
|
| 140 |
+
observation, reward, done, truncated, info = super().step(action_dict)
|
| 141 |
+
new_obs = self._format_raw_obs(observation)
|
| 142 |
+
|
| 143 |
+
is_success = bool(info.get("success", 0))
|
| 144 |
+
terminated = done or is_success
|
| 145 |
+
info.update({"task": self.task, "done": done, "is_success": is_success})
|
| 146 |
+
|
| 147 |
+
if terminated:
|
| 148 |
+
info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
|
| 149 |
+
self.reset()
|
| 150 |
+
|
| 151 |
+
return new_obs, reward, terminated, truncated, info
|
| 152 |
+
|
| 153 |
+
def render(self):
|
| 154 |
+
frame = super().render()
|
| 155 |
+
if frame is None:
|
| 156 |
+
return frame
|
| 157 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 158 |
+
import textwrap
|
| 159 |
+
|
| 160 |
+
text = self._task_description or self.task
|
| 161 |
+
w = frame.shape[1]
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
|
| 165 |
+
except Exception:
|
| 166 |
+
font = ImageFont.load_default()
|
| 167 |
+
|
| 168 |
+
lines = textwrap.wrap(text, width=55)
|
| 169 |
+
line_h = 18
|
| 170 |
+
bar_h = len(lines) * line_h + 10
|
| 171 |
+
|
| 172 |
+
bar = Image.new("RGB", (w, bar_h), color=(30, 30, 30))
|
| 173 |
+
draw = ImageDraw.Draw(bar)
|
| 174 |
+
for i, line in enumerate(lines):
|
| 175 |
+
draw.text((8, 5 + i * line_h), line, font=font, fill=(220, 220, 220))
|
| 176 |
+
|
| 177 |
+
return np.concatenate([frame, np.array(bar)], axis=0)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]):
|
| 181 |
+
def _make_env(episode_index: int, **kwargs):
|
| 182 |
+
seed = kwargs.pop("seed", episode_index)
|
| 183 |
+
return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs)
|
| 184 |
+
|
| 185 |
+
return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ======================================================================
|
| 189 |
+
# LeRobot Hub ํ์ ์ง์
์ (Entry Point)
|
| 190 |
+
# ======================================================================
|
| 191 |
+
def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
|
| 192 |
+
"""
|
| 193 |
+
LeRobot์ด Hub์์ ํ๊ฒฝ์ ๋ก๋ํ ๋ ํธ์ถํ๋ ๋ฉ์ธ ํจ์์
๋๋ค.
|
| 194 |
+
"""
|
| 195 |
+
# ํ๊ฒฝ ๋ํผ ํด๋์ค ์ ํ
|
| 196 |
+
env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
|
| 197 |
+
|
| 198 |
+
# ์ค์ ๊ฐ ์ถ์ถ (cfg ๊ฐ์ฒด๊ฐ ์์ผ๋ฉด ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ์ ์ฉ)
|
| 199 |
+
if cfg is not None:
|
| 200 |
+
task_name = getattr(cfg, "task", "CloseFridge")
|
| 201 |
+
fps = getattr(cfg, "fps", 20) # fps ์ถ์ถ
|
| 202 |
+
gym_kwargs = {
|
| 203 |
+
"obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
|
| 204 |
+
"render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์ ์ง
|
| 205 |
+
"observation_width": getattr(cfg, "observation_width", 256),
|
| 206 |
+
"observation_height": getattr(cfg, "observation_height", 256),
|
| 207 |
+
"camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
|
| 208 |
+
"split": getattr(cfg, "split", None),
|
| 209 |
+
"fps": fps, # ํต์ฌ ์ธ์ ๋๋ฝ ๋ฐฉ์ง
|
| 210 |
+
}
|
| 211 |
+
else:
|
| 212 |
+
# cfg ์์ด ์ง์ ํธ์ถ๋ ๋์ ๊ธฐ๋ณธ๊ฐ
|
| 213 |
+
task_name = "CloseFridge"
|
| 214 |
+
gym_kwargs = {
|
| 215 |
+
"obs_type": "pixels_agent_pos",
|
| 216 |
+
"render_mode": "rgb_array",
|
| 217 |
+
"observation_width": 256,
|
| 218 |
+
"observation_height": 256,
|
| 219 |
+
"camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
|
| 220 |
+
"split": None,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
|
| 224 |
+
combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
|
| 225 |
+
|
| 226 |
+
# ๋ฒค์น๋งํฌ์ธ์ง ๋จ์ผ ํ์คํฌ์ธ์ง ๊ตฌ๋ถ
|
| 227 |
+
if task_name in combined_tasks:
|
| 228 |
+
task_names = combined_tasks[task_name]
|
| 229 |
+
gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
|
| 230 |
+
else:
|
| 231 |
+
task_names = [t.strip() for t in task_name.split(",")]
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
out = defaultdict(dict)
|
| 235 |
+
|
| 236 |
+
# ํ์คํฌ๋ณ๋ก ํ๊ฒฝ ์์ฑ
|
| 237 |
+
for task in task_names:
|
| 238 |
+
fns = _make_env_fns(
|
| 239 |
+
task_name=task,
|
| 240 |
+
n_envs=n_envs,
|
| 241 |
+
camera_names=parsed_camera_names,
|
| 242 |
+
gym_kwargs=gym_kwargs
|
| 243 |
+
)
|
| 244 |
+
out[task][0] = env_cls(fns)
|
| 245 |
+
|
| 246 |
+
# {suite_name: {task_id: VectorEnv}} ํํ๋ก ๋ฐํ
|
| 247 |
+
#return {"robocasa": dict(out)}
|
| 248 |
+
return {suite: dict(task_map) for suite, task_map in out.items()}
|
configs.py
CHANGED
|
@@ -462,14 +462,19 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
|
| 462 |
@dataclass
|
| 463 |
class RoboCasaEnv(HubEnvConfig):
|
| 464 |
|
| 465 |
-
hub_path: str = "Whalswp/RoboCasa_Env"
|
| 466 |
-
|
| 467 |
-
task
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
obs_type: str = "pixels_agent_pos"
|
| 469 |
render_mode: str = "rgb_array"
|
| 470 |
camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
|
| 471 |
observation_height: int = 256
|
| 472 |
observation_width: int = 256
|
|
|
|
| 473 |
split: str | None = None
|
| 474 |
|
| 475 |
# VLA ๋ชจ๋ธ ๋ฑ์์ ์ฌ์ฉํ Observation & Action ๊ท๊ฒฉ ๋งคํ
|
|
|
|
| 462 |
@dataclass
|
| 463 |
class RoboCasaEnv(HubEnvConfig):
|
| 464 |
|
| 465 |
+
hub_path: str = "Whalswp/RoboCasa_Env"
|
| 466 |
+
|
| 467 |
+
# ๋จ์ผ task ์ด๋ฆ ๋๋ benchmark ํค(`atomic_seen`, `composite_unseen`, ...).
|
| 468 |
+
# ๊ณต์ GR00T eval์ฒ๋ผ ์ฌ๋ฌ ๊ฐ๋ฅผ ๋์์ ๋ฐ์ ์ ์๋๋ก list๋ ํ์ฉ.
|
| 469 |
+
# ์: --env.task=atomic_seen composite_unseen composite_seen
|
| 470 |
+
task: str | list[str] | None = None
|
| 471 |
+
fps: int = 20
|
| 472 |
obs_type: str = "pixels_agent_pos"
|
| 473 |
render_mode: str = "rgb_array"
|
| 474 |
camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
|
| 475 |
observation_height: int = 256
|
| 476 |
observation_width: int = 256
|
| 477 |
+
# `pretrain` | `target` | `all` | None (None์ด๋ฉด task๋ก๋ถํฐ ์ถ๋ก )
|
| 478 |
split: str | None = None
|
| 479 |
|
| 480 |
# VLA ๋ชจ๋ธ ๋ฑ์์ ์ฌ์ฉํ Observation & Action ๊ท๊ฒฉ ๋งคํ
|
configs.py.bak
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import abc
|
| 16 |
+
from dataclasses import dataclass, field, fields
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import draccus
|
| 20 |
+
|
| 21 |
+
from lerobot.configs.types import FeatureType, PolicyFeature
|
| 22 |
+
from lerobot.robots import RobotConfig
|
| 23 |
+
from lerobot.teleoperators.config import TeleoperatorConfig
|
| 24 |
+
from lerobot.utils.constants import (
|
| 25 |
+
ACTION,
|
| 26 |
+
LIBERO_KEY_EEF_MAT,
|
| 27 |
+
LIBERO_KEY_EEF_POS,
|
| 28 |
+
LIBERO_KEY_EEF_QUAT,
|
| 29 |
+
LIBERO_KEY_GRIPPER_QPOS,
|
| 30 |
+
LIBERO_KEY_GRIPPER_QVEL,
|
| 31 |
+
LIBERO_KEY_JOINTS_POS,
|
| 32 |
+
LIBERO_KEY_JOINTS_VEL,
|
| 33 |
+
LIBERO_KEY_PIXELS_AGENTVIEW,
|
| 34 |
+
LIBERO_KEY_PIXELS_EYE_IN_HAND,
|
| 35 |
+
OBS_ENV_STATE,
|
| 36 |
+
OBS_IMAGE,
|
| 37 |
+
OBS_IMAGES,
|
| 38 |
+
OBS_STATE,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
| 44 |
+
task: str | None = None
|
| 45 |
+
fps: int = 30
|
| 46 |
+
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
| 47 |
+
features_map: dict[str, str] = field(default_factory=dict)
|
| 48 |
+
max_parallel_tasks: int = 1
|
| 49 |
+
disable_env_checker: bool = True
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def type(self) -> str:
|
| 53 |
+
return self.get_choice_name(self.__class__)
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def package_name(self) -> str:
|
| 57 |
+
"""Package name to import if environment not found in gym registry"""
|
| 58 |
+
return f"gym_{self.type}"
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def gym_id(self) -> str:
|
| 62 |
+
"""ID string used in gym.make() to instantiate the environment"""
|
| 63 |
+
return f"{self.package_name}/{self.task}"
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
@abc.abstractmethod
|
| 67 |
+
def gym_kwargs(self) -> dict:
|
| 68 |
+
raise NotImplementedError()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class HubEnvConfig(EnvConfig):
|
| 73 |
+
"""Base class for environments that delegate creation to a hub-hosted make_env.
|
| 74 |
+
|
| 75 |
+
Hub environments download and execute remote code from the HF Hub.
|
| 76 |
+
The hub_path points to a repository containing an env.py with a make_env function.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
hub_path: str | None = None # required: e.g., "username/repo" or "username/repo@branch:file.py"
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def gym_kwargs(self) -> dict:
|
| 83 |
+
# Not used for hub environments - the hub's make_env handles everything
|
| 84 |
+
return {}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@EnvConfig.register_subclass("aloha")
|
| 88 |
+
@dataclass
|
| 89 |
+
class AlohaEnv(EnvConfig):
|
| 90 |
+
task: str | None = "AlohaInsertion-v0"
|
| 91 |
+
fps: int = 50
|
| 92 |
+
episode_length: int = 400
|
| 93 |
+
obs_type: str = "pixels_agent_pos"
|
| 94 |
+
observation_height: int = 480
|
| 95 |
+
observation_width: int = 640
|
| 96 |
+
render_mode: str = "rgb_array"
|
| 97 |
+
features: dict[str, PolicyFeature] = field(
|
| 98 |
+
default_factory=lambda: {
|
| 99 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
features_map: dict[str, str] = field(
|
| 103 |
+
default_factory=lambda: {
|
| 104 |
+
ACTION: ACTION,
|
| 105 |
+
"agent_pos": OBS_STATE,
|
| 106 |
+
"top": f"{OBS_IMAGE}.top",
|
| 107 |
+
"pixels/top": f"{OBS_IMAGES}.top",
|
| 108 |
+
}
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def __post_init__(self):
|
| 112 |
+
if self.obs_type == "pixels":
|
| 113 |
+
self.features["top"] = PolicyFeature(
|
| 114 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 115 |
+
)
|
| 116 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 117 |
+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
| 118 |
+
self.features["pixels/top"] = PolicyFeature(
|
| 119 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def gym_kwargs(self) -> dict:
|
| 124 |
+
return {
|
| 125 |
+
"obs_type": self.obs_type,
|
| 126 |
+
"render_mode": self.render_mode,
|
| 127 |
+
"max_episode_steps": self.episode_length,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@EnvConfig.register_subclass("pusht")
|
| 132 |
+
@dataclass
|
| 133 |
+
class PushtEnv(EnvConfig):
|
| 134 |
+
task: str | None = "PushT-v0"
|
| 135 |
+
fps: int = 10
|
| 136 |
+
episode_length: int = 300
|
| 137 |
+
obs_type: str = "pixels_agent_pos"
|
| 138 |
+
render_mode: str = "rgb_array"
|
| 139 |
+
visualization_width: int = 384
|
| 140 |
+
visualization_height: int = 384
|
| 141 |
+
observation_height: int = 384
|
| 142 |
+
observation_width: int = 384
|
| 143 |
+
features: dict[str, PolicyFeature] = field(
|
| 144 |
+
default_factory=lambda: {
|
| 145 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
| 146 |
+
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
features_map: dict[str, str] = field(
|
| 150 |
+
default_factory=lambda: {
|
| 151 |
+
ACTION: ACTION,
|
| 152 |
+
"agent_pos": OBS_STATE,
|
| 153 |
+
"environment_state": OBS_ENV_STATE,
|
| 154 |
+
"pixels": OBS_IMAGE,
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def __post_init__(self):
|
| 159 |
+
if self.obs_type == "pixels_agent_pos":
|
| 160 |
+
self.features["pixels"] = PolicyFeature(
|
| 161 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 162 |
+
)
|
| 163 |
+
elif self.obs_type == "environment_state_agent_pos":
|
| 164 |
+
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def gym_kwargs(self) -> dict:
|
| 168 |
+
return {
|
| 169 |
+
"obs_type": self.obs_type,
|
| 170 |
+
"render_mode": self.render_mode,
|
| 171 |
+
"visualization_width": self.visualization_width,
|
| 172 |
+
"visualization_height": self.visualization_height,
|
| 173 |
+
"max_episode_steps": self.episode_length,
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@dataclass
|
| 178 |
+
class ImagePreprocessingConfig:
|
| 179 |
+
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
| 180 |
+
resize_size: tuple[int, int] | None = None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@dataclass
|
| 184 |
+
class RewardClassifierConfig:
|
| 185 |
+
"""Configuration for reward classification."""
|
| 186 |
+
|
| 187 |
+
pretrained_path: str | None = None
|
| 188 |
+
success_threshold: float = 0.5
|
| 189 |
+
success_reward: float = 1.0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@dataclass
|
| 193 |
+
class InverseKinematicsConfig:
|
| 194 |
+
"""Configuration for inverse kinematics processing."""
|
| 195 |
+
|
| 196 |
+
urdf_path: str | None = None
|
| 197 |
+
target_frame_name: str | None = None
|
| 198 |
+
end_effector_bounds: dict[str, list[float]] | None = None
|
| 199 |
+
end_effector_step_sizes: dict[str, float] | None = None
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@dataclass
|
| 203 |
+
class ObservationConfig:
|
| 204 |
+
"""Configuration for observation processing."""
|
| 205 |
+
|
| 206 |
+
add_joint_velocity_to_observation: bool = False
|
| 207 |
+
add_current_to_observation: bool = False
|
| 208 |
+
add_ee_pose_to_observation: bool = False
|
| 209 |
+
display_cameras: bool = False
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclass
|
| 213 |
+
class GripperConfig:
|
| 214 |
+
"""Configuration for gripper control and penalties."""
|
| 215 |
+
|
| 216 |
+
use_gripper: bool = True
|
| 217 |
+
gripper_penalty: float = 0.0
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
@dataclass
|
| 221 |
+
class ResetConfig:
|
| 222 |
+
"""Configuration for environment reset behavior."""
|
| 223 |
+
|
| 224 |
+
fixed_reset_joint_positions: Any | None = None
|
| 225 |
+
reset_time_s: float = 5.0
|
| 226 |
+
control_time_s: float = 20.0
|
| 227 |
+
terminate_on_success: bool = True
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dataclass
|
| 231 |
+
class HILSerlProcessorConfig:
|
| 232 |
+
"""Configuration for environment processing pipeline."""
|
| 233 |
+
|
| 234 |
+
control_mode: str = "gamepad"
|
| 235 |
+
observation: ObservationConfig | None = None
|
| 236 |
+
image_preprocessing: ImagePreprocessingConfig | None = None
|
| 237 |
+
gripper: GripperConfig | None = None
|
| 238 |
+
reset: ResetConfig | None = None
|
| 239 |
+
inverse_kinematics: InverseKinematicsConfig | None = None
|
| 240 |
+
reward_classifier: RewardClassifierConfig | None = None
|
| 241 |
+
max_gripper_pos: float | None = 100.0
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@EnvConfig.register_subclass(name="gym_manipulator")
|
| 245 |
+
@dataclass
|
| 246 |
+
class HILSerlRobotEnvConfig(EnvConfig):
|
| 247 |
+
"""Configuration for the HILSerlRobotEnv environment."""
|
| 248 |
+
|
| 249 |
+
robot: RobotConfig | None = None
|
| 250 |
+
teleop: TeleoperatorConfig | None = None
|
| 251 |
+
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
|
| 252 |
+
|
| 253 |
+
name: str = "real_robot"
|
| 254 |
+
|
| 255 |
+
@property
|
| 256 |
+
def gym_kwargs(self) -> dict:
|
| 257 |
+
return {}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@EnvConfig.register_subclass("libero")
|
| 261 |
+
@dataclass
|
| 262 |
+
class LiberoEnv(EnvConfig):
|
| 263 |
+
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
| 264 |
+
task_ids: list[int] | None = None
|
| 265 |
+
fps: int = 30
|
| 266 |
+
episode_length: int | None = None
|
| 267 |
+
obs_type: str = "pixels_agent_pos"
|
| 268 |
+
render_mode: str = "rgb_array"
|
| 269 |
+
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
| 270 |
+
init_states: bool = True
|
| 271 |
+
camera_name_mapping: dict[str, str] | None = None
|
| 272 |
+
observation_height: int = 360
|
| 273 |
+
observation_width: int = 360
|
| 274 |
+
features: dict[str, PolicyFeature] = field(
|
| 275 |
+
default_factory=lambda: {
|
| 276 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
| 277 |
+
}
|
| 278 |
+
)
|
| 279 |
+
features_map: dict[str, str] = field(
|
| 280 |
+
default_factory=lambda: {
|
| 281 |
+
ACTION: ACTION,
|
| 282 |
+
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
|
| 283 |
+
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
|
| 284 |
+
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
|
| 285 |
+
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
|
| 286 |
+
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
|
| 287 |
+
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
|
| 288 |
+
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
|
| 289 |
+
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
|
| 290 |
+
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
| 291 |
+
}
|
| 292 |
+
)
|
| 293 |
+
control_mode: str = "relative" # or "absolute"
|
| 294 |
+
|
| 295 |
+
def __post_init__(self):
|
| 296 |
+
if self.obs_type == "pixels":
|
| 297 |
+
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
|
| 298 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 299 |
+
)
|
| 300 |
+
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
|
| 301 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 302 |
+
)
|
| 303 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 304 |
+
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
|
| 305 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 306 |
+
)
|
| 307 |
+
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
|
| 308 |
+
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
| 309 |
+
)
|
| 310 |
+
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
|
| 311 |
+
type=FeatureType.STATE,
|
| 312 |
+
shape=(3,),
|
| 313 |
+
)
|
| 314 |
+
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
|
| 315 |
+
type=FeatureType.STATE,
|
| 316 |
+
shape=(4,),
|
| 317 |
+
)
|
| 318 |
+
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
|
| 319 |
+
type=FeatureType.STATE,
|
| 320 |
+
shape=(3, 3),
|
| 321 |
+
)
|
| 322 |
+
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
|
| 323 |
+
type=FeatureType.STATE,
|
| 324 |
+
shape=(2,),
|
| 325 |
+
)
|
| 326 |
+
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
|
| 327 |
+
type=FeatureType.STATE,
|
| 328 |
+
shape=(2,),
|
| 329 |
+
)
|
| 330 |
+
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
|
| 331 |
+
type=FeatureType.STATE,
|
| 332 |
+
shape=(7,),
|
| 333 |
+
)
|
| 334 |
+
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
|
| 335 |
+
type=FeatureType.STATE,
|
| 336 |
+
shape=(7,),
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
| 340 |
+
|
| 341 |
+
@property
|
| 342 |
+
def gym_kwargs(self) -> dict:
|
| 343 |
+
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
| 344 |
+
if self.task_ids is not None:
|
| 345 |
+
kwargs["task_ids"] = self.task_ids
|
| 346 |
+
return kwargs
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@EnvConfig.register_subclass("metaworld")
|
| 350 |
+
@dataclass
|
| 351 |
+
class MetaworldEnv(EnvConfig):
|
| 352 |
+
task: str = "metaworld-push-v2" # add all tasks
|
| 353 |
+
fps: int = 80
|
| 354 |
+
episode_length: int = 400
|
| 355 |
+
obs_type: str = "pixels_agent_pos"
|
| 356 |
+
render_mode: str = "rgb_array"
|
| 357 |
+
multitask_eval: bool = True
|
| 358 |
+
features: dict[str, PolicyFeature] = field(
|
| 359 |
+
default_factory=lambda: {
|
| 360 |
+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
features_map: dict[str, str] = field(
|
| 364 |
+
default_factory=lambda: {
|
| 365 |
+
"action": ACTION,
|
| 366 |
+
"agent_pos": OBS_STATE,
|
| 367 |
+
"top": f"{OBS_IMAGE}",
|
| 368 |
+
"pixels/top": f"{OBS_IMAGE}",
|
| 369 |
+
}
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def __post_init__(self):
|
| 373 |
+
if self.obs_type == "pixels":
|
| 374 |
+
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
| 375 |
+
|
| 376 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 377 |
+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
| 378 |
+
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
| 379 |
+
|
| 380 |
+
else:
|
| 381 |
+
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
| 382 |
+
|
| 383 |
+
@property
|
| 384 |
+
def gym_kwargs(self) -> dict:
|
| 385 |
+
return {
|
| 386 |
+
"obs_type": self.obs_type,
|
| 387 |
+
"render_mode": self.render_mode,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@EnvConfig.register_subclass("isaaclab_arena")
|
| 392 |
+
@dataclass
|
| 393 |
+
class IsaaclabArenaEnv(HubEnvConfig):
|
| 394 |
+
hub_path: str = "nvidia/isaaclab-arena-envs"
|
| 395 |
+
episode_length: int = 300
|
| 396 |
+
num_envs: int = 1
|
| 397 |
+
embodiment: str | None = "gr1_pink"
|
| 398 |
+
object: str | None = "power_drill"
|
| 399 |
+
mimic: bool = False
|
| 400 |
+
teleop_device: str | None = None
|
| 401 |
+
seed: int | None = 42
|
| 402 |
+
device: str | None = "cuda:0"
|
| 403 |
+
disable_fabric: bool = False
|
| 404 |
+
enable_cameras: bool = False
|
| 405 |
+
headless: bool = False
|
| 406 |
+
enable_pinocchio: bool = True
|
| 407 |
+
environment: str | None = "gr1_microwave"
|
| 408 |
+
task: str | None = "Reach out to the microwave and open it."
|
| 409 |
+
state_dim: int = 54
|
| 410 |
+
action_dim: int = 36
|
| 411 |
+
camera_height: int = 512
|
| 412 |
+
camera_width: int = 512
|
| 413 |
+
video: bool = False
|
| 414 |
+
video_length: int = 100
|
| 415 |
+
video_interval: int = 200
|
| 416 |
+
# Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos"
|
| 417 |
+
state_keys: str = "robot_joint_pos"
|
| 418 |
+
# Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb"
|
| 419 |
+
# Set to None or "" for environments without cameras
|
| 420 |
+
camera_keys: str | None = None
|
| 421 |
+
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
| 422 |
+
features_map: dict[str, str] = field(default_factory=dict)
|
| 423 |
+
kwargs: dict | None = None
|
| 424 |
+
|
| 425 |
+
def __post_init__(self):
|
| 426 |
+
if self.kwargs:
|
| 427 |
+
# dynamically convert kwargs to fields in the dataclass
|
| 428 |
+
# NOTE! the new fields will not bee seen by the dataclass repr
|
| 429 |
+
field_names = {f.name for f in fields(self)}
|
| 430 |
+
for key, value in self.kwargs.items():
|
| 431 |
+
if key not in field_names and key != "kwargs":
|
| 432 |
+
setattr(self, key, value)
|
| 433 |
+
self.kwargs = None
|
| 434 |
+
|
| 435 |
+
# Set action feature
|
| 436 |
+
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
|
| 437 |
+
self.features_map[ACTION] = ACTION
|
| 438 |
+
|
| 439 |
+
# Set state feature
|
| 440 |
+
self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,))
|
| 441 |
+
self.features_map[OBS_STATE] = OBS_STATE
|
| 442 |
+
|
| 443 |
+
# Add camera features for each camera key
|
| 444 |
+
if self.enable_cameras and self.camera_keys:
|
| 445 |
+
for cam_key in self.camera_keys.split(","):
|
| 446 |
+
cam_key = cam_key.strip()
|
| 447 |
+
if cam_key:
|
| 448 |
+
self.features[cam_key] = PolicyFeature(
|
| 449 |
+
type=FeatureType.VISUAL,
|
| 450 |
+
shape=(self.camera_height, self.camera_width, 3),
|
| 451 |
+
)
|
| 452 |
+
self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}"
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def gym_kwargs(self) -> dict:
|
| 456 |
+
return {}
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# ------------------------ Robocasa365 --------------------------------
|
| 460 |
+
|
| 461 |
+
@EnvConfig.register_subclass("robocasa")
|
| 462 |
+
@dataclass
|
| 463 |
+
class RoboCasaEnv(HubEnvConfig):
|
| 464 |
+
|
| 465 |
+
hub_path: str = "Whalswp/RoboCasa_Env"
|
| 466 |
+
|
| 467 |
+
task: str | None = None
|
| 468 |
+
obs_type: str = "pixels_agent_pos"
|
| 469 |
+
render_mode: str = "rgb_array"
|
| 470 |
+
camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
|
| 471 |
+
observation_height: int = 256
|
| 472 |
+
observation_width: int = 256
|
| 473 |
+
split: str | None = None
|
| 474 |
+
|
| 475 |
+
# VLA ๋ชจ๋ธ ๋ฑ์์ ์ฌ์ฉํ Observation & Action ๊ท๊ฒฉ ๋งคํ
|
| 476 |
+
features: dict[str, PolicyFeature] = field(default_factory=lambda: {
|
| 477 |
+
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
| 478 |
+
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(16,)),
|
| 479 |
+
"pixels/robot0_agentview_left": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
| 480 |
+
"pixels/robot0_agentview_right": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
| 481 |
+
"pixels/robot0_eye_in_hand": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
| 482 |
+
})
|
| 483 |
+
features_map: dict[str, str] = field(default_factory=lambda: {
|
| 484 |
+
ACTION: ACTION,
|
| 485 |
+
"agent_pos": OBS_STATE,
|
| 486 |
+
"pixels/robot0_agentview_left": f"{OBS_IMAGES}.robot0_agentview_left",
|
| 487 |
+
"pixels/robot0_agentview_right": f"{OBS_IMAGES}.robot0_agentview_right",
|
| 488 |
+
"pixels/robot0_eye_in_hand": f"{OBS_IMAGES}.robot0_eye_in_hand",
|
| 489 |
+
})
|
env.py
CHANGED
|
@@ -3,21 +3,30 @@ import gymnasium as gym
|
|
| 3 |
from gymnasium import spaces
|
| 4 |
import numpy as np
|
| 5 |
from collections import defaultdict
|
| 6 |
-
from collections.abc import
|
| 7 |
from functools import partial
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
# RoboCasa ์ ์ฉ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ
|
| 11 |
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
|
| 12 |
-
from robocasa.utils.dataset_registry import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
OBS_STATE_DIM = 16
|
| 15 |
ACTION_DIM = 12
|
| 16 |
ACTION_LOW = -1.0
|
| 17 |
ACTION_HIGH = 1.0
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
dict_state = dict_state.copy()
|
| 22 |
final_state = np.concatenate([
|
| 23 |
dict_state["state.base_position"],
|
|
@@ -26,11 +35,14 @@ def convert_state(dict_state):
|
|
| 26 |
dict_state["state.end_effector_rotation_relative"],
|
| 27 |
dict_state["state.gripper_qpos"],
|
| 28 |
], axis=0)
|
| 29 |
-
return final_state
|
|
|
|
| 30 |
|
| 31 |
def convert_action(action):
|
| 32 |
-
"""LeRobot์ ์ก์
์ ์๋ฎฌ๋ ์ดํฐ๊ฐ ์ดํดํ๋ dict ํํ๋ก ๋ณํํฉ๋๋ค.
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
output_action = {
|
| 35 |
"action.base_motion": action[0:4],
|
| 36 |
"action.control_mode": action[4:5],
|
|
@@ -40,8 +52,8 @@ def convert_action(action):
|
|
| 40 |
}
|
| 41 |
return output_action
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
if isinstance(camera_name, str):
|
| 46 |
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
| 47 |
elif isinstance(camera_name, (list, tuple)):
|
|
@@ -52,45 +64,71 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
|
| 52 |
raise ValueError("camera_name resolved to an empty list.")
|
| 53 |
return cams
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
class RoboCasaEnv(RoboCasaGymEnv):
|
| 56 |
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
| 57 |
|
| 58 |
def __init__(
|
| 59 |
self,
|
| 60 |
task: str,
|
| 61 |
-
camera_name: Sequence[str] =
|
| 62 |
render_mode: str = "rgb_array",
|
| 63 |
obs_type: str = "pixels_agent_pos",
|
| 64 |
observation_width: int = 256,
|
| 65 |
observation_height: int = 256,
|
| 66 |
split: str | None = None,
|
| 67 |
-
**kwargs
|
| 68 |
):
|
| 69 |
self.obs_type = obs_type
|
| 70 |
self.render_mode = render_mode
|
| 71 |
self.split = split
|
| 72 |
self.task = task
|
| 73 |
-
self._task_description =
|
| 74 |
-
|
| 75 |
kwargs.pop("fps", None)
|
| 76 |
self.kwargs = kwargs
|
| 77 |
|
| 78 |
-
|
| 79 |
try:
|
| 80 |
-
self._max_episode_steps =
|
| 81 |
-
except
|
| 82 |
-
|
| 83 |
-
|
|
|
|
| 84 |
super().__init__(
|
| 85 |
task,
|
| 86 |
-
camera_names=camera_name,
|
| 87 |
camera_widths=observation_width,
|
| 88 |
camera_heights=observation_height,
|
| 89 |
enable_render=(render_mode is not None),
|
| 90 |
split=split,
|
| 91 |
-
**kwargs
|
| 92 |
)
|
| 93 |
-
|
| 94 |
def _create_obs_and_action_space(self):
|
| 95 |
images = {}
|
| 96 |
for cam in self.camera_names:
|
|
@@ -100,11 +138,15 @@ class RoboCasaEnv(RoboCasaGymEnv):
|
|
| 100 |
if self.obs_type == "state":
|
| 101 |
raise NotImplementedError("The 'state' observation type is not supported.")
|
| 102 |
elif self.obs_type == "pixels":
|
| 103 |
-
self.observation_space = spaces.Dict({
|
|
|
|
|
|
|
|
|
|
| 104 |
elif self.obs_type == "pixels_agent_pos":
|
| 105 |
self.observation_space = spaces.Dict({
|
| 106 |
"pixels": spaces.Dict(images),
|
| 107 |
"agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
|
|
|
|
| 108 |
})
|
| 109 |
else:
|
| 110 |
raise ValueError(f"Unknown obs_type: {self.obs_type}")
|
|
@@ -117,38 +159,75 @@ class RoboCasaEnv(RoboCasaGymEnv):
|
|
| 117 |
def task_description(self) -> str:
|
| 118 |
return self._task_description
|
| 119 |
|
| 120 |
-
def
|
| 121 |
-
|
| 122 |
-
observation, info = super().reset(seed, **kwargs)
|
| 123 |
-
self._task_description = self.env.get_ep_meta().get("lang", self.task)
|
| 124 |
-
print(f"[RoboCasaEnv] task_description: {self._task_description!r}")
|
| 125 |
-
return self._format_raw_obs(observation), info
|
| 126 |
-
|
| 127 |
-
def _format_raw_obs(self, raw_obs: dict):
|
| 128 |
-
new_obs = {}
|
| 129 |
if self.obs_type == "pixels_agent_pos":
|
| 130 |
new_obs["agent_pos"] = convert_state(raw_obs)
|
| 131 |
new_obs["pixels"] = {}
|
| 132 |
for k, v in raw_obs.items():
|
| 133 |
-
if "video."
|
| 134 |
new_obs["pixels"][k.replace("video.", "")] = v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
return new_obs
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def step(self, action: np.ndarray):
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
action_dict = convert_action(action)
|
| 140 |
observation, reward, done, truncated, info = super().step(action_dict)
|
| 141 |
new_obs = self._format_raw_obs(observation)
|
| 142 |
|
| 143 |
is_success = bool(info.get("success", 0))
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
def render(self):
|
| 154 |
frame = super().render()
|
|
@@ -185,71 +264,99 @@ def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwar
|
|
| 185 |
return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
|
| 186 |
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
# ======================================================================
|
| 189 |
# LeRobot Hub ํ์ ์ง์
์ (Entry Point)
|
| 190 |
# ======================================================================
|
| 191 |
def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
"""
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
|
| 197 |
|
| 198 |
-
# ์ค์ ๊ฐ ์ถ์ถ (cfg ๊ฐ์ฒด๊ฐ ์์ผ๋ฉด ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ์ ์ฉ)
|
| 199 |
if cfg is not None:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
gym_kwargs = {
|
| 203 |
"obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
|
| 204 |
-
"render_mode": getattr(cfg, "render_mode", "rgb_array"),
|
| 205 |
"observation_width": getattr(cfg, "observation_width", 256),
|
| 206 |
"observation_height": getattr(cfg, "observation_height", 256),
|
| 207 |
-
"camera_name": getattr(
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
else:
|
| 212 |
-
|
| 213 |
-
|
| 214 |
gym_kwargs = {
|
| 215 |
"obs_type": "pixels_agent_pos",
|
| 216 |
"render_mode": "rgb_array",
|
| 217 |
"observation_width": 256,
|
| 218 |
"observation_height": 256,
|
| 219 |
"camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
|
| 220 |
-
"
|
| 221 |
}
|
| 222 |
|
| 223 |
parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
else:
|
| 233 |
-
task_names = []
|
| 234 |
-
for part in parts:
|
| 235 |
-
if part in combined_tasks:
|
| 236 |
-
task_names.extend(combined_tasks[part])
|
| 237 |
-
else:
|
| 238 |
-
task_names.append(part)
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
out = defaultdict(dict)
|
| 242 |
-
|
| 243 |
-
# ํ์คํฌ๋ณ๋ก ํ๊ฒฝ ์์ฑ
|
| 244 |
-
for task in task_names:
|
| 245 |
fns = _make_env_fns(
|
| 246 |
task_name=task,
|
| 247 |
n_envs=n_envs,
|
| 248 |
camera_names=parsed_camera_names,
|
| 249 |
-
gym_kwargs=
|
| 250 |
)
|
|
|
|
|
|
|
| 251 |
out[task][0] = env_cls(fns)
|
| 252 |
|
| 253 |
-
|
| 254 |
-
#return {"robocasa": dict(out)}
|
| 255 |
-
return {suite: dict(task_map) for suite, task_map in out.items()}
|
|
|
|
| 3 |
from gymnasium import spaces
|
| 4 |
import numpy as np
|
| 5 |
from collections import defaultdict
|
| 6 |
+
from collections.abc import Sequence, Mapping
|
| 7 |
from functools import partial
|
| 8 |
from typing import Any
|
| 9 |
|
| 10 |
# RoboCasa ์ ์ฉ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ
|
| 11 |
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
|
| 12 |
+
from robocasa.utils.dataset_registry import (
|
| 13 |
+
ATOMIC_TASK_DATASETS,
|
| 14 |
+
COMPOSITE_TASK_DATASETS,
|
| 15 |
+
TARGET_TASKS,
|
| 16 |
+
PRETRAINING_TASKS,
|
| 17 |
+
)
|
| 18 |
+
from robocasa.utils.dataset_registry_utils import get_task_horizon
|
| 19 |
|
| 20 |
OBS_STATE_DIM = 16
|
| 21 |
ACTION_DIM = 12
|
| 22 |
ACTION_LOW = -1.0
|
| 23 |
ACTION_HIGH = 1.0
|
| 24 |
|
| 25 |
+
|
| 26 |
+
def convert_state(dict_state):
|
| 27 |
+
"""์๋ฎฌ๋ ์ดํฐ ์ํ๋ฅผ LeRobot์ด ๊ธฐ๋ํ๋ 16-dim concat ํํ๋ก ๋ณํํฉ๋๋ค.
|
| 28 |
+
์ธ๋ฑ์ค ์ ์๋ STATE_ACTION_SPEC.md ์ฐธ๊ณ .
|
| 29 |
+
"""
|
| 30 |
dict_state = dict_state.copy()
|
| 31 |
final_state = np.concatenate([
|
| 32 |
dict_state["state.base_position"],
|
|
|
|
| 35 |
dict_state["state.end_effector_rotation_relative"],
|
| 36 |
dict_state["state.gripper_qpos"],
|
| 37 |
], axis=0)
|
| 38 |
+
return final_state.astype(np.float32)
|
| 39 |
+
|
| 40 |
|
| 41 |
def convert_action(action):
|
| 42 |
+
"""LeRobot์ 12-dim ์ก์
์ ์๋ฎฌ๋ ์ดํฐ๊ฐ ์ดํดํ๋ dict ํํ๋ก ๋ณํํฉ๋๋ค.
|
| 43 |
+
์ธ๋ฑ์ค ์ ์๋ STATE_ACTION_SPEC.md ์ฐธ๊ณ .
|
| 44 |
+
"""
|
| 45 |
+
action = np.asarray(action).copy()
|
| 46 |
output_action = {
|
| 47 |
"action.base_motion": action[0:4],
|
| 48 |
"action.control_mode": action[4:5],
|
|
|
|
| 52 |
}
|
| 53 |
return output_action
|
| 54 |
|
| 55 |
+
|
| 56 |
+
def _parse_camera_names(camera_name) -> list[str]:
|
| 57 |
if isinstance(camera_name, str):
|
| 58 |
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
| 59 |
elif isinstance(camera_name, (list, tuple)):
|
|
|
|
| 64 |
raise ValueError("camera_name resolved to an empty list.")
|
| 65 |
return cams
|
| 66 |
|
| 67 |
+
|
| 68 |
+
def _normalize_task_arg(task) -> list[str]:
|
| 69 |
+
"""`--env.task=atomic_seen composite_unseen ...` ์ฒ๋ผ ๋ค์ด์ค๋ ๋ค์ํ ํํ๋ฅผ
|
| 70 |
+
list[str]๋ก ์ ๊ทํํ๋ค.
|
| 71 |
+
- draccus๊ฐ list๋ก ํ์ฑํ ๊ฒฝ์ฐ: ๊ทธ๋๋ก ์ ์ง
|
| 72 |
+
- ๋จ์ผ ๋ฌธ์์ด์ด๋ฉด ๊ณต๋ฐฑ ๋๋ ์ฝค๋ง ๋ถ๋ฆฌ
|
| 73 |
+
"""
|
| 74 |
+
if task is None:
|
| 75 |
+
raise ValueError("task is required")
|
| 76 |
+
if isinstance(task, (list, tuple)):
|
| 77 |
+
items = []
|
| 78 |
+
for t in task:
|
| 79 |
+
items.extend(_normalize_task_arg(t))
|
| 80 |
+
return items
|
| 81 |
+
s = str(task).strip()
|
| 82 |
+
if not s:
|
| 83 |
+
return []
|
| 84 |
+
# ๊ณต๋ฐฑ/์ฝค๋ง ๋ชจ๋ ํ์ฉ
|
| 85 |
+
parts = []
|
| 86 |
+
for chunk in s.replace(",", " ").split():
|
| 87 |
+
if chunk:
|
| 88 |
+
parts.append(chunk)
|
| 89 |
+
return parts
|
| 90 |
+
|
| 91 |
+
|
| 92 |
class RoboCasaEnv(RoboCasaGymEnv):
|
| 93 |
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
| 94 |
|
| 95 |
def __init__(
|
| 96 |
self,
|
| 97 |
task: str,
|
| 98 |
+
camera_name: Sequence[str] = ("robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"),
|
| 99 |
render_mode: str = "rgb_array",
|
| 100 |
obs_type: str = "pixels_agent_pos",
|
| 101 |
observation_width: int = 256,
|
| 102 |
observation_height: int = 256,
|
| 103 |
split: str | None = None,
|
| 104 |
+
**kwargs,
|
| 105 |
):
|
| 106 |
self.obs_type = obs_type
|
| 107 |
self.render_mode = render_mode
|
| 108 |
self.split = split
|
| 109 |
self.task = task
|
| 110 |
+
self._task_description: str = task
|
| 111 |
+
|
| 112 |
kwargs.pop("fps", None)
|
| 113 |
self.kwargs = kwargs
|
| 114 |
|
| 115 |
+
# horizon์ ๊ณต์ ํฌํผ ์ฌ์ฉ. ๋ฏธ๋ฑ๋ก ํ์คํฌ๋ ๋ช
์์ ์ผ๋ก ์๋ฌ.
|
| 116 |
try:
|
| 117 |
+
self._max_episode_steps = int(get_task_horizon(task))
|
| 118 |
+
except Exception as e:
|
| 119 |
+
valid = list({**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}.keys())
|
| 120 |
+
raise ValueError(f"Unknown task '{task}'. Valid tasks: {valid[:10]}... ({len(valid)} total)") from e
|
| 121 |
+
|
| 122 |
super().__init__(
|
| 123 |
task,
|
| 124 |
+
camera_names=list(camera_name),
|
| 125 |
camera_widths=observation_width,
|
| 126 |
camera_heights=observation_height,
|
| 127 |
enable_render=(render_mode is not None),
|
| 128 |
split=split,
|
| 129 |
+
**kwargs,
|
| 130 |
)
|
| 131 |
+
|
| 132 |
def _create_obs_and_action_space(self):
|
| 133 |
images = {}
|
| 134 |
for cam in self.camera_names:
|
|
|
|
| 138 |
if self.obs_type == "state":
|
| 139 |
raise NotImplementedError("The 'state' observation type is not supported.")
|
| 140 |
elif self.obs_type == "pixels":
|
| 141 |
+
self.observation_space = spaces.Dict({
|
| 142 |
+
"pixels": spaces.Dict(images),
|
| 143 |
+
"task": spaces.Text(max_length=512),
|
| 144 |
+
})
|
| 145 |
elif self.obs_type == "pixels_agent_pos":
|
| 146 |
self.observation_space = spaces.Dict({
|
| 147 |
"pixels": spaces.Dict(images),
|
| 148 |
"agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
|
| 149 |
+
"task": spaces.Text(max_length=512),
|
| 150 |
})
|
| 151 |
else:
|
| 152 |
raise ValueError(f"Unknown obs_type: {self.obs_type}")
|
|
|
|
| 159 |
def task_description(self) -> str:
|
| 160 |
return self._task_description
|
| 161 |
|
| 162 |
+
def _format_raw_obs(self, raw_obs: dict) -> dict:
|
| 163 |
+
new_obs: dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
if self.obs_type == "pixels_agent_pos":
|
| 165 |
new_obs["agent_pos"] = convert_state(raw_obs)
|
| 166 |
new_obs["pixels"] = {}
|
| 167 |
for k, v in raw_obs.items():
|
| 168 |
+
if k.startswith("video."):
|
| 169 |
new_obs["pixels"][k.replace("video.", "")] = v
|
| 170 |
+
# ์ธ์ด ์กฐ๊ฑด: AsyncVectorEnv์์๋ ๋๊ธฐ์ง ์๋๋ก obs์ ์ง์ ๋
ธ์ถ.
|
| 171 |
+
# RoboCasaGymEnv๊ฐ obs์ ์ฑ์์ฃผ๋ annotation์ด ์์ผ๋ฉด ์ฐ์ ์ฌ์ฉ.
|
| 172 |
+
lang = raw_obs.get("annotation.human.task_description")
|
| 173 |
+
if not lang:
|
| 174 |
+
lang = self._task_description or self.task
|
| 175 |
+
new_obs["task"] = str(lang)
|
| 176 |
+
self._task_description = str(lang)
|
| 177 |
return new_obs
|
| 178 |
|
| 179 |
+
def reset(self, seed: int | None = None, **kwargs):
|
| 180 |
+
# mujoco offscreen GL ์ปจํ
์คํธ๊ฐ reset ์์ ์ stale ์ํ๊ฐ ๋๋ ํ๊ฒฝ์์์
|
| 181 |
+
# ์ฐํ. (์ฌ์ฉ์ ํ๊ฒฝ์์ ํ์ํด ๋ณด์กด)
|
| 182 |
+
try:
|
| 183 |
+
self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
|
| 184 |
+
except Exception:
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
observation, info = super().reset(seed=seed, **kwargs)
|
| 188 |
+
# ep meta์ lang์ด ๋ ํ๋ถํ ๋ฌธ์ฅ์ ์ฃผ๋ฏ๋ก ๊ฐ๋ฅํ๋ฉด ์ฌ์ฉ
|
| 189 |
+
try:
|
| 190 |
+
ep_lang = self.env.get_ep_meta().get("lang", None)
|
| 191 |
+
if ep_lang:
|
| 192 |
+
self._task_description = str(ep_lang)
|
| 193 |
+
except Exception:
|
| 194 |
+
pass
|
| 195 |
+
formatted = self._format_raw_obs(observation)
|
| 196 |
+
info = dict(info or {})
|
| 197 |
+
info["task"] = self.task
|
| 198 |
+
info["task_description"] = self._task_description
|
| 199 |
+
return formatted, info
|
| 200 |
+
|
| 201 |
def step(self, action: np.ndarray):
|
| 202 |
+
try:
|
| 203 |
+
self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
|
| 204 |
+
except Exception:
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
action_dict = convert_action(action)
|
| 208 |
observation, reward, done, truncated, info = super().step(action_dict)
|
| 209 |
new_obs = self._format_raw_obs(observation)
|
| 210 |
|
| 211 |
is_success = bool(info.get("success", 0))
|
| 212 |
+
# ๊ณต์ GR00T eval๊ณผ ๋์ผํ๊ฒ: success๋ termination ์ ํธ๋ก ์ง์ ๋ณํํ์ง ์๋๋ค.
|
| 213 |
+
# gymnasium VectorEnv์ autoreset์ด terminated/truncated๋ฅผ ๋ณด๊ณ final_info๋ฅผ ๋ง๋ ๋ค.
|
| 214 |
+
terminated = bool(done) or is_success
|
| 215 |
+
|
| 216 |
+
info = dict(info or {})
|
| 217 |
+
info.update({
|
| 218 |
+
"task": self.task,
|
| 219 |
+
"task_description": self._task_description,
|
| 220 |
+
"is_success": is_success,
|
| 221 |
+
"success": is_success,
|
| 222 |
+
"done": bool(done),
|
| 223 |
+
})
|
| 224 |
+
|
| 225 |
+
# NOTE: ์์ฒด self.reset() ํธ์ถ์ ์ ๊ฑฐ.
|
| 226 |
+
# - gymnasium 0.29+ VectorEnv๊ฐ terminated/truncated ์ ์๋์ผ๋ก resetํ๊ณ
|
| 227 |
+
# final_info๋ฅผ ์ฑ์์ค๋ค. wrapper๊ฐ ํ ๋ฒ ๋ resetํ๋ฉด ์ฒซ obs๊ฐ final obs๋ฅผ
|
| 228 |
+
# ๋ฎ์ด์ฐ๊ณ , lerobot rollout์ final_info["is_success"]๊ฐ ์ ํฉ์ฑ์ ์๋๋ค.
|
| 229 |
+
|
| 230 |
+
return new_obs, float(reward), terminated, truncated, info
|
| 231 |
|
| 232 |
def render(self):
|
| 233 |
frame = super().render()
|
|
|
|
| 264 |
return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
|
| 265 |
|
| 266 |
|
| 267 |
+
def _resolve_task_list(task_arg, explicit_split: str | None):
|
| 268 |
+
"""`task` ์ธ์(๋ฌธ์์ด/๋ฆฌ์คํธ)์ ์ฌ์ฉ์๊ฐ ๋ช
์ํ split์ ๋ณด๊ณ
|
| 269 |
+
์ค์ ๋ก ๋์ธ (task_name, split) ํ์ด ๋ฆฌ์คํธ๋ฅผ ๋ง๋ ๋ค.
|
| 270 |
+
|
| 271 |
+
๊ณต์ run_eval.py์ฒ๋ผ:
|
| 272 |
+
- benchmark ํค(`atomic_seen`, `composite_unseen`, ...)๊ฐ ๋ค์ด์ค๋ฉด ํผ์น๋ค.
|
| 273 |
+
- explicit_split์ด ์ง์ ๋๋ฉด ๊ทธ๊ฒ์ ์ฐ์ ํ๋ค.
|
| 274 |
+
- ์๋๋ฉด TARGET_TASKS / PRETRAINING_TASKS ๋ฑ๋ก๋ถ์์ ์๋ ์ถ๋ก .
|
| 275 |
+
- ๋จ์ผ task ์ด๋ฆ์ด๋ฉด ๊ทธ๋๋ก ์ฌ์ฉ.
|
| 276 |
+
"""
|
| 277 |
+
items = _normalize_task_arg(task_arg)
|
| 278 |
+
pairs: list[tuple[str, str | None]] = []
|
| 279 |
+
for item in items:
|
| 280 |
+
if item in TARGET_TASKS:
|
| 281 |
+
split = explicit_split or "target"
|
| 282 |
+
for sub in TARGET_TASKS[item]:
|
| 283 |
+
pairs.append((sub, split))
|
| 284 |
+
elif item in PRETRAINING_TASKS:
|
| 285 |
+
split = explicit_split or "pretrain"
|
| 286 |
+
for sub in PRETRAINING_TASKS[item]:
|
| 287 |
+
pairs.append((sub, split))
|
| 288 |
+
else:
|
| 289 |
+
# ๋จ์ผ task ์ด๋ฆ
|
| 290 |
+
pairs.append((item, explicit_split))
|
| 291 |
+
# ์ค๋ณต ์ ๊ฑฐ (์์ ์ ์ง)
|
| 292 |
+
seen = set()
|
| 293 |
+
uniq: list[tuple[str, str | None]] = []
|
| 294 |
+
for p in pairs:
|
| 295 |
+
if p in seen:
|
| 296 |
+
continue
|
| 297 |
+
seen.add(p)
|
| 298 |
+
uniq.append(p)
|
| 299 |
+
return uniq
|
| 300 |
+
|
| 301 |
+
|
| 302 |
# ======================================================================
|
| 303 |
# LeRobot Hub ํ์ ์ง์
์ (Entry Point)
|
| 304 |
# ======================================================================
|
| 305 |
def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
|
| 306 |
+
"""LeRobot์ด Hub์์ ํ๊ฒฝ์ ๋ก๋ํ ๋ ํธ์ถํ๋ ๋ฉ์ธ ํจ์.
|
| 307 |
+
|
| 308 |
+
๊ณต์ GR00T eval(`Isaac-GR00T/scripts/run_eval.py`)๊ณผ ๋์ผํ๊ฒ:
|
| 309 |
+
- benchmark ํค(`atomic_seen`, `composite_unseen`, ...)๋ sub-task ๋ฆฌ์คํธ๋ก ํผ์น๋ค.
|
| 310 |
+
- `cfg.split`์ด ์ง์ ๋๋ฉด ๊ทธ๊ฒ์ ์ฐ์ ํ๋ค.
|
| 311 |
+
- ๊ฐ sub-task๋ ์์ ์ horizon์ผ๋ก ๋ณ๋ VectorEnv๋ฅผ ๋ง๋ ๋ค.
|
| 312 |
"""
|
| 313 |
+
env_cls = (
|
| 314 |
+
partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
|
| 315 |
+
)
|
|
|
|
| 316 |
|
|
|
|
| 317 |
if cfg is not None:
|
| 318 |
+
task_arg = getattr(cfg, "task", None)
|
| 319 |
+
explicit_split = getattr(cfg, "split", None)
|
| 320 |
gym_kwargs = {
|
| 321 |
"obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
|
| 322 |
+
"render_mode": getattr(cfg, "render_mode", "rgb_array"),
|
| 323 |
"observation_width": getattr(cfg, "observation_width", 256),
|
| 324 |
"observation_height": getattr(cfg, "observation_height", 256),
|
| 325 |
+
"camera_name": getattr(
|
| 326 |
+
cfg, "camera_name",
|
| 327 |
+
"robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
|
| 328 |
+
),
|
| 329 |
+
"fps": getattr(cfg, "fps", 20),
|
| 330 |
}
|
| 331 |
else:
|
| 332 |
+
task_arg = "CloseFridge"
|
| 333 |
+
explicit_split = None
|
| 334 |
gym_kwargs = {
|
| 335 |
"obs_type": "pixels_agent_pos",
|
| 336 |
"render_mode": "rgb_array",
|
| 337 |
"observation_width": 256,
|
| 338 |
"observation_height": 256,
|
| 339 |
"camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
|
| 340 |
+
"fps": 20,
|
| 341 |
}
|
| 342 |
|
| 343 |
parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
|
| 344 |
+
task_split_pairs = _resolve_task_list(task_arg, explicit_split)
|
| 345 |
+
if not task_split_pairs:
|
| 346 |
+
raise ValueError(f"No tasks resolved from task={task_arg!r}, split={explicit_split!r}")
|
| 347 |
+
|
| 348 |
+
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
| 349 |
+
for idx, (task, split) in enumerate(task_split_pairs):
|
| 350 |
+
per_task_kwargs = dict(gym_kwargs)
|
| 351 |
+
per_task_kwargs["split"] = split
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
fns = _make_env_fns(
|
| 353 |
task_name=task,
|
| 354 |
n_envs=n_envs,
|
| 355 |
camera_names=parsed_camera_names,
|
| 356 |
+
gym_kwargs=per_task_kwargs,
|
| 357 |
)
|
| 358 |
+
# `{suite: {task_id: VectorEnv}}` ๊ตฌ์กฐ: lerobot_eval์ group/task ๋จ์๋ก ์ง๊ณํ๋ฏ๋ก
|
| 359 |
+
# task_name ์์ฒด๋ฅผ suite ํค๋ก ์ฌ์ฉํด per-task SR์ ๊ทธ๋๋ก ๋
ธ์ถ.
|
| 360 |
out[task][0] = env_cls(fns)
|
| 361 |
|
| 362 |
+
return {suite: dict(task_map) for suite, task_map in out.items()}
|
|
|
|
|
|
env.py.bak
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# env.py
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
from gymnasium import spaces
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from collections.abc import Callable, Sequence, Mapping
|
| 7 |
+
from functools import partial
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
# RoboCasa ์ ์ฉ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ
|
| 11 |
+
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
|
| 12 |
+
from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS
|
| 13 |
+
|
| 14 |
+
OBS_STATE_DIM = 16
|
| 15 |
+
ACTION_DIM = 12
|
| 16 |
+
ACTION_LOW = -1.0
|
| 17 |
+
ACTION_HIGH = 1.0
|
| 18 |
+
|
| 19 |
+
def convert_state(dict_state):
|
| 20 |
+
"""์๋ฎฌ๋ ์ดํฐ ์ํ๋ฅผ LeRobot์ด ๊ธฐ๋ํ๋ ํํ๋ก ๋ณํ(Conversion)ํฉ๋๋ค."""
|
| 21 |
+
dict_state = dict_state.copy()
|
| 22 |
+
final_state = np.concatenate([
|
| 23 |
+
dict_state["state.base_position"],
|
| 24 |
+
dict_state["state.base_rotation"],
|
| 25 |
+
dict_state["state.end_effector_position_relative"],
|
| 26 |
+
dict_state["state.end_effector_rotation_relative"],
|
| 27 |
+
dict_state["state.gripper_qpos"],
|
| 28 |
+
], axis=0)
|
| 29 |
+
return final_state
|
| 30 |
+
|
| 31 |
+
def convert_action(action):
|
| 32 |
+
"""LeRobot์ ์ก์
์ ์๋ฎฌ๋ ์ดํฐ๊ฐ ์ดํดํ๋ dict ํํ๋ก ๋ณํํฉ๋๋ค."""
|
| 33 |
+
action = action.copy()
|
| 34 |
+
output_action = {
|
| 35 |
+
"action.base_motion": action[0:4],
|
| 36 |
+
"action.control_mode": action[4:5],
|
| 37 |
+
"action.end_effector_position": action[5:8],
|
| 38 |
+
"action.end_effector_rotation": action[8:11],
|
| 39 |
+
"action.gripper_close": action[11:12],
|
| 40 |
+
}
|
| 41 |
+
return output_action
|
| 42 |
+
|
| 43 |
+
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
| 44 |
+
"""์นด๋ฉ๋ผ ์ด๋ฆ์ ๋ฆฌ์คํธ ํํ๋ก ์ ๊ทํ(Normalization)ํฉ๋๋ค."""
|
| 45 |
+
if isinstance(camera_name, str):
|
| 46 |
+
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
| 47 |
+
elif isinstance(camera_name, (list, tuple)):
|
| 48 |
+
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
| 49 |
+
else:
|
| 50 |
+
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
| 51 |
+
if not cams:
|
| 52 |
+
raise ValueError("camera_name resolved to an empty list.")
|
| 53 |
+
return cams
|
| 54 |
+
|
| 55 |
+
class RoboCasaEnv(RoboCasaGymEnv):
|
| 56 |
+
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
task: str,
|
| 61 |
+
camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
|
| 62 |
+
render_mode: str = "rgb_array",
|
| 63 |
+
obs_type: str = "pixels_agent_pos",
|
| 64 |
+
observation_width: int = 256,
|
| 65 |
+
observation_height: int = 256,
|
| 66 |
+
split: str | None = None,
|
| 67 |
+
**kwargs
|
| 68 |
+
):
|
| 69 |
+
self.obs_type = obs_type
|
| 70 |
+
self.render_mode = render_mode
|
| 71 |
+
self.split = split
|
| 72 |
+
self.task = task
|
| 73 |
+
self._task_description = ""
|
| 74 |
+
|
| 75 |
+
kwargs.pop("fps", None)
|
| 76 |
+
self.kwargs = kwargs
|
| 77 |
+
|
| 78 |
+
meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
|
| 79 |
+
try:
|
| 80 |
+
self._max_episode_steps = meta_info[task]['horizon']
|
| 81 |
+
except KeyError:
|
| 82 |
+
raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
|
| 83 |
+
|
| 84 |
+
super().__init__(
|
| 85 |
+
task,
|
| 86 |
+
camera_names=camera_name,
|
| 87 |
+
camera_widths=observation_width,
|
| 88 |
+
camera_heights=observation_height,
|
| 89 |
+
enable_render=(render_mode is not None),
|
| 90 |
+
split=split,
|
| 91 |
+
**kwargs
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _create_obs_and_action_space(self):
|
| 95 |
+
images = {}
|
| 96 |
+
for cam in self.camera_names:
|
| 97 |
+
images[cam] = spaces.Box(
|
| 98 |
+
low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8
|
| 99 |
+
)
|
| 100 |
+
if self.obs_type == "state":
|
| 101 |
+
raise NotImplementedError("The 'state' observation type is not supported.")
|
| 102 |
+
elif self.obs_type == "pixels":
|
| 103 |
+
self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
|
| 104 |
+
elif self.obs_type == "pixels_agent_pos":
|
| 105 |
+
self.observation_space = spaces.Dict({
|
| 106 |
+
"pixels": spaces.Dict(images),
|
| 107 |
+
"agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
|
| 108 |
+
})
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown obs_type: {self.obs_type}")
|
| 111 |
+
|
| 112 |
+
self.action_space = spaces.Box(
|
| 113 |
+
low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def task_description(self) -> str:
|
| 118 |
+
return self._task_description
|
| 119 |
+
|
| 120 |
+
def reset(self, seed: int | None = None, **kwargs):
|
| 121 |
+
self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
|
| 122 |
+
observation, info = super().reset(seed, **kwargs)
|
| 123 |
+
self._task_description = self.env.get_ep_meta().get("lang", self.task)
|
| 124 |
+
print(f"[RoboCasaEnv] task_description: {self._task_description!r}")
|
| 125 |
+
return self._format_raw_obs(observation), info
|
| 126 |
+
|
| 127 |
+
def _format_raw_obs(self, raw_obs: dict):
|
| 128 |
+
new_obs = {}
|
| 129 |
+
if self.obs_type == "pixels_agent_pos":
|
| 130 |
+
new_obs["agent_pos"] = convert_state(raw_obs)
|
| 131 |
+
new_obs["pixels"] = {}
|
| 132 |
+
for k, v in raw_obs.items():
|
| 133 |
+
if "video." in k:
|
| 134 |
+
new_obs["pixels"][k.replace("video.", "")] = v
|
| 135 |
+
return new_obs
|
| 136 |
+
|
| 137 |
+
def step(self, action: np.ndarray):
|
| 138 |
+
self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
|
| 139 |
+
action_dict = convert_action(action)
|
| 140 |
+
observation, reward, done, truncated, info = super().step(action_dict)
|
| 141 |
+
new_obs = self._format_raw_obs(observation)
|
| 142 |
+
|
| 143 |
+
is_success = bool(info.get("success", 0))
|
| 144 |
+
terminated = done or is_success
|
| 145 |
+
info.update({"task": self.task, "done": done, "is_success": is_success})
|
| 146 |
+
|
| 147 |
+
if terminated:
|
| 148 |
+
info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
|
| 149 |
+
self.reset()
|
| 150 |
+
|
| 151 |
+
return new_obs, reward, terminated, truncated, info
|
| 152 |
+
|
| 153 |
+
def render(self):
|
| 154 |
+
frame = super().render()
|
| 155 |
+
if frame is None:
|
| 156 |
+
return frame
|
| 157 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 158 |
+
import textwrap
|
| 159 |
+
|
| 160 |
+
text = self._task_description or self.task
|
| 161 |
+
w = frame.shape[1]
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
|
| 165 |
+
except Exception:
|
| 166 |
+
font = ImageFont.load_default()
|
| 167 |
+
|
| 168 |
+
lines = textwrap.wrap(text, width=55)
|
| 169 |
+
line_h = 18
|
| 170 |
+
bar_h = len(lines) * line_h + 10
|
| 171 |
+
|
| 172 |
+
bar = Image.new("RGB", (w, bar_h), color=(30, 30, 30))
|
| 173 |
+
draw = ImageDraw.Draw(bar)
|
| 174 |
+
for i, line in enumerate(lines):
|
| 175 |
+
draw.text((8, 5 + i * line_h), line, font=font, fill=(220, 220, 220))
|
| 176 |
+
|
| 177 |
+
return np.concatenate([frame, np.array(bar)], axis=0)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]):
|
| 181 |
+
def _make_env(episode_index: int, **kwargs):
|
| 182 |
+
seed = kwargs.pop("seed", episode_index)
|
| 183 |
+
return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs)
|
| 184 |
+
|
| 185 |
+
return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ======================================================================
|
| 189 |
+
# LeRobot Hub ํ์ ์ง์
์ (Entry Point)
|
| 190 |
+
# ======================================================================
|
| 191 |
+
def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
|
| 192 |
+
"""
|
| 193 |
+
LeRobot์ด Hub์์ ํ๊ฒฝ์ ๋ก๋ํ ๋ ํธ์ถํ๋ ๋ฉ์ธ ํจ์์
๋๋ค.
|
| 194 |
+
"""
|
| 195 |
+
# ํ๊ฒฝ ๋ํผ ํด๋์ค ์ ํ
|
| 196 |
+
env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
|
| 197 |
+
|
| 198 |
+
# ์ค์ ๊ฐ ์ถ์ถ (cfg ๊ฐ์ฒด๊ฐ ์์ผ๋ฉด ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด ๊ธฐ๋ณธ๊ฐ ์ ์ฉ)
|
| 199 |
+
if cfg is not None:
|
| 200 |
+
task_name = getattr(cfg, "task", "CloseFridge")
|
| 201 |
+
fps = getattr(cfg, "fps", 20) # fps ์ถ์ถ
|
| 202 |
+
gym_kwargs = {
|
| 203 |
+
"obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
|
| 204 |
+
"render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์ ์ง
|
| 205 |
+
"observation_width": getattr(cfg, "observation_width", 256),
|
| 206 |
+
"observation_height": getattr(cfg, "observation_height", 256),
|
| 207 |
+
"camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
|
| 208 |
+
"split": getattr(cfg, "split", None),
|
| 209 |
+
"fps": fps, # ํต์ฌ ์ธ์ ๋๋ฝ ๋ฐฉ์ง
|
| 210 |
+
}
|
| 211 |
+
else:
|
| 212 |
+
# cfg ์์ด ์ง์ ํธ์ถ๋ ๋์ ๊ธฐ๋ณธ๊ฐ
|
| 213 |
+
task_name = "CloseFridge"
|
| 214 |
+
gym_kwargs = {
|
| 215 |
+
"obs_type": "pixels_agent_pos",
|
| 216 |
+
"render_mode": "rgb_array",
|
| 217 |
+
"observation_width": 256,
|
| 218 |
+
"observation_height": 256,
|
| 219 |
+
"camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
|
| 220 |
+
"split": None,
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
|
| 224 |
+
combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
|
| 225 |
+
|
| 226 |
+
# ๋ฒค์น๋งํฌ์ธ์ง ๋จ์ผ ํ์คํฌ์ธ์ง ๊ตฌ๋ถ
|
| 227 |
+
if task_name in combined_tasks:
|
| 228 |
+
task_names = combined_tasks[task_name]
|
| 229 |
+
gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
|
| 230 |
+
else:
|
| 231 |
+
task_names = [t.strip() for t in task_name.split(",")]
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
out = defaultdict(dict)
|
| 235 |
+
|
| 236 |
+
# ํ์คํฌ๋ณ๋ก ํ๊ฒฝ ์์ฑ
|
| 237 |
+
for task in task_names:
|
| 238 |
+
fns = _make_env_fns(
|
| 239 |
+
task_name=task,
|
| 240 |
+
n_envs=n_envs,
|
| 241 |
+
camera_names=parsed_camera_names,
|
| 242 |
+
gym_kwargs=gym_kwargs
|
| 243 |
+
)
|
| 244 |
+
out[task][0] = env_cls(fns)
|
| 245 |
+
|
| 246 |
+
# {suite_name: {task_id: VectorEnv}} ํํ๋ก ๋ฐํ
|
| 247 |
+
#return {"robocasa": dict(out)}
|
| 248 |
+
return {suite: dict(task_map) for suite, task_map in out.items()}
|