Whalswp commited on
Commit
d893a8a
ยท
verified ยท
1 Parent(s): eb2032b

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. CHANGES.md +215 -0
  2. STATE_ACTION_SPEC.md +83 -0
  3. backup/configs.py +489 -0
  4. backup/env.py +248 -0
  5. configs.py +8 -3
  6. configs.py.bak +489 -0
  7. env.py +187 -80
  8. env.py.bak +248 -0
CHANGES.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Robocasa_Env ๋ณ€๊ฒฝ์‚ฌํ•ญ โ€” ๊ณต์‹ GR00T eval ํ๋ฆ„๊ณผ ์ •๋ ฌ
2
+
3
+ ๋ชฉ์ : `lerobot-eval --env.type=robocasa ...`๊ฐ€ ๊ณต์‹
4
+ `Isaac-GR00T/scripts/run_eval.py`์™€ ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ๋™์ž‘ํ•˜๋„๋ก ์ˆ˜์ •.
5
+ **State/Action ๊ณ„์•ฝ(`STATE_ACTION_SPEC.md`)์€ ๊ทธ๋Œ€๋กœ ์œ ์ง€** โ€” lerobot ์ •์ฑ…(์˜ˆ: ACT)์ด
6
+ 12-dim concat action์„ ์ถœ๋ ฅํ•˜๊ณ , env๊ฐ€ 16-dim concat `agent_pos`๋ฅผ ๋…ธ์ถœํ•˜๋Š” ์•ฝ์†์„ ๊นจ์ง€ ์•Š๋Š”๋‹ค.
7
+
8
+ ์ˆ˜์ • ํŒŒ์ผ: `env.py`, `configs.py` ๋‘ ๊ฐœ. (๋ฐฑ์—…: `*.bak`)
9
+
10
+ ---
11
+
12
+ ## 1. `env.py` ๋ณ€๊ฒฝ ์š”์•ฝ
13
+
14
+ ### 1-1. `step()`์—์„œ ์ž์ฒด `self.reset()` ์ œ๊ฑฐ (โ˜… ํ•ต์‹ฌ)
15
+
16
+ **Before**
17
+ ```python
18
+ if terminated:
19
+ info["final_info"] = {...}
20
+ self.reset() # โ† ํ™˜๊ฒฝ์ด ์Šค์Šค๋กœ reset
21
+ return new_obs, reward, terminated, truncated, info
22
+ ```
23
+
24
+ **After**
25
+ ```python
26
+ # self.reset() ํ˜ธ์ถœ ์ œ๊ฑฐ.
27
+ # gymnasium 0.29+ VectorEnv์˜ autoreset์ด terminated/truncated๋ฅผ ๋ณด๊ณ 
28
+ # final_info๋ฅผ ๋งŒ๋“ค๊ณ  ์ž๋™์œผ๋กœ ๋ฆฌ์…‹ํ•œ๋‹ค. wrapper๊ฐ€ ํ•œ ๋ฒˆ ๋” resetํ•˜๋ฉด
29
+ # ์ฒซ obs๊ฐ€ final obs๋ฅผ ๋ฎ์–ด์“ฐ๊ณ , lerobot rollout์˜ final_info["is_success"]๊ฐ€
30
+ # ์ •ํ•ฉ์„ฑ์„ ์žƒ๋Š”๋‹ค.
31
+ return new_obs, float(reward), terminated, truncated, info
32
+ ```
33
+ - ๊ณต์‹ GR00T `simulation.py`๋„ wrapper ์•ˆ์—์„œ resetํ•˜์ง€ ์•Š๊ณ  ์™ธ๋ถ€ ๋ฃจํ”„๊ฐ€ ์ฒ˜๋ฆฌํ•จ.
34
+ - `lerobot_eval.py`์˜ `rollout()`์ด `info["final_info"][i]["is_success"]`๋ฅผ ์ฝ์œผ๋ฏ€๋กœ,
35
+ ์ž๋™ reset์— ๋งก๊ธฐ๋Š” ๊ฒŒ ์ •ํ™•ํ•œ SR ์ง‘๊ณ„๋กœ ์ด์–ด์ง.
36
+
37
+ ### 1-2. ์–ธ์–ด ์กฐ๊ฑด(`task`)์„ obs๋กœ ์ง์ ‘ ๋…ธ์ถœ
38
+
39
+ `RoboCasaGymEnv.get_observation`์€ ์ด๋ฏธ `annotation.human.task_description` ํ‚ค๋กœ
40
+ ep_meta์˜ lang์„ obs์— ์ฑ„์›Œ ์ค€๋‹ค. ์‚ฌ์šฉ์ž ์ฝ”๋“œ์˜ `_format_raw_obs`๋Š” ์ด๊ฑธ ๋ฒ„๋ฆฌ๊ณ 
41
+ ์žˆ์—ˆ์Œ.
42
+
43
+ ```python
44
+ def _format_raw_obs(self, raw_obs):
45
+ ...
46
+ lang = raw_obs.get("annotation.human.task_description") \
47
+ or self._task_description or self.task
48
+ new_obs["task"] = str(lang)
49
+ self._task_description = str(lang)
50
+ return new_obs
51
+ ```
52
+
53
+ - `AsyncVectorEnv`์—์„œ๋Š” `env.call("task_description")`์ด worker process๋ฅผ ๊ฑฐ์ณ ๋น„์‹ธ๊ณ 
54
+ ํƒ€์ด๋ฐ ์ด์Šˆ๊ฐ€ ์žˆ๋‹ค (README์˜ *use_async_envs=True ๋ณด๋ฅ˜: task_description ๋ˆ„๋ฝ* ์ด์Šˆ).
55
+ obs์— ์ง์ ‘ ๋„ฃ์œผ๋ฉด sync/async ๋ชจ๋‘์—์„œ ๋Š๊ธฐ์ง€ ์•Š์Œ.
56
+ - `observation_space`์—๋„ `"task": spaces.Text(max_length=512)`๋ฅผ ์ถ”๊ฐ€.
57
+
58
+ ### 1-3. horizon์€ ๊ณต์‹ ํ—ฌํผ `get_task_horizon` ์‚ฌ์šฉ
59
+
60
+ ```python
61
+ from robocasa.utils.dataset_registry_utils import get_task_horizon
62
+ ...
63
+ self._max_episode_steps = int(get_task_horizon(task))
64
+ ```
65
+
66
+ - ๊ธฐ์กด `meta_info[task]['horizon']` ์ง์กฐํšŒ์™€ ๋™์ผ ๊ฒฐ๊ณผ์ง€๋งŒ, ๊ณต์‹ ํ•จ์ˆ˜์™€ ๋™์ผํ•œ ์ฝ”๋“œ ๊ฒฝ๋กœ๋ฅผ ํƒ€๊ฒŒ ํ•œ๋‹ค.
67
+ - `lerobot_eval.rollout()`์ด `env.call("_max_episode_steps")[0]`์„ ์ฝ์œผ๋ฏ€๋กœ,
68
+ **batch์—๋Š” horizon์ด ๊ฐ™์€ task๋งŒ ๋ฌถ์ด๋„๋ก** task๋ณ„๋กœ ๋ณ„๋„ VectorEnv๋ฅผ ๋งŒ๋“ ๋‹ค (๊ธฐ์กด ๊ตฌ์กฐ ์œ ์ง€).
69
+
70
+ ### 1-4. `_resolve_task_list()` ์‹ ์„ค โ€” benchmark/๋‹จ์ผ task/๋‹ค์ค‘ task ๋ชจ๋‘ ์ฒ˜๋ฆฌ
71
+
72
+ ๊ณต์‹ `run_eval.py`์˜ ๋‹ค์Œ ๋กœ์ง์„ ์žฌํ˜„:
73
+
74
+ ```python
75
+ all_env_names = []
76
+ for task_set in task_set_list:
77
+ all_env_names += TASK_SET_REGISTRY[task_set]
78
+ all_env_names = set(all_env_names)
79
+ for env_name in all_env_names:
80
+ config = SimulationConfig(env_name=f"robocasa/{env_name}", split=split, ...)
81
+ ```
82
+
83
+ ์ˆ˜์ • ํ›„ `make_env`:
84
+ - benchmark ํ‚ค(`atomic_seen`, `composite_unseen`, `pretrain50`, ...)๋ฉด sub-task ๋ฆฌ์ŠคํŠธ๋กœ ํŽผ์นœ๋‹ค.
85
+ - ๋‹จ์ผ task ์ด๋ฆ„์ด๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ.
86
+ - ์—ฌ๋Ÿฌ ๊ฐœ๋ฅผ ๊ณต๋ฐฑ/์ฝค๋งˆ/๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ๋ชจ๋‘ ๋ฐ›๋Š”๋‹ค.
87
+ - `--env.task=atomic_seen composite_unseen composite_seen` (draccus list)
88
+ - `--env.task="atomic_seen,composite_unseen"` (์ฝค๋งˆ ๋ถ„๋ฆฌ)
89
+ - `--env.task=PnPCounterToCab` (๋‹จ์ผ)
90
+
91
+ ### 1-5. `cfg.split` ๋ช…์‹œ ์‹œ ์šฐ์„ 
92
+
93
+ **Before**
94
+ ```python
95
+ if task_name in combined_tasks:
96
+ task_names = combined_tasks[task_name]
97
+ gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
98
+ # โ† ์‚ฌ์šฉ์ž๊ฐ€ --env.split=pretrain ์ค˜๋„ ๊ฐ•์ œ ๋ฎ์–ด์”€
99
+ ```
100
+
101
+ **After**
102
+ ```python
103
+ if item in TARGET_TASKS:
104
+ split = explicit_split or "target" # explicit์ด ์žˆ์œผ๋ฉด ๊ทธ๊ฒƒ ์‚ฌ์šฉ
105
+ elif item in PRETRAINING_TASKS:
106
+ split = explicit_split or "pretrain"
107
+ else:
108
+ pairs.append((item, explicit_split)) # ๋‹จ์ผ task๋Š” ๊ทธ๋Œ€๋กœ
109
+ ```
110
+
111
+ ์ด์ œ `--env.task=atomic_seen --env.split=pretrain` ๊ฐ™์€ ์กฐํ•ฉ์ด ์˜๋„๋Œ€๋กœ ๋™์ž‘ํ•œ๋‹ค
112
+ (๊ณต์‹ `run_eval.py`๋„ `--task_set`๊ณผ `--split`์„ ๋…๋ฆฝ ์ธ์ž๋กœ ๋ฐ›์Œ).
113
+
114
+ ### 1-6. GL context ์ˆ˜๋™ ์กฐ์ž‘์€ ๋ณด์กด (๋‹จ, try/except)
115
+
116
+ `reset()`์˜ `gl_ctx.free()`, `step()`์˜ `make_current()`๋Š” **์ด์œ  ๋ถˆ๋ช…์˜ ์šฐํšŒ hack**์œผ๋กœ ๋ณด์ด์ง€๋งŒ
117
+ ์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์—์„œ ์˜๋„์ ์œผ๋กœ ์ถ”๊ฐ€๋œ ๊ฒƒ์ผ ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ์–ด **๋ณด์กด**.
118
+ ๋‹ค๋งŒ `try/except`๋กœ ๊ฐ์‹ธ ๋‹ค๋ฅธ ํ™˜๊ฒฝ์—์„œ ๊นจ์ง€์ง€ ์•Š๋„๋ก ํ–ˆ๋‹ค.
119
+
120
+ ### 1-7. ๊ธฐํƒ€ ์ •๋ฆฌ
121
+
122
+ - `convert_state` ๊ฒฐ๊ณผ๋ฅผ `float32`๋กœ ๋ช…์‹œ (lerobot tensor ๋ณ€ํ™˜ ์•ˆ์ „)
123
+ - `convert_action`์ด list/tuple๋„ ๋ฐ›๋„๋ก `np.asarray`
124
+ - `_format_raw_obs`์˜ `"video." in k` โ†’ `k.startswith("video.")` (์˜คํƒ ๋ฐฉ์ง€)
125
+ - info dict์— `task`, `task_description`์„ ๋งค step ์ฑ„์›€ (lerobot์˜ `add_envs_task` fallback)
126
+
127
+ ---
128
+
129
+ ## 2. `configs.py` ๋ณ€๊ฒฝ ์š”์•ฝ
130
+
131
+ ```python
132
+ @EnvConfig.register_subclass("robocasa")
133
+ @dataclass
134
+ class RoboCasaEnv(HubEnvConfig):
135
+ hub_path: str = "Whalswp/RoboCasa_Env"
136
+
137
+ # โ˜… list ํ—ˆ์šฉ โ€” ๊ณต์‹ run_eval์ฒ˜๋Ÿผ ์—ฌ๋Ÿฌ task_set ๋™์‹œ ์ž…๋ ฅ
138
+ task: str | list[str] | None = None
139
+ fps: int = 20 # โ˜… ์‹ ์„ค (env.py๊ฐ€ cfg.fps ์ฐธ์กฐ)
140
+ obs_type: str = "pixels_agent_pos"
141
+ render_mode: str = "rgb_array"
142
+ camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
143
+ observation_height: int = 256
144
+ observation_width: int = 256
145
+ split: str | None = None # `pretrain` | `target` | `all` | None
146
+ ```
147
+
148
+ `features` / `features_map` / `hub_path`๋Š” ๋ณ€๊ฒฝํ•˜์ง€ ์•Š์Œ โ€” STATE_ACTION_SPEC ๊ณ„์•ฝ ์œ ์ง€.
149
+
150
+ ---
151
+
152
+ ## 3. ๊ณต์‹ GR00T eval โ†” ์ˆ˜์ • ํ›„ lerobot-eval ๋Œ€์‘ํ‘œ
153
+
154
+ | ๊ณต์‹ ์ธ์ž (`run_eval.py`) | lerobot-eval ์ธ์ž | ๋น„๊ณ  |
155
+ |--------------------------|-------------------|------|
156
+ | `--model_path` | `--policy.path` | lerobot์€ `PreTrainedPolicy` ๋กœ๋“œ ๊ฐ€๋Šฅํ•ด์•ผ ํ•จ |
157
+ | `--task_set atomic_seen composite_unseen` | `--env.task=atomic_seen composite_unseen` | ์ด๋ฒˆ ์ˆ˜์ •์œผ๋กœ ๋™๋“ฑ |
158
+ | `--split pretrain` | `--env.split=pretrain` | explicit ์šฐ์„ ๋˜๋„๋ก ์ˆ˜์ • |
159
+ | `--n_episodes 50` | `--eval.n_episodes=50` | ๋™์ผ |
160
+ | `--n_envs 5` | `--eval.batch_size=5` | ๋™์ผ |
161
+ | `--video_dir <path>` | `--output_dir <path>/videos` | lerobot์ด `output_dir/videos/{task}_{id}/eval_episode_*.mp4` ์ž๋™ |
162
+ | `--n_action_steps 16` | (ํ•ด๋‹น ์—†์Œ) | lerobot ์ •์ฑ…์ด single-step ์ถœ๋ ฅ. ACT ๋“ฑ lerobot ์ •์ฑ…์€ ๋‚ด๋ถ€ chunk ์ฒ˜๋ฆฌ |
163
+
164
+ GR00T ์ •์ฑ… ์ž์ฒด๋Š” chunk(16-step) ์ถœ๋ ฅ์ด๋ผ lerobot์˜ `select_action` ๋‹จ์ผ ์Šคํ… ํ๋ฆ„๊ณผ ๋‹ค๋ฅด๋‹ค.
165
+ ์ด๊ฑด ์ •์ฑ… ์–ด๋Œ‘ํ„ฐ ์˜์—ญ์ด๋ผ **์ด๋ฒˆ ์ˆ˜์ • ๋ฒ”์œ„ ๋ฐ–** (env๋Š” ์ •์ฑ…-ํ™˜๊ฒฝ ์‚ฌ์ด์˜ 12-dim ๋‹จ์ผ ์Šคํ… ๊ณ„์•ฝ๋งŒ ์ฑ…์ž„์ง).
166
+
167
+ ---
168
+
169
+ ## 4. ์‚ฌ์šฉ ์˜ˆ์‹œ
170
+
171
+ ```bash
172
+ # ๋‹จ์ผ task (sanity)
173
+ lerobot-eval \
174
+ --policy.path=BrunoM42/act_base-robocasa_target_PickPlaceCounterToCabinet \
175
+ --env.type=robocasa \
176
+ --env.task=PickPlaceCounterToCabinet \
177
+ --eval.batch_size=5 --eval.n_episodes=5 \
178
+ --policy.device=cuda --trust_remote_code=true \
179
+ --env.split=pretrain \
180
+ --output_dir /home/seonho/clvla/benchmarks/robocasa365/bench_outputs
181
+
182
+ # ์—ฌ๋Ÿฌ benchmark ๋™์‹œ (= ๊ณต์‹ --task_set atomic_seen composite_unseen composite_seen)
183
+ lerobot-eval \
184
+ --policy.path=<...> \
185
+ --env.type=robocasa \
186
+ --env.task="atomic_seen composite_unseen composite_seen" \
187
+ --eval.batch_size=5 --eval.n_episodes=50 \
188
+ --policy.device=cuda --trust_remote_code=true \
189
+ --env.split=pretrain \
190
+ --output_dir /home/seonho/clvla/benchmarks/robocasa365/bench_outputs
191
+ ```
192
+
193
+ `out[task][0] = VectorEnv(...)` ๊ตฌ์กฐ์ด๋ฏ€๋กœ `eval_info.json`์˜ `per_group`์—๋Š”
194
+ **task๋ณ„ success rate**๊ฐ€ ๊ทธ๋Œ€๋กœ ๋–จ์–ด์ง€๊ณ , `overall`์ด ์ „์ฒด ํ‰๊ท ์ด ๋œ๋‹ค โ€” ๊ณต์‹
195
+ `get_eval_stats.py`๊ฐ€ ๋งŒ๋“œ๋Š” task_set ํ‰๊ท ๊ณผ ๋น„๊ตํ•˜๊ธฐ ์‰ฌ์šด ํ˜•ํƒœ.
196
+
197
+ ---
198
+
199
+ ## 5. ๋ณ€๊ฒฝํ•˜์ง€ ์•Š์€ ๊ฒƒ (์˜๋„)
200
+
201
+ - **State 16-dim / Action 12-dim concat ๊ณ„์•ฝ** โ€” `STATE_ACTION_SPEC.md` ์œ ์ง€
202
+ - `convert_state`, `convert_action`์˜ ์ธ๋ฑ์Šค/ํ‚ค ์ •์˜
203
+ - `hub_path`, `features`, `features_map`
204
+ - `_create_obs_and_action_space`์˜ ์นด๋ฉ”๋ผ/์•ก์…˜ ๋ฐ•์Šค shape
205
+ - `gl_ctx.free()` / `make_current()` ํ˜ธ์ถœ (์ด์œ  ๋ถˆ๋ช…, ์•ˆ์ „ ์šฐ์„  ๋ณด์กด)
206
+
207
+ ## 6. ํ•œ๊ณ„ / ํ›„์† ๊ณผ์ œ
208
+
209
+ - GR00T ์ •์ฑ…์„ lerobot์—์„œ ๊ทธ๋Œ€๋กœ ์“ฐ๋ ค๋ฉด chunkยทhistory๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” **์ •์ฑ… ์–ด๋Œ‘ํ„ฐ**๊ฐ€ ๋ณ„๋„๋กœ ํ•„์š”.
210
+ `lerobot_eval_gap_analysis.md` ยง3.C ์ฐธ๊ณ . ์ด๋ฒˆ ์ˆ˜์ •์€ ํ™˜๊ฒฝ ์ธก๋งŒ.
211
+ - lerobot์˜ `eval_info.json`์„ ๊ณต์‹ `evals/<split>/<env>/stats.json` ํŠธ๋ฆฌ๋กœ
212
+ ๋ณ€ํ™˜ํ•ด์ฃผ๋Š” ์ž‘์€ ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์žˆ์œผ๋ฉด `gr00t/eval/get_eval_stats.py`๋ฅผ ๊ทธ๋Œ€๋กœ ์žฌ์‚ฌ์šฉ ๊ฐ€๋Šฅ
213
+ (`lerobot_eval_gap_analysis.md` ยง3.D).
214
+ - Hub(`Whalswp/RoboCasa_Env`)๋ฅผ ์“ฐ๋Š” ๊ฒฝ์šฐ, **์ด ๋กœ์ปฌ ๋ณ€๊ฒฝ์„ hub์— ํ‘ธ์‹œ**ํ•ด์•ผ lerobot์˜
215
+ `trust_remote_code` ๊ฒฝ๋กœ๊ฐ€ ์ƒˆ ๋ฒ„์ „์„ ๋ฐ›๋Š”๋‹ค. ๋กœ์ปฌ์—์„œ ์ง์ ‘ importํ•˜๋Š” ๊ฒฝ์šฐ์—” ๊ทธ๋Œ€๋กœ ์ ์šฉ๋จ.
STATE_ACTION_SPEC.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RoboCasa State / Action ๋ช…์„ธ
2
+
3
+ > ๊ทผ๊ฑฐ ํŒŒ์ผ: `env.py`, `gym_wrapper.py` (`PandaOmronKeyConverter`), `robosuite/controllers/parts/arm/osc.py`, `robosuite/controllers/config/robots/default_pandaomron.json`
4
+
5
+ ---
6
+
7
+ ## State (์ด 16์ฐจ์›)
8
+
9
+ `env.py: convert_state()` ๊ธฐ์ค€์œผ๋กœ concatenate๋จ.
10
+
11
+ | ์ธ๋ฑ์Šค | ์ฐจ์› | ํ‚ค | absolute / relative | ํ‘œํ˜„ |
12
+ |--------|------|----|---------------------|------|
13
+ | 0~2 | 3 | `state.base_position` | **absolute** | xyz |
14
+ | 3~6 | 4 | `state.base_rotation` | **absolute** | **Quaternion** (`robot0_base_quat`) |
15
+ | 7~9 | 3 | `state.end_effector_position_relative` | **relative** (base โ†’ EE) | xyz |
16
+ | 10~13 | 4 | `state.end_effector_rotation_relative` | **relative** (base โ†’ EE) | **Quaternion** (`robot0_base_to_eef_quat`) |
17
+ | 14~15 | 2 | `state.gripper_qpos` | โ€” | joint position |
18
+
19
+ ---
20
+
21
+ ## Action (์ด 12์ฐจ์›)
22
+
23
+ `env.py: convert_action()` ๊ธฐ์ค€์œผ๋กœ ๋ถ„ํ•ด๋จ.
24
+
25
+ | ์ธ๋ฑ์Šค | ์ฐจ์› | ํ‚ค | ์„ค๋ช… |
26
+ |--------|------|----|------|
27
+ | 0~3 | 4 | `action.base_motion` | ๋ฒ ์ด์Šค ์ด๋™ (์•„๋ž˜ ์ฐธ๊ณ ) |
28
+ | 4 | 1 | `action.control_mode` | ์ œ์–ด ๋ชจ๋“œ ์Šค์œ„์น˜ (์•„๋ž˜ ์ฐธ๊ณ ) |
29
+ | 5~7 | 3 | `action.end_effector_position` | EE delta position, **base frame ๊ธฐ์ค€** |
30
+ | 8~10 | 3 | `action.end_effector_rotation` | EE delta rotation, **base frame ๊ธฐ์ค€, axis-angle** |
31
+ | 11 | 1 | `action.gripper_close` | ๊ทธ๋ฆฌํผ ๋‹ซ๊ธฐ (0.5 threshold โ†’ binary) |
32
+
33
+ ### base_motion (4์ฐจ์›) ์ƒ์„ธ
34
+
35
+ | ์ธ๋ฑ์Šค | ๋Œ€์ƒ | controller type | ์„ค๋ช… |
36
+ |--------|------|-----------------|------|
37
+ | 0~2 | `robot0_base` | `JOINT_VELOCITY` | ๋ชจ๋ฐ”์ผ ๋ฒ ์ด์Šค x์†๋„ / y์†๋„ / yaw์†๋„ |
38
+ | 3 | `robot0_torso` | `JOINT_POSITION` | ๋ชธํ†ต ์ˆ˜์ง ๋ฆฌํ”„ํŠธ joint position (โ‰ˆ ๋†’์ด) |
39
+
40
+ ### control_mode (1์ฐจ์›) ์ƒ์„ธ
41
+
42
+ | ๊ฐ’ | base_mode | ๋™์ž‘ |
43
+ |----|-----------|------|
44
+ | < 0.5 | -1 | **Arm mode** โ€” ๋ฒ ์ด์Šค ๊ณ ์ •, ํŒ”๋กœ ์กฐ์ž‘ (goal: `achieved` ๊ธฐ์ค€) |
45
+ | โ‰ฅ 0.5 | +1 | **Base mode** โ€” ๋ฒ ์ด์Šค ์ด๋™, ํŒ” ๋ชฉํ‘œ ์œ ์ง€ (goal: `desired` ๊ธฐ์ค€) |
46
+
47
+ ---
48
+
49
+ ## EE Rotation: axis-angle์„ ์“ฐ๋Š” ์ด์œ 
50
+
51
+ OSC controller(`osc.py`)๋Š” rotation input์„ `Rotation.from_rotvec()` ์œผ๋กœ ํ•ด์„ โ†’ **axis-angle ๊ณ ์ •**.
52
+
53
+ RPY(Euler angle) ๋Œ€์‹  axis-angle์„ ์“ฐ๋Š” ์ด์œ :
54
+
55
+ 1. **Gimbal lock ์—†์Œ** โ€” RPY๋Š” ํŠน์ • ์ž์„ธ์—์„œ ๋‘ ์ถ•์ด ๊ฒน์ณ DOF๋ฅผ ์žƒ๋Š” singularity ๋ฐœ์ƒ. EE๋Š” ์ž์œ  ํšŒ์ „ํ•˜๋ฏ€๋กœ ์‹ค์ œ ๋ฌธ์ œ๊ฐ€ ๋จ.
56
+ 2. **Delta ์ œ์–ด์— ์ž์—ฐ์Šค๋Ÿฌ์›€** โ€” "์ด ์ถ• ๋ฐฉํ–ฅ์œผ๋กœ ฮธ๋งŒํผ ํšŒ์ „" ์˜๋ฏธ๊ฐ€ ์ง๊ด€์ ์ด๊ณ  ๋ณด๊ฐ„์ด smooth. RPY delta๋Š” ์ˆœ์„œ ์˜์กด์„ฑ(rollโ†’pitchโ†’yaw) ๋•Œ๋ฌธ์— ํ•ฉ์„ฑ์ด ๋ณต์žกํ•จ.
57
+ 3. **ํฌ๊ธฐ = ํšŒ์ „๋Ÿ‰** โ€” ๋ฒกํ„ฐ norm์ด ํšŒ์ „๊ฐ์ด๋ผ output clipping์ด ์ž์—ฐ์Šค๋Ÿฌ์›€. (`output_max: [0.5, 0.5, 0.5]` rad)
58
+
59
+ > RPY ์ž…๋ ฅ์€ ์ฝ”๋“œ์ƒ ์ง€์›ํ•˜์ง€ ์•Š์Œ. ํ•„์š”ํ•˜๋ฉด wrapper์—์„œ ๋ณ€ํ™˜ ํ•„์š”:
60
+ > ```python
61
+ > from scipy.spatial.transform import Rotation
62
+ > axis_angle = Rotation.from_euler('xyz', rpy).as_rotvec()
63
+ > ```
64
+
65
+ ---
66
+
67
+ ## ์‹œ๋ฎฌ๋ ˆ์ด์…˜ ์—ฐ๊ฒฐ ํ๋ฆ„
68
+
69
+ ```
70
+ policy output (12-dim)
71
+ โ†“ convert_action() [env.py]
72
+ action dict (base_motion, control_mode, EE_pos, EE_rot, gripper_close)
73
+ โ†“ unmap_action() [gym_wrapper.py]
74
+ {
75
+ robot0_right: concat(EE_pos[3], EE_rot[3]) โ†’ OSC_POSE controller
76
+ robot0_right_gripper: threshold(gripper_close, 0.5) โ†’ -1 or +1
77
+ robot0_base: base_motion[0:3] โ†’ JOINT_VELOCITY controller
78
+ robot0_torso: base_motion[3:4] โ†’ JOINT_POSITION controller
79
+ robot0_base_mode: threshold(control_mode, 0.5) โ†’ -1 or +1
80
+ }
81
+ โ†“ env.step() [robosuite]
82
+ MuJoCo simulation
83
+ ```
backup/configs.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from dataclasses import dataclass, field, fields
17
+ from typing import Any
18
+
19
+ import draccus
20
+
21
+ from lerobot.configs.types import FeatureType, PolicyFeature
22
+ from lerobot.robots import RobotConfig
23
+ from lerobot.teleoperators.config import TeleoperatorConfig
24
+ from lerobot.utils.constants import (
25
+ ACTION,
26
+ LIBERO_KEY_EEF_MAT,
27
+ LIBERO_KEY_EEF_POS,
28
+ LIBERO_KEY_EEF_QUAT,
29
+ LIBERO_KEY_GRIPPER_QPOS,
30
+ LIBERO_KEY_GRIPPER_QVEL,
31
+ LIBERO_KEY_JOINTS_POS,
32
+ LIBERO_KEY_JOINTS_VEL,
33
+ LIBERO_KEY_PIXELS_AGENTVIEW,
34
+ LIBERO_KEY_PIXELS_EYE_IN_HAND,
35
+ OBS_ENV_STATE,
36
+ OBS_IMAGE,
37
+ OBS_IMAGES,
38
+ OBS_STATE,
39
+ )
40
+
41
+
42
+ @dataclass
43
+ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
44
+ task: str | None = None
45
+ fps: int = 30
46
+ features: dict[str, PolicyFeature] = field(default_factory=dict)
47
+ features_map: dict[str, str] = field(default_factory=dict)
48
+ max_parallel_tasks: int = 1
49
+ disable_env_checker: bool = True
50
+
51
+ @property
52
+ def type(self) -> str:
53
+ return self.get_choice_name(self.__class__)
54
+
55
+ @property
56
+ def package_name(self) -> str:
57
+ """Package name to import if environment not found in gym registry"""
58
+ return f"gym_{self.type}"
59
+
60
+ @property
61
+ def gym_id(self) -> str:
62
+ """ID string used in gym.make() to instantiate the environment"""
63
+ return f"{self.package_name}/{self.task}"
64
+
65
+ @property
66
+ @abc.abstractmethod
67
+ def gym_kwargs(self) -> dict:
68
+ raise NotImplementedError()
69
+
70
+
71
+ @dataclass
72
+ class HubEnvConfig(EnvConfig):
73
+ """Base class for environments that delegate creation to a hub-hosted make_env.
74
+
75
+ Hub environments download and execute remote code from the HF Hub.
76
+ The hub_path points to a repository containing an env.py with a make_env function.
77
+ """
78
+
79
+ hub_path: str | None = None # required: e.g., "username/repo" or "username/repo@branch:file.py"
80
+
81
+ @property
82
+ def gym_kwargs(self) -> dict:
83
+ # Not used for hub environments - the hub's make_env handles everything
84
+ return {}
85
+
86
+
87
+ @EnvConfig.register_subclass("aloha")
88
+ @dataclass
89
+ class AlohaEnv(EnvConfig):
90
+ task: str | None = "AlohaInsertion-v0"
91
+ fps: int = 50
92
+ episode_length: int = 400
93
+ obs_type: str = "pixels_agent_pos"
94
+ observation_height: int = 480
95
+ observation_width: int = 640
96
+ render_mode: str = "rgb_array"
97
+ features: dict[str, PolicyFeature] = field(
98
+ default_factory=lambda: {
99
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
100
+ }
101
+ )
102
+ features_map: dict[str, str] = field(
103
+ default_factory=lambda: {
104
+ ACTION: ACTION,
105
+ "agent_pos": OBS_STATE,
106
+ "top": f"{OBS_IMAGE}.top",
107
+ "pixels/top": f"{OBS_IMAGES}.top",
108
+ }
109
+ )
110
+
111
+ def __post_init__(self):
112
+ if self.obs_type == "pixels":
113
+ self.features["top"] = PolicyFeature(
114
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
115
+ )
116
+ elif self.obs_type == "pixels_agent_pos":
117
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
118
+ self.features["pixels/top"] = PolicyFeature(
119
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
120
+ )
121
+
122
+ @property
123
+ def gym_kwargs(self) -> dict:
124
+ return {
125
+ "obs_type": self.obs_type,
126
+ "render_mode": self.render_mode,
127
+ "max_episode_steps": self.episode_length,
128
+ }
129
+
130
+
131
+ @EnvConfig.register_subclass("pusht")
132
+ @dataclass
133
+ class PushtEnv(EnvConfig):
134
+ task: str | None = "PushT-v0"
135
+ fps: int = 10
136
+ episode_length: int = 300
137
+ obs_type: str = "pixels_agent_pos"
138
+ render_mode: str = "rgb_array"
139
+ visualization_width: int = 384
140
+ visualization_height: int = 384
141
+ observation_height: int = 384
142
+ observation_width: int = 384
143
+ features: dict[str, PolicyFeature] = field(
144
+ default_factory=lambda: {
145
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
146
+ "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
147
+ }
148
+ )
149
+ features_map: dict[str, str] = field(
150
+ default_factory=lambda: {
151
+ ACTION: ACTION,
152
+ "agent_pos": OBS_STATE,
153
+ "environment_state": OBS_ENV_STATE,
154
+ "pixels": OBS_IMAGE,
155
+ }
156
+ )
157
+
158
+ def __post_init__(self):
159
+ if self.obs_type == "pixels_agent_pos":
160
+ self.features["pixels"] = PolicyFeature(
161
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
162
+ )
163
+ elif self.obs_type == "environment_state_agent_pos":
164
+ self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
165
+
166
+ @property
167
+ def gym_kwargs(self) -> dict:
168
+ return {
169
+ "obs_type": self.obs_type,
170
+ "render_mode": self.render_mode,
171
+ "visualization_width": self.visualization_width,
172
+ "visualization_height": self.visualization_height,
173
+ "max_episode_steps": self.episode_length,
174
+ }
175
+
176
+
177
+ @dataclass
178
+ class ImagePreprocessingConfig:
179
+ crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
180
+ resize_size: tuple[int, int] | None = None
181
+
182
+
183
+ @dataclass
184
+ class RewardClassifierConfig:
185
+ """Configuration for reward classification."""
186
+
187
+ pretrained_path: str | None = None
188
+ success_threshold: float = 0.5
189
+ success_reward: float = 1.0
190
+
191
+
192
+ @dataclass
193
+ class InverseKinematicsConfig:
194
+ """Configuration for inverse kinematics processing."""
195
+
196
+ urdf_path: str | None = None
197
+ target_frame_name: str | None = None
198
+ end_effector_bounds: dict[str, list[float]] | None = None
199
+ end_effector_step_sizes: dict[str, float] | None = None
200
+
201
+
202
+ @dataclass
203
+ class ObservationConfig:
204
+ """Configuration for observation processing."""
205
+
206
+ add_joint_velocity_to_observation: bool = False
207
+ add_current_to_observation: bool = False
208
+ add_ee_pose_to_observation: bool = False
209
+ display_cameras: bool = False
210
+
211
+
212
+ @dataclass
213
+ class GripperConfig:
214
+ """Configuration for gripper control and penalties."""
215
+
216
+ use_gripper: bool = True
217
+ gripper_penalty: float = 0.0
218
+
219
+
220
+ @dataclass
221
+ class ResetConfig:
222
+ """Configuration for environment reset behavior."""
223
+
224
+ fixed_reset_joint_positions: Any | None = None
225
+ reset_time_s: float = 5.0
226
+ control_time_s: float = 20.0
227
+ terminate_on_success: bool = True
228
+
229
+
230
+ @dataclass
231
+ class HILSerlProcessorConfig:
232
+ """Configuration for environment processing pipeline."""
233
+
234
+ control_mode: str = "gamepad"
235
+ observation: ObservationConfig | None = None
236
+ image_preprocessing: ImagePreprocessingConfig | None = None
237
+ gripper: GripperConfig | None = None
238
+ reset: ResetConfig | None = None
239
+ inverse_kinematics: InverseKinematicsConfig | None = None
240
+ reward_classifier: RewardClassifierConfig | None = None
241
+ max_gripper_pos: float | None = 100.0
242
+
243
+
244
+ @EnvConfig.register_subclass(name="gym_manipulator")
245
+ @dataclass
246
+ class HILSerlRobotEnvConfig(EnvConfig):
247
+ """Configuration for the HILSerlRobotEnv environment."""
248
+
249
+ robot: RobotConfig | None = None
250
+ teleop: TeleoperatorConfig | None = None
251
+ processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
252
+
253
+ name: str = "real_robot"
254
+
255
+ @property
256
+ def gym_kwargs(self) -> dict:
257
+ return {}
258
+
259
+
260
+ @EnvConfig.register_subclass("libero")
261
+ @dataclass
262
+ class LiberoEnv(EnvConfig):
263
+ task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
264
+ task_ids: list[int] | None = None
265
+ fps: int = 30
266
+ episode_length: int | None = None
267
+ obs_type: str = "pixels_agent_pos"
268
+ render_mode: str = "rgb_array"
269
+ camera_name: str = "agentview_image,robot0_eye_in_hand_image"
270
+ init_states: bool = True
271
+ camera_name_mapping: dict[str, str] | None = None
272
+ observation_height: int = 360
273
+ observation_width: int = 360
274
+ features: dict[str, PolicyFeature] = field(
275
+ default_factory=lambda: {
276
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
277
+ }
278
+ )
279
+ features_map: dict[str, str] = field(
280
+ default_factory=lambda: {
281
+ ACTION: ACTION,
282
+ LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
283
+ LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
284
+ LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
285
+ LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
286
+ LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
287
+ LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
288
+ LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
289
+ LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
290
+ LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
291
+ }
292
+ )
293
+ control_mode: str = "relative" # or "absolute"
294
+
295
+ def __post_init__(self):
296
+ if self.obs_type == "pixels":
297
+ self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
298
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
299
+ )
300
+ self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
301
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
302
+ )
303
+ elif self.obs_type == "pixels_agent_pos":
304
+ self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
305
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
306
+ )
307
+ self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
308
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
309
+ )
310
+ self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
311
+ type=FeatureType.STATE,
312
+ shape=(3,),
313
+ )
314
+ self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
315
+ type=FeatureType.STATE,
316
+ shape=(4,),
317
+ )
318
+ self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
319
+ type=FeatureType.STATE,
320
+ shape=(3, 3),
321
+ )
322
+ self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
323
+ type=FeatureType.STATE,
324
+ shape=(2,),
325
+ )
326
+ self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
327
+ type=FeatureType.STATE,
328
+ shape=(2,),
329
+ )
330
+ self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
331
+ type=FeatureType.STATE,
332
+ shape=(7,),
333
+ )
334
+ self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
335
+ type=FeatureType.STATE,
336
+ shape=(7,),
337
+ )
338
+ else:
339
+ raise ValueError(f"Unsupported obs_type: {self.obs_type}")
340
+
341
+ @property
342
+ def gym_kwargs(self) -> dict:
343
+ kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
344
+ if self.task_ids is not None:
345
+ kwargs["task_ids"] = self.task_ids
346
+ return kwargs
347
+
348
+
349
+ @EnvConfig.register_subclass("metaworld")
350
+ @dataclass
351
+ class MetaworldEnv(EnvConfig):
352
+ task: str = "metaworld-push-v2" # add all tasks
353
+ fps: int = 80
354
+ episode_length: int = 400
355
+ obs_type: str = "pixels_agent_pos"
356
+ render_mode: str = "rgb_array"
357
+ multitask_eval: bool = True
358
+ features: dict[str, PolicyFeature] = field(
359
+ default_factory=lambda: {
360
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
361
+ }
362
+ )
363
+ features_map: dict[str, str] = field(
364
+ default_factory=lambda: {
365
+ "action": ACTION,
366
+ "agent_pos": OBS_STATE,
367
+ "top": f"{OBS_IMAGE}",
368
+ "pixels/top": f"{OBS_IMAGE}",
369
+ }
370
+ )
371
+
372
+ def __post_init__(self):
373
+ if self.obs_type == "pixels":
374
+ self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
375
+
376
+ elif self.obs_type == "pixels_agent_pos":
377
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
378
+ self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
379
+
380
+ else:
381
+ raise ValueError(f"Unsupported obs_type: {self.obs_type}")
382
+
383
+ @property
384
+ def gym_kwargs(self) -> dict:
385
+ return {
386
+ "obs_type": self.obs_type,
387
+ "render_mode": self.render_mode,
388
+ }
389
+
390
+
391
+ @EnvConfig.register_subclass("isaaclab_arena")
392
+ @dataclass
393
+ class IsaaclabArenaEnv(HubEnvConfig):
394
+ hub_path: str = "nvidia/isaaclab-arena-envs"
395
+ episode_length: int = 300
396
+ num_envs: int = 1
397
+ embodiment: str | None = "gr1_pink"
398
+ object: str | None = "power_drill"
399
+ mimic: bool = False
400
+ teleop_device: str | None = None
401
+ seed: int | None = 42
402
+ device: str | None = "cuda:0"
403
+ disable_fabric: bool = False
404
+ enable_cameras: bool = False
405
+ headless: bool = False
406
+ enable_pinocchio: bool = True
407
+ environment: str | None = "gr1_microwave"
408
+ task: str | None = "Reach out to the microwave and open it."
409
+ state_dim: int = 54
410
+ action_dim: int = 36
411
+ camera_height: int = 512
412
+ camera_width: int = 512
413
+ video: bool = False
414
+ video_length: int = 100
415
+ video_interval: int = 200
416
+ # Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos"
417
+ state_keys: str = "robot_joint_pos"
418
+ # Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb"
419
+ # Set to None or "" for environments without cameras
420
+ camera_keys: str | None = None
421
+ features: dict[str, PolicyFeature] = field(default_factory=dict)
422
+ features_map: dict[str, str] = field(default_factory=dict)
423
+ kwargs: dict | None = None
424
+
425
+ def __post_init__(self):
426
+ if self.kwargs:
427
+ # dynamically convert kwargs to fields in the dataclass
428
+ # NOTE! the new fields will not bee seen by the dataclass repr
429
+ field_names = {f.name for f in fields(self)}
430
+ for key, value in self.kwargs.items():
431
+ if key not in field_names and key != "kwargs":
432
+ setattr(self, key, value)
433
+ self.kwargs = None
434
+
435
+ # Set action feature
436
+ self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
437
+ self.features_map[ACTION] = ACTION
438
+
439
+ # Set state feature
440
+ self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,))
441
+ self.features_map[OBS_STATE] = OBS_STATE
442
+
443
+ # Add camera features for each camera key
444
+ if self.enable_cameras and self.camera_keys:
445
+ for cam_key in self.camera_keys.split(","):
446
+ cam_key = cam_key.strip()
447
+ if cam_key:
448
+ self.features[cam_key] = PolicyFeature(
449
+ type=FeatureType.VISUAL,
450
+ shape=(self.camera_height, self.camera_width, 3),
451
+ )
452
+ self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}"
453
+
454
+ @property
455
+ def gym_kwargs(self) -> dict:
456
+ return {}
457
+
458
+
459
+ # ------------------------ Robocasa365 --------------------------------
460
+
461
+ @EnvConfig.register_subclass("robocasa")
462
+ @dataclass
463
+ class RoboCasaEnv(HubEnvConfig):
464
+
465
+ hub_path: str = "Whalswp/RoboCasa_Env"
466
+
467
+ task: str | None = None
468
+ obs_type: str = "pixels_agent_pos"
469
+ render_mode: str = "rgb_array"
470
+ camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
471
+ observation_height: int = 256
472
+ observation_width: int = 256
473
+ split: str | None = None
474
+
475
+ # VLA ๋ชจ๋ธ ๋“ฑ์—์„œ ์‚ฌ์šฉํ•  Observation & Action ๊ทœ๊ฒฉ ๋งคํ•‘
476
+ features: dict[str, PolicyFeature] = field(default_factory=lambda: {
477
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
478
+ "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(16,)),
479
+ "pixels/robot0_agentview_left": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
480
+ "pixels/robot0_agentview_right": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
481
+ "pixels/robot0_eye_in_hand": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
482
+ })
483
+ features_map: dict[str, str] = field(default_factory=lambda: {
484
+ ACTION: ACTION,
485
+ "agent_pos": OBS_STATE,
486
+ "pixels/robot0_agentview_left": f"{OBS_IMAGES}.robot0_agentview_left",
487
+ "pixels/robot0_agentview_right": f"{OBS_IMAGES}.robot0_agentview_right",
488
+ "pixels/robot0_eye_in_hand": f"{OBS_IMAGES}.robot0_eye_in_hand",
489
+ })
backup/env.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # env.py
2
+ import gymnasium as gym
3
+ from gymnasium import spaces
4
+ import numpy as np
5
+ from collections import defaultdict
6
+ from collections.abc import Callable, Sequence, Mapping
7
+ from functools import partial
8
+ from typing import Any
9
+
10
+ # RoboCasa ์ „์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
11
+ from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
12
+ from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS
13
+
14
+ OBS_STATE_DIM = 16
15
+ ACTION_DIM = 12
16
+ ACTION_LOW = -1.0
17
+ ACTION_HIGH = 1.0
18
+
19
+ def convert_state(dict_state):
20
+ """์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ ์ƒํƒœ๋ฅผ LeRobot์ด ๊ธฐ๋Œ€ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜(Conversion)ํ•ฉ๋‹ˆ๋‹ค."""
21
+ dict_state = dict_state.copy()
22
+ final_state = np.concatenate([
23
+ dict_state["state.base_position"],
24
+ dict_state["state.base_rotation"],
25
+ dict_state["state.end_effector_position_relative"],
26
+ dict_state["state.end_effector_rotation_relative"],
27
+ dict_state["state.gripper_qpos"],
28
+ ], axis=0)
29
+ return final_state
30
+
31
+ def convert_action(action):
32
+ """LeRobot์˜ ์•ก์…˜์„ ์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ๊ฐ€ ์ดํ•ดํ•˜๋Š” dict ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
33
+ action = action.copy()
34
+ output_action = {
35
+ "action.base_motion": action[0:4],
36
+ "action.control_mode": action[4:5],
37
+ "action.end_effector_position": action[5:8],
38
+ "action.end_effector_rotation": action[8:11],
39
+ "action.gripper_close": action[11:12],
40
+ }
41
+ return output_action
42
+
43
+ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
44
+ """์นด๋ฉ”๋ผ ์ด๋ฆ„์„ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ •๊ทœํ™”(Normalization)ํ•ฉ๋‹ˆ๋‹ค."""
45
+ if isinstance(camera_name, str):
46
+ cams = [c.strip() for c in camera_name.split(",") if c.strip()]
47
+ elif isinstance(camera_name, (list, tuple)):
48
+ cams = [str(c).strip() for c in camera_name if str(c).strip()]
49
+ else:
50
+ raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
51
+ if not cams:
52
+ raise ValueError("camera_name resolved to an empty list.")
53
+ return cams
54
+
55
+ class RoboCasaEnv(RoboCasaGymEnv):
56
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
57
+
58
+ def __init__(
59
+ self,
60
+ task: str,
61
+ camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
62
+ render_mode: str = "rgb_array",
63
+ obs_type: str = "pixels_agent_pos",
64
+ observation_width: int = 256,
65
+ observation_height: int = 256,
66
+ split: str | None = None,
67
+ **kwargs
68
+ ):
69
+ self.obs_type = obs_type
70
+ self.render_mode = render_mode
71
+ self.split = split
72
+ self.task = task
73
+ self._task_description = ""
74
+
75
+ kwargs.pop("fps", None)
76
+ self.kwargs = kwargs
77
+
78
+ meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
79
+ try:
80
+ self._max_episode_steps = meta_info[task]['horizon']
81
+ except KeyError:
82
+ raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
83
+
84
+ super().__init__(
85
+ task,
86
+ camera_names=camera_name,
87
+ camera_widths=observation_width,
88
+ camera_heights=observation_height,
89
+ enable_render=(render_mode is not None),
90
+ split=split,
91
+ **kwargs
92
+ )
93
+
94
+ def _create_obs_and_action_space(self):
95
+ images = {}
96
+ for cam in self.camera_names:
97
+ images[cam] = spaces.Box(
98
+ low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8
99
+ )
100
+ if self.obs_type == "state":
101
+ raise NotImplementedError("The 'state' observation type is not supported.")
102
+ elif self.obs_type == "pixels":
103
+ self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
104
+ elif self.obs_type == "pixels_agent_pos":
105
+ self.observation_space = spaces.Dict({
106
+ "pixels": spaces.Dict(images),
107
+ "agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
108
+ })
109
+ else:
110
+ raise ValueError(f"Unknown obs_type: {self.obs_type}")
111
+
112
+ self.action_space = spaces.Box(
113
+ low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32
114
+ )
115
+
116
+ @property
117
+ def task_description(self) -> str:
118
+ return self._task_description
119
+
120
+ def reset(self, seed: int | None = None, **kwargs):
121
+ self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
122
+ observation, info = super().reset(seed, **kwargs)
123
+ self._task_description = self.env.get_ep_meta().get("lang", self.task)
124
+ print(f"[RoboCasaEnv] task_description: {self._task_description!r}")
125
+ return self._format_raw_obs(observation), info
126
+
127
+ def _format_raw_obs(self, raw_obs: dict):
128
+ new_obs = {}
129
+ if self.obs_type == "pixels_agent_pos":
130
+ new_obs["agent_pos"] = convert_state(raw_obs)
131
+ new_obs["pixels"] = {}
132
+ for k, v in raw_obs.items():
133
+ if "video." in k:
134
+ new_obs["pixels"][k.replace("video.", "")] = v
135
+ return new_obs
136
+
137
+ def step(self, action: np.ndarray):
138
+ self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
139
+ action_dict = convert_action(action)
140
+ observation, reward, done, truncated, info = super().step(action_dict)
141
+ new_obs = self._format_raw_obs(observation)
142
+
143
+ is_success = bool(info.get("success", 0))
144
+ terminated = done or is_success
145
+ info.update({"task": self.task, "done": done, "is_success": is_success})
146
+
147
+ if terminated:
148
+ info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
149
+ self.reset()
150
+
151
+ return new_obs, reward, terminated, truncated, info
152
+
153
+ def render(self):
154
+ frame = super().render()
155
+ if frame is None:
156
+ return frame
157
+ from PIL import Image, ImageDraw, ImageFont
158
+ import textwrap
159
+
160
+ text = self._task_description or self.task
161
+ w = frame.shape[1]
162
+
163
+ try:
164
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
165
+ except Exception:
166
+ font = ImageFont.load_default()
167
+
168
+ lines = textwrap.wrap(text, width=55)
169
+ line_h = 18
170
+ bar_h = len(lines) * line_h + 10
171
+
172
+ bar = Image.new("RGB", (w, bar_h), color=(30, 30, 30))
173
+ draw = ImageDraw.Draw(bar)
174
+ for i, line in enumerate(lines):
175
+ draw.text((8, 5 + i * line_h), line, font=font, fill=(220, 220, 220))
176
+
177
+ return np.concatenate([frame, np.array(bar)], axis=0)
178
+
179
+
180
+ def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]):
181
+ def _make_env(episode_index: int, **kwargs):
182
+ seed = kwargs.pop("seed", episode_index)
183
+ return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs)
184
+
185
+ return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
186
+
187
+
188
+ # ======================================================================
189
+ # LeRobot Hub ํ•„์ˆ˜ ์ง„์ž…์ (Entry Point)
190
+ # ======================================================================
191
+ def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
192
+ """
193
+ LeRobot์ด Hub์—์„œ ํ™˜๊ฒฝ์„ ๋กœ๋“œํ•  ๋•Œ ํ˜ธ์ถœํ•˜๋Š” ๋ฉ”์ธ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
194
+ """
195
+ # ํ™˜๊ฒฝ ๋ž˜ํผ ํด๋ž˜์Šค ์„ ํƒ
196
+ env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
197
+
198
+ # ์„ค์ •๊ฐ’ ์ถ”์ถœ (cfg ๊ฐ์ฒด๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์ ์šฉ)
199
+ if cfg is not None:
200
+ task_name = getattr(cfg, "task", "CloseFridge")
201
+ fps = getattr(cfg, "fps", 20) # fps ์ถ”์ถœ
202
+ gym_kwargs = {
203
+ "obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
204
+ "render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์œ ์ง€
205
+ "observation_width": getattr(cfg, "observation_width", 256),
206
+ "observation_height": getattr(cfg, "observation_height", 256),
207
+ "camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
208
+ "split": getattr(cfg, "split", None),
209
+ "fps": fps, # ํ•ต์‹ฌ ์ธ์ž ๋ˆ„๋ฝ ๋ฐฉ์ง€
210
+ }
211
+ else:
212
+ # cfg ์—†์ด ์ง์ ‘ ํ˜ธ์ถœ๋  ๋•Œ์˜ ๊ธฐ๋ณธ๊ฐ’
213
+ task_name = "CloseFridge"
214
+ gym_kwargs = {
215
+ "obs_type": "pixels_agent_pos",
216
+ "render_mode": "rgb_array",
217
+ "observation_width": 256,
218
+ "observation_height": 256,
219
+ "camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
220
+ "split": None,
221
+ }
222
+
223
+ parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
224
+ combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
225
+
226
+ # ๋ฒค์น˜๋งˆํฌ์ธ์ง€ ๋‹จ์ผ ํƒœ์Šคํฌ์ธ์ง€ ๊ตฌ๋ถ„
227
+ if task_name in combined_tasks:
228
+ task_names = combined_tasks[task_name]
229
+ gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
230
+ else:
231
+ task_names = [t.strip() for t in task_name.split(",")]
232
+
233
+
234
+ out = defaultdict(dict)
235
+
236
+ # ํƒœ์Šคํฌ๋ณ„๋กœ ํ™˜๊ฒฝ ์ƒ์„ฑ
237
+ for task in task_names:
238
+ fns = _make_env_fns(
239
+ task_name=task,
240
+ n_envs=n_envs,
241
+ camera_names=parsed_camera_names,
242
+ gym_kwargs=gym_kwargs
243
+ )
244
+ out[task][0] = env_cls(fns)
245
+
246
+ # {suite_name: {task_id: VectorEnv}} ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜
247
+ #return {"robocasa": dict(out)}
248
+ return {suite: dict(task_map) for suite, task_map in out.items()}
configs.py CHANGED
@@ -462,14 +462,19 @@ class IsaaclabArenaEnv(HubEnvConfig):
462
  @dataclass
463
  class RoboCasaEnv(HubEnvConfig):
464
 
465
- hub_path: str = "Whalswp/RoboCasa_Env"
466
-
467
- task: str | None = None
 
 
 
 
468
  obs_type: str = "pixels_agent_pos"
469
  render_mode: str = "rgb_array"
470
  camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
471
  observation_height: int = 256
472
  observation_width: int = 256
 
473
  split: str | None = None
474
 
475
  # VLA ๋ชจ๋ธ ๋“ฑ์—์„œ ์‚ฌ์šฉํ•  Observation & Action ๊ทœ๊ฒฉ ๋งคํ•‘
 
462
  @dataclass
463
  class RoboCasaEnv(HubEnvConfig):
464
 
465
+ hub_path: str = "Whalswp/RoboCasa_Env"
466
+
467
+ # ๋‹จ์ผ task ์ด๋ฆ„ ๋˜๋Š” benchmark ํ‚ค(`atomic_seen`, `composite_unseen`, ...).
468
+ # ๊ณต์‹ GR00T eval์ฒ˜๋Ÿผ ์—ฌ๋Ÿฌ ๊ฐœ๋ฅผ ๋™์‹œ์— ๋ฐ›์„ ์ˆ˜ ์žˆ๋„๋ก list๋„ ํ—ˆ์šฉ.
469
+ # ์˜ˆ: --env.task=atomic_seen composite_unseen composite_seen
470
+ task: str | list[str] | None = None
471
+ fps: int = 20
472
  obs_type: str = "pixels_agent_pos"
473
  render_mode: str = "rgb_array"
474
  camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
475
  observation_height: int = 256
476
  observation_width: int = 256
477
+ # `pretrain` | `target` | `all` | None (None์ด๋ฉด task๋กœ๋ถ€ํ„ฐ ์ถ”๋ก )
478
  split: str | None = None
479
 
480
  # VLA ๋ชจ๋ธ ๋“ฑ์—์„œ ์‚ฌ์šฉํ•  Observation & Action ๊ทœ๊ฒฉ ๋งคํ•‘
configs.py.bak ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from dataclasses import dataclass, field, fields
17
+ from typing import Any
18
+
19
+ import draccus
20
+
21
+ from lerobot.configs.types import FeatureType, PolicyFeature
22
+ from lerobot.robots import RobotConfig
23
+ from lerobot.teleoperators.config import TeleoperatorConfig
24
+ from lerobot.utils.constants import (
25
+ ACTION,
26
+ LIBERO_KEY_EEF_MAT,
27
+ LIBERO_KEY_EEF_POS,
28
+ LIBERO_KEY_EEF_QUAT,
29
+ LIBERO_KEY_GRIPPER_QPOS,
30
+ LIBERO_KEY_GRIPPER_QVEL,
31
+ LIBERO_KEY_JOINTS_POS,
32
+ LIBERO_KEY_JOINTS_VEL,
33
+ LIBERO_KEY_PIXELS_AGENTVIEW,
34
+ LIBERO_KEY_PIXELS_EYE_IN_HAND,
35
+ OBS_ENV_STATE,
36
+ OBS_IMAGE,
37
+ OBS_IMAGES,
38
+ OBS_STATE,
39
+ )
40
+
41
+
42
+ @dataclass
43
+ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
44
+ task: str | None = None
45
+ fps: int = 30
46
+ features: dict[str, PolicyFeature] = field(default_factory=dict)
47
+ features_map: dict[str, str] = field(default_factory=dict)
48
+ max_parallel_tasks: int = 1
49
+ disable_env_checker: bool = True
50
+
51
+ @property
52
+ def type(self) -> str:
53
+ return self.get_choice_name(self.__class__)
54
+
55
+ @property
56
+ def package_name(self) -> str:
57
+ """Package name to import if environment not found in gym registry"""
58
+ return f"gym_{self.type}"
59
+
60
+ @property
61
+ def gym_id(self) -> str:
62
+ """ID string used in gym.make() to instantiate the environment"""
63
+ return f"{self.package_name}/{self.task}"
64
+
65
+ @property
66
+ @abc.abstractmethod
67
+ def gym_kwargs(self) -> dict:
68
+ raise NotImplementedError()
69
+
70
+
71
+ @dataclass
72
+ class HubEnvConfig(EnvConfig):
73
+ """Base class for environments that delegate creation to a hub-hosted make_env.
74
+
75
+ Hub environments download and execute remote code from the HF Hub.
76
+ The hub_path points to a repository containing an env.py with a make_env function.
77
+ """
78
+
79
+ hub_path: str | None = None # required: e.g., "username/repo" or "username/repo@branch:file.py"
80
+
81
+ @property
82
+ def gym_kwargs(self) -> dict:
83
+ # Not used for hub environments - the hub's make_env handles everything
84
+ return {}
85
+
86
+
87
+ @EnvConfig.register_subclass("aloha")
88
+ @dataclass
89
+ class AlohaEnv(EnvConfig):
90
+ task: str | None = "AlohaInsertion-v0"
91
+ fps: int = 50
92
+ episode_length: int = 400
93
+ obs_type: str = "pixels_agent_pos"
94
+ observation_height: int = 480
95
+ observation_width: int = 640
96
+ render_mode: str = "rgb_array"
97
+ features: dict[str, PolicyFeature] = field(
98
+ default_factory=lambda: {
99
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
100
+ }
101
+ )
102
+ features_map: dict[str, str] = field(
103
+ default_factory=lambda: {
104
+ ACTION: ACTION,
105
+ "agent_pos": OBS_STATE,
106
+ "top": f"{OBS_IMAGE}.top",
107
+ "pixels/top": f"{OBS_IMAGES}.top",
108
+ }
109
+ )
110
+
111
+ def __post_init__(self):
112
+ if self.obs_type == "pixels":
113
+ self.features["top"] = PolicyFeature(
114
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
115
+ )
116
+ elif self.obs_type == "pixels_agent_pos":
117
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
118
+ self.features["pixels/top"] = PolicyFeature(
119
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
120
+ )
121
+
122
+ @property
123
+ def gym_kwargs(self) -> dict:
124
+ return {
125
+ "obs_type": self.obs_type,
126
+ "render_mode": self.render_mode,
127
+ "max_episode_steps": self.episode_length,
128
+ }
129
+
130
+
131
+ @EnvConfig.register_subclass("pusht")
132
+ @dataclass
133
+ class PushtEnv(EnvConfig):
134
+ task: str | None = "PushT-v0"
135
+ fps: int = 10
136
+ episode_length: int = 300
137
+ obs_type: str = "pixels_agent_pos"
138
+ render_mode: str = "rgb_array"
139
+ visualization_width: int = 384
140
+ visualization_height: int = 384
141
+ observation_height: int = 384
142
+ observation_width: int = 384
143
+ features: dict[str, PolicyFeature] = field(
144
+ default_factory=lambda: {
145
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
146
+ "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
147
+ }
148
+ )
149
+ features_map: dict[str, str] = field(
150
+ default_factory=lambda: {
151
+ ACTION: ACTION,
152
+ "agent_pos": OBS_STATE,
153
+ "environment_state": OBS_ENV_STATE,
154
+ "pixels": OBS_IMAGE,
155
+ }
156
+ )
157
+
158
+ def __post_init__(self):
159
+ if self.obs_type == "pixels_agent_pos":
160
+ self.features["pixels"] = PolicyFeature(
161
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
162
+ )
163
+ elif self.obs_type == "environment_state_agent_pos":
164
+ self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
165
+
166
+ @property
167
+ def gym_kwargs(self) -> dict:
168
+ return {
169
+ "obs_type": self.obs_type,
170
+ "render_mode": self.render_mode,
171
+ "visualization_width": self.visualization_width,
172
+ "visualization_height": self.visualization_height,
173
+ "max_episode_steps": self.episode_length,
174
+ }
175
+
176
+
177
+ @dataclass
178
+ class ImagePreprocessingConfig:
179
+ crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
180
+ resize_size: tuple[int, int] | None = None
181
+
182
+
183
+ @dataclass
184
+ class RewardClassifierConfig:
185
+ """Configuration for reward classification."""
186
+
187
+ pretrained_path: str | None = None
188
+ success_threshold: float = 0.5
189
+ success_reward: float = 1.0
190
+
191
+
192
+ @dataclass
193
+ class InverseKinematicsConfig:
194
+ """Configuration for inverse kinematics processing."""
195
+
196
+ urdf_path: str | None = None
197
+ target_frame_name: str | None = None
198
+ end_effector_bounds: dict[str, list[float]] | None = None
199
+ end_effector_step_sizes: dict[str, float] | None = None
200
+
201
+
202
+ @dataclass
203
+ class ObservationConfig:
204
+ """Configuration for observation processing."""
205
+
206
+ add_joint_velocity_to_observation: bool = False
207
+ add_current_to_observation: bool = False
208
+ add_ee_pose_to_observation: bool = False
209
+ display_cameras: bool = False
210
+
211
+
212
+ @dataclass
213
+ class GripperConfig:
214
+ """Configuration for gripper control and penalties."""
215
+
216
+ use_gripper: bool = True
217
+ gripper_penalty: float = 0.0
218
+
219
+
220
+ @dataclass
221
+ class ResetConfig:
222
+ """Configuration for environment reset behavior."""
223
+
224
+ fixed_reset_joint_positions: Any | None = None
225
+ reset_time_s: float = 5.0
226
+ control_time_s: float = 20.0
227
+ terminate_on_success: bool = True
228
+
229
+
230
+ @dataclass
231
+ class HILSerlProcessorConfig:
232
+ """Configuration for environment processing pipeline."""
233
+
234
+ control_mode: str = "gamepad"
235
+ observation: ObservationConfig | None = None
236
+ image_preprocessing: ImagePreprocessingConfig | None = None
237
+ gripper: GripperConfig | None = None
238
+ reset: ResetConfig | None = None
239
+ inverse_kinematics: InverseKinematicsConfig | None = None
240
+ reward_classifier: RewardClassifierConfig | None = None
241
+ max_gripper_pos: float | None = 100.0
242
+
243
+
244
+ @EnvConfig.register_subclass(name="gym_manipulator")
245
+ @dataclass
246
+ class HILSerlRobotEnvConfig(EnvConfig):
247
+ """Configuration for the HILSerlRobotEnv environment."""
248
+
249
+ robot: RobotConfig | None = None
250
+ teleop: TeleoperatorConfig | None = None
251
+ processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
252
+
253
+ name: str = "real_robot"
254
+
255
+ @property
256
+ def gym_kwargs(self) -> dict:
257
+ return {}
258
+
259
+
260
+ @EnvConfig.register_subclass("libero")
261
+ @dataclass
262
+ class LiberoEnv(EnvConfig):
263
+ task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
264
+ task_ids: list[int] | None = None
265
+ fps: int = 30
266
+ episode_length: int | None = None
267
+ obs_type: str = "pixels_agent_pos"
268
+ render_mode: str = "rgb_array"
269
+ camera_name: str = "agentview_image,robot0_eye_in_hand_image"
270
+ init_states: bool = True
271
+ camera_name_mapping: dict[str, str] | None = None
272
+ observation_height: int = 360
273
+ observation_width: int = 360
274
+ features: dict[str, PolicyFeature] = field(
275
+ default_factory=lambda: {
276
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
277
+ }
278
+ )
279
+ features_map: dict[str, str] = field(
280
+ default_factory=lambda: {
281
+ ACTION: ACTION,
282
+ LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
283
+ LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
284
+ LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
285
+ LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
286
+ LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
287
+ LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
288
+ LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
289
+ LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
290
+ LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
291
+ }
292
+ )
293
+ control_mode: str = "relative" # or "absolute"
294
+
295
+ def __post_init__(self):
296
+ if self.obs_type == "pixels":
297
+ self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
298
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
299
+ )
300
+ self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
301
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
302
+ )
303
+ elif self.obs_type == "pixels_agent_pos":
304
+ self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
305
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
306
+ )
307
+ self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
308
+ type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
309
+ )
310
+ self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
311
+ type=FeatureType.STATE,
312
+ shape=(3,),
313
+ )
314
+ self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
315
+ type=FeatureType.STATE,
316
+ shape=(4,),
317
+ )
318
+ self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
319
+ type=FeatureType.STATE,
320
+ shape=(3, 3),
321
+ )
322
+ self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
323
+ type=FeatureType.STATE,
324
+ shape=(2,),
325
+ )
326
+ self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
327
+ type=FeatureType.STATE,
328
+ shape=(2,),
329
+ )
330
+ self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
331
+ type=FeatureType.STATE,
332
+ shape=(7,),
333
+ )
334
+ self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
335
+ type=FeatureType.STATE,
336
+ shape=(7,),
337
+ )
338
+ else:
339
+ raise ValueError(f"Unsupported obs_type: {self.obs_type}")
340
+
341
+ @property
342
+ def gym_kwargs(self) -> dict:
343
+ kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
344
+ if self.task_ids is not None:
345
+ kwargs["task_ids"] = self.task_ids
346
+ return kwargs
347
+
348
+
349
+ @EnvConfig.register_subclass("metaworld")
350
+ @dataclass
351
+ class MetaworldEnv(EnvConfig):
352
+ task: str = "metaworld-push-v2" # add all tasks
353
+ fps: int = 80
354
+ episode_length: int = 400
355
+ obs_type: str = "pixels_agent_pos"
356
+ render_mode: str = "rgb_array"
357
+ multitask_eval: bool = True
358
+ features: dict[str, PolicyFeature] = field(
359
+ default_factory=lambda: {
360
+ "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
361
+ }
362
+ )
363
+ features_map: dict[str, str] = field(
364
+ default_factory=lambda: {
365
+ "action": ACTION,
366
+ "agent_pos": OBS_STATE,
367
+ "top": f"{OBS_IMAGE}",
368
+ "pixels/top": f"{OBS_IMAGE}",
369
+ }
370
+ )
371
+
372
+ def __post_init__(self):
373
+ if self.obs_type == "pixels":
374
+ self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
375
+
376
+ elif self.obs_type == "pixels_agent_pos":
377
+ self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
378
+ self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
379
+
380
+ else:
381
+ raise ValueError(f"Unsupported obs_type: {self.obs_type}")
382
+
383
+ @property
384
+ def gym_kwargs(self) -> dict:
385
+ return {
386
+ "obs_type": self.obs_type,
387
+ "render_mode": self.render_mode,
388
+ }
389
+
390
+
391
+ @EnvConfig.register_subclass("isaaclab_arena")
392
+ @dataclass
393
+ class IsaaclabArenaEnv(HubEnvConfig):
394
+ hub_path: str = "nvidia/isaaclab-arena-envs"
395
+ episode_length: int = 300
396
+ num_envs: int = 1
397
+ embodiment: str | None = "gr1_pink"
398
+ object: str | None = "power_drill"
399
+ mimic: bool = False
400
+ teleop_device: str | None = None
401
+ seed: int | None = 42
402
+ device: str | None = "cuda:0"
403
+ disable_fabric: bool = False
404
+ enable_cameras: bool = False
405
+ headless: bool = False
406
+ enable_pinocchio: bool = True
407
+ environment: str | None = "gr1_microwave"
408
+ task: str | None = "Reach out to the microwave and open it."
409
+ state_dim: int = 54
410
+ action_dim: int = 36
411
+ camera_height: int = 512
412
+ camera_width: int = 512
413
+ video: bool = False
414
+ video_length: int = 100
415
+ video_interval: int = 200
416
+ # Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos"
417
+ state_keys: str = "robot_joint_pos"
418
+ # Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb"
419
+ # Set to None or "" for environments without cameras
420
+ camera_keys: str | None = None
421
+ features: dict[str, PolicyFeature] = field(default_factory=dict)
422
+ features_map: dict[str, str] = field(default_factory=dict)
423
+ kwargs: dict | None = None
424
+
425
+ def __post_init__(self):
426
+ if self.kwargs:
427
+ # dynamically convert kwargs to fields in the dataclass
428
+ # NOTE! the new fields will not bee seen by the dataclass repr
429
+ field_names = {f.name for f in fields(self)}
430
+ for key, value in self.kwargs.items():
431
+ if key not in field_names and key != "kwargs":
432
+ setattr(self, key, value)
433
+ self.kwargs = None
434
+
435
+ # Set action feature
436
+ self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))
437
+ self.features_map[ACTION] = ACTION
438
+
439
+ # Set state feature
440
+ self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,))
441
+ self.features_map[OBS_STATE] = OBS_STATE
442
+
443
+ # Add camera features for each camera key
444
+ if self.enable_cameras and self.camera_keys:
445
+ for cam_key in self.camera_keys.split(","):
446
+ cam_key = cam_key.strip()
447
+ if cam_key:
448
+ self.features[cam_key] = PolicyFeature(
449
+ type=FeatureType.VISUAL,
450
+ shape=(self.camera_height, self.camera_width, 3),
451
+ )
452
+ self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}"
453
+
454
+ @property
455
+ def gym_kwargs(self) -> dict:
456
+ return {}
457
+
458
+
459
+ # ------------------------ Robocasa365 --------------------------------
460
+
461
+ @EnvConfig.register_subclass("robocasa")
462
+ @dataclass
463
+ class RoboCasaEnv(HubEnvConfig):
464
+
465
+ hub_path: str = "Whalswp/RoboCasa_Env"
466
+
467
+ task: str | None = None
468
+ obs_type: str = "pixels_agent_pos"
469
+ render_mode: str = "rgb_array"
470
+ camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
471
+ observation_height: int = 256
472
+ observation_width: int = 256
473
+ split: str | None = None
474
+
475
+ # VLA ๋ชจ๋ธ ๋“ฑ์—์„œ ์‚ฌ์šฉํ•  Observation & Action ๊ทœ๊ฒฉ ๋งคํ•‘
476
+ features: dict[str, PolicyFeature] = field(default_factory=lambda: {
477
+ ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
478
+ "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(16,)),
479
+ "pixels/robot0_agentview_left": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
480
+ "pixels/robot0_agentview_right": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
481
+ "pixels/robot0_eye_in_hand": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
482
+ })
483
+ features_map: dict[str, str] = field(default_factory=lambda: {
484
+ ACTION: ACTION,
485
+ "agent_pos": OBS_STATE,
486
+ "pixels/robot0_agentview_left": f"{OBS_IMAGES}.robot0_agentview_left",
487
+ "pixels/robot0_agentview_right": f"{OBS_IMAGES}.robot0_agentview_right",
488
+ "pixels/robot0_eye_in_hand": f"{OBS_IMAGES}.robot0_eye_in_hand",
489
+ })
env.py CHANGED
@@ -3,21 +3,30 @@ import gymnasium as gym
3
  from gymnasium import spaces
4
  import numpy as np
5
  from collections import defaultdict
6
- from collections.abc import Callable, Sequence, Mapping
7
  from functools import partial
8
  from typing import Any
9
 
10
  # RoboCasa ์ „์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
11
  from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
12
- from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS
 
 
 
 
 
 
13
 
14
  OBS_STATE_DIM = 16
15
  ACTION_DIM = 12
16
  ACTION_LOW = -1.0
17
  ACTION_HIGH = 1.0
18
 
19
- def convert_state(dict_state):
20
- """์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ ์ƒํƒœ๋ฅผ LeRobot์ด ๊ธฐ๋Œ€ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜(Conversion)ํ•ฉ๋‹ˆ๋‹ค."""
 
 
 
21
  dict_state = dict_state.copy()
22
  final_state = np.concatenate([
23
  dict_state["state.base_position"],
@@ -26,11 +35,14 @@ def convert_state(dict_state):
26
  dict_state["state.end_effector_rotation_relative"],
27
  dict_state["state.gripper_qpos"],
28
  ], axis=0)
29
- return final_state
 
30
 
31
  def convert_action(action):
32
- """LeRobot์˜ ์•ก์…˜์„ ์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ๊ฐ€ ์ดํ•ดํ•˜๋Š” dict ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
33
- action = action.copy()
 
 
34
  output_action = {
35
  "action.base_motion": action[0:4],
36
  "action.control_mode": action[4:5],
@@ -40,8 +52,8 @@ def convert_action(action):
40
  }
41
  return output_action
42
 
43
- def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
44
- """์นด๋ฉ”๋ผ ์ด๋ฆ„์„ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ •๊ทœํ™”(Normalization)ํ•ฉ๋‹ˆ๋‹ค."""
45
  if isinstance(camera_name, str):
46
  cams = [c.strip() for c in camera_name.split(",") if c.strip()]
47
  elif isinstance(camera_name, (list, tuple)):
@@ -52,45 +64,71 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
52
  raise ValueError("camera_name resolved to an empty list.")
53
  return cams
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  class RoboCasaEnv(RoboCasaGymEnv):
56
  metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
57
 
58
  def __init__(
59
  self,
60
  task: str,
61
- camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
62
  render_mode: str = "rgb_array",
63
  obs_type: str = "pixels_agent_pos",
64
  observation_width: int = 256,
65
  observation_height: int = 256,
66
  split: str | None = None,
67
- **kwargs
68
  ):
69
  self.obs_type = obs_type
70
  self.render_mode = render_mode
71
  self.split = split
72
  self.task = task
73
- self._task_description = ""
74
-
75
  kwargs.pop("fps", None)
76
  self.kwargs = kwargs
77
 
78
- meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
79
  try:
80
- self._max_episode_steps = meta_info[task]['horizon']
81
- except KeyError:
82
- raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
83
-
 
84
  super().__init__(
85
  task,
86
- camera_names=camera_name,
87
  camera_widths=observation_width,
88
  camera_heights=observation_height,
89
  enable_render=(render_mode is not None),
90
  split=split,
91
- **kwargs
92
  )
93
-
94
  def _create_obs_and_action_space(self):
95
  images = {}
96
  for cam in self.camera_names:
@@ -100,11 +138,15 @@ class RoboCasaEnv(RoboCasaGymEnv):
100
  if self.obs_type == "state":
101
  raise NotImplementedError("The 'state' observation type is not supported.")
102
  elif self.obs_type == "pixels":
103
- self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
 
 
 
104
  elif self.obs_type == "pixels_agent_pos":
105
  self.observation_space = spaces.Dict({
106
  "pixels": spaces.Dict(images),
107
  "agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
 
108
  })
109
  else:
110
  raise ValueError(f"Unknown obs_type: {self.obs_type}")
@@ -117,38 +159,75 @@ class RoboCasaEnv(RoboCasaGymEnv):
117
  def task_description(self) -> str:
118
  return self._task_description
119
 
120
- def reset(self, seed: int | None = None, **kwargs):
121
- self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
122
- observation, info = super().reset(seed, **kwargs)
123
- self._task_description = self.env.get_ep_meta().get("lang", self.task)
124
- print(f"[RoboCasaEnv] task_description: {self._task_description!r}")
125
- return self._format_raw_obs(observation), info
126
-
127
- def _format_raw_obs(self, raw_obs: dict):
128
- new_obs = {}
129
  if self.obs_type == "pixels_agent_pos":
130
  new_obs["agent_pos"] = convert_state(raw_obs)
131
  new_obs["pixels"] = {}
132
  for k, v in raw_obs.items():
133
- if "video." in k:
134
  new_obs["pixels"][k.replace("video.", "")] = v
 
 
 
 
 
 
 
135
  return new_obs
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def step(self, action: np.ndarray):
138
- self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
 
 
 
 
139
  action_dict = convert_action(action)
140
  observation, reward, done, truncated, info = super().step(action_dict)
141
  new_obs = self._format_raw_obs(observation)
142
 
143
  is_success = bool(info.get("success", 0))
144
- terminated = done or is_success
145
- info.update({"task": self.task, "done": done, "is_success": is_success})
146
-
147
- if terminated:
148
- info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
149
- self.reset()
150
-
151
- return new_obs, reward, terminated, truncated, info
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def render(self):
154
  frame = super().render()
@@ -185,71 +264,99 @@ def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwar
185
  return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  # ======================================================================
189
  # LeRobot Hub ํ•„์ˆ˜ ์ง„์ž…์ (Entry Point)
190
  # ======================================================================
191
  def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
 
 
 
 
 
 
192
  """
193
- LeRobot์ด Hub์—์„œ ํ™˜๊ฒฝ์„ ๋กœ๋“œํ•  ๋•Œ ํ˜ธ์ถœํ•˜๋Š” ๋ฉ”์ธ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
194
- """
195
- # ํ™˜๊ฒฝ ๋ž˜ํผ ํด๋ž˜์Šค ์„ ํƒ
196
- env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
197
 
198
- # ์„ค์ •๊ฐ’ ์ถ”์ถœ (cfg ๊ฐ์ฒด๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์ ์šฉ)
199
  if cfg is not None:
200
- task_name = getattr(cfg, "task", "CloseFridge")
201
- fps = getattr(cfg, "fps", 20) # fps ์ถ”์ถœ
202
  gym_kwargs = {
203
  "obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
204
- "render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์œ ์ง€
205
  "observation_width": getattr(cfg, "observation_width", 256),
206
  "observation_height": getattr(cfg, "observation_height", 256),
207
- "camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
208
- "split": getattr(cfg, "split", None),
209
- "fps": fps, # ํ•ต์‹ฌ ์ธ์ž ๋ˆ„๋ฝ ๋ฐฉ์ง€
 
 
210
  }
211
  else:
212
- # cfg ์—†์ด ์ง์ ‘ ํ˜ธ์ถœ๋  ๋•Œ์˜ ๊ธฐ๋ณธ๊ฐ’
213
- task_name = "CloseFridge"
214
  gym_kwargs = {
215
  "obs_type": "pixels_agent_pos",
216
  "render_mode": "rgb_array",
217
  "observation_width": 256,
218
  "observation_height": 256,
219
  "camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
220
- "split": None,
221
  }
222
 
223
  parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
224
- combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
225
-
226
- # ๋ฒค์น˜๋งˆํฌ์ธ์ง€ ๋‹จ์ผ ํƒœ์Šคํฌ์ธ์ง€ ๊ตฌ๋ถ„
227
- parts = [t.strip() for t in task_name.split(",")]
228
- if len(parts) == 1 and parts[0] in combined_tasks:
229
- task_names = combined_tasks[parts[0]]
230
- if gym_kwargs.get("split") is None:
231
- gym_kwargs["split"] = "target" if parts[0] in TARGET_TASKS else "pretrain"
232
- else:
233
- task_names = []
234
- for part in parts:
235
- if part in combined_tasks:
236
- task_names.extend(combined_tasks[part])
237
- else:
238
- task_names.append(part)
239
-
240
-
241
- out = defaultdict(dict)
242
-
243
- # ํƒœ์Šคํฌ๋ณ„๋กœ ํ™˜๊ฒฝ ์ƒ์„ฑ
244
- for task in task_names:
245
  fns = _make_env_fns(
246
  task_name=task,
247
  n_envs=n_envs,
248
  camera_names=parsed_camera_names,
249
- gym_kwargs=gym_kwargs
250
  )
 
 
251
  out[task][0] = env_cls(fns)
252
 
253
- # {suite_name: {task_id: VectorEnv}} ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜
254
- #return {"robocasa": dict(out)}
255
- return {suite: dict(task_map) for suite, task_map in out.items()}
 
3
  from gymnasium import spaces
4
  import numpy as np
5
  from collections import defaultdict
6
+ from collections.abc import Sequence, Mapping
7
  from functools import partial
8
  from typing import Any
9
 
10
  # RoboCasa ์ „์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
11
  from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
12
+ from robocasa.utils.dataset_registry import (
13
+ ATOMIC_TASK_DATASETS,
14
+ COMPOSITE_TASK_DATASETS,
15
+ TARGET_TASKS,
16
+ PRETRAINING_TASKS,
17
+ )
18
+ from robocasa.utils.dataset_registry_utils import get_task_horizon
19
 
20
  OBS_STATE_DIM = 16
21
  ACTION_DIM = 12
22
  ACTION_LOW = -1.0
23
  ACTION_HIGH = 1.0
24
 
25
+
26
+ def convert_state(dict_state):
27
+ """์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ ์ƒํƒœ๋ฅผ LeRobot์ด ๊ธฐ๋Œ€ํ•˜๋Š” 16-dim concat ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
28
+ ์ธ๋ฑ์Šค ์ •์˜๋Š” STATE_ACTION_SPEC.md ์ฐธ๊ณ .
29
+ """
30
  dict_state = dict_state.copy()
31
  final_state = np.concatenate([
32
  dict_state["state.base_position"],
 
35
  dict_state["state.end_effector_rotation_relative"],
36
  dict_state["state.gripper_qpos"],
37
  ], axis=0)
38
+ return final_state.astype(np.float32)
39
+
40
 
41
  def convert_action(action):
42
+ """LeRobot์˜ 12-dim ์•ก์…˜์„ ์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ๊ฐ€ ์ดํ•ดํ•˜๋Š” dict ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
43
+ ์ธ๋ฑ์Šค ์ •์˜๋Š” STATE_ACTION_SPEC.md ์ฐธ๊ณ .
44
+ """
45
+ action = np.asarray(action).copy()
46
  output_action = {
47
  "action.base_motion": action[0:4],
48
  "action.control_mode": action[4:5],
 
52
  }
53
  return output_action
54
 
55
+
56
+ def _parse_camera_names(camera_name) -> list[str]:
57
  if isinstance(camera_name, str):
58
  cams = [c.strip() for c in camera_name.split(",") if c.strip()]
59
  elif isinstance(camera_name, (list, tuple)):
 
64
  raise ValueError("camera_name resolved to an empty list.")
65
  return cams
66
 
67
+
68
+ def _normalize_task_arg(task) -> list[str]:
69
+ """`--env.task=atomic_seen composite_unseen ...` ์ฒ˜๋Ÿผ ๋“ค์–ด์˜ค๋Š” ๋‹ค์–‘ํ•œ ํ˜•ํƒœ๋ฅผ
70
+ list[str]๋กœ ์ •๊ทœํ™”ํ•œ๋‹ค.
71
+ - draccus๊ฐ€ list๋กœ ํŒŒ์‹ฑํ•œ ๊ฒฝ์šฐ: ๊ทธ๋Œ€๋กœ ์œ ์ง€
72
+ - ๋‹จ์ผ ๋ฌธ์ž์—ด์ด๋ฉด ๊ณต๋ฐฑ ๋˜๋Š” ์ฝค๋งˆ ๋ถ„๋ฆฌ
73
+ """
74
+ if task is None:
75
+ raise ValueError("task is required")
76
+ if isinstance(task, (list, tuple)):
77
+ items = []
78
+ for t in task:
79
+ items.extend(_normalize_task_arg(t))
80
+ return items
81
+ s = str(task).strip()
82
+ if not s:
83
+ return []
84
+ # ๊ณต๋ฐฑ/์ฝค๋งˆ ๋ชจ๋‘ ํ—ˆ์šฉ
85
+ parts = []
86
+ for chunk in s.replace(",", " ").split():
87
+ if chunk:
88
+ parts.append(chunk)
89
+ return parts
90
+
91
+
92
  class RoboCasaEnv(RoboCasaGymEnv):
93
  metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
94
 
95
  def __init__(
96
  self,
97
  task: str,
98
+ camera_name: Sequence[str] = ("robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"),
99
  render_mode: str = "rgb_array",
100
  obs_type: str = "pixels_agent_pos",
101
  observation_width: int = 256,
102
  observation_height: int = 256,
103
  split: str | None = None,
104
+ **kwargs,
105
  ):
106
  self.obs_type = obs_type
107
  self.render_mode = render_mode
108
  self.split = split
109
  self.task = task
110
+ self._task_description: str = task
111
+
112
  kwargs.pop("fps", None)
113
  self.kwargs = kwargs
114
 
115
+ # horizon์€ ๊ณต์‹ ํ—ฌํผ ์‚ฌ์šฉ. ๋ฏธ๋“ฑ๋ก ํƒœ์Šคํฌ๋Š” ๋ช…์‹œ์ ์œผ๋กœ ์—๋Ÿฌ.
116
  try:
117
+ self._max_episode_steps = int(get_task_horizon(task))
118
+ except Exception as e:
119
+ valid = list({**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}.keys())
120
+ raise ValueError(f"Unknown task '{task}'. Valid tasks: {valid[:10]}... ({len(valid)} total)") from e
121
+
122
  super().__init__(
123
  task,
124
+ camera_names=list(camera_name),
125
  camera_widths=observation_width,
126
  camera_heights=observation_height,
127
  enable_render=(render_mode is not None),
128
  split=split,
129
+ **kwargs,
130
  )
131
+
132
  def _create_obs_and_action_space(self):
133
  images = {}
134
  for cam in self.camera_names:
 
138
  if self.obs_type == "state":
139
  raise NotImplementedError("The 'state' observation type is not supported.")
140
  elif self.obs_type == "pixels":
141
+ self.observation_space = spaces.Dict({
142
+ "pixels": spaces.Dict(images),
143
+ "task": spaces.Text(max_length=512),
144
+ })
145
  elif self.obs_type == "pixels_agent_pos":
146
  self.observation_space = spaces.Dict({
147
  "pixels": spaces.Dict(images),
148
  "agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
149
+ "task": spaces.Text(max_length=512),
150
  })
151
  else:
152
  raise ValueError(f"Unknown obs_type: {self.obs_type}")
 
159
  def task_description(self) -> str:
160
  return self._task_description
161
 
162
+ def _format_raw_obs(self, raw_obs: dict) -> dict:
163
+ new_obs: dict[str, Any] = {}
 
 
 
 
 
 
 
164
  if self.obs_type == "pixels_agent_pos":
165
  new_obs["agent_pos"] = convert_state(raw_obs)
166
  new_obs["pixels"] = {}
167
  for k, v in raw_obs.items():
168
+ if k.startswith("video."):
169
  new_obs["pixels"][k.replace("video.", "")] = v
170
+ # ์–ธ์–ด ์กฐ๊ฑด: AsyncVectorEnv์—์„œ๋„ ๋Š๊ธฐ์ง€ ์•Š๋„๋ก obs์— ์ง์ ‘ ๋…ธ์ถœ.
171
+ # RoboCasaGymEnv๊ฐ€ obs์— ์ฑ„์›Œ์ฃผ๋Š” annotation์ด ์žˆ์œผ๋ฉด ์šฐ์„  ์‚ฌ์šฉ.
172
+ lang = raw_obs.get("annotation.human.task_description")
173
+ if not lang:
174
+ lang = self._task_description or self.task
175
+ new_obs["task"] = str(lang)
176
+ self._task_description = str(lang)
177
  return new_obs
178
 
179
+ def reset(self, seed: int | None = None, **kwargs):
180
+ # mujoco offscreen GL ์ปจํ…์ŠคํŠธ๊ฐ€ reset ์‹œ์ ์— stale ์ƒํƒœ๊ฐ€ ๋˜๋Š” ํ™˜๊ฒฝ์—์„œ์˜
181
+ # ์šฐํšŒ. (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์—์„œ ํ•„์š”ํ•ด ๋ณด์กด)
182
+ try:
183
+ self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
184
+ except Exception:
185
+ pass
186
+
187
+ observation, info = super().reset(seed=seed, **kwargs)
188
+ # ep meta์˜ lang์ด ๋” ํ’๋ถ€ํ•œ ๋ฌธ์žฅ์„ ์ฃผ๋ฏ€๋กœ ๊ฐ€๋Šฅํ•˜๋ฉด ์‚ฌ์šฉ
189
+ try:
190
+ ep_lang = self.env.get_ep_meta().get("lang", None)
191
+ if ep_lang:
192
+ self._task_description = str(ep_lang)
193
+ except Exception:
194
+ pass
195
+ formatted = self._format_raw_obs(observation)
196
+ info = dict(info or {})
197
+ info["task"] = self.task
198
+ info["task_description"] = self._task_description
199
+ return formatted, info
200
+
201
  def step(self, action: np.ndarray):
202
+ try:
203
+ self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
204
+ except Exception:
205
+ pass
206
+
207
  action_dict = convert_action(action)
208
  observation, reward, done, truncated, info = super().step(action_dict)
209
  new_obs = self._format_raw_obs(observation)
210
 
211
  is_success = bool(info.get("success", 0))
212
+ # ๊ณต์‹ GR00T eval๊ณผ ๋™์ผํ•˜๊ฒŒ: success๋Š” termination ์‹ ํ˜ธ๋กœ ์ง์ ‘ ๋ณ€ํ™˜ํ•˜์ง€ ์•Š๋Š”๋‹ค.
213
+ # gymnasium VectorEnv์˜ autoreset์ด terminated/truncated๋ฅผ ๋ณด๊ณ  final_info๋ฅผ ๋งŒ๋“ ๋‹ค.
214
+ terminated = bool(done) or is_success
215
+
216
+ info = dict(info or {})
217
+ info.update({
218
+ "task": self.task,
219
+ "task_description": self._task_description,
220
+ "is_success": is_success,
221
+ "success": is_success,
222
+ "done": bool(done),
223
+ })
224
+
225
+ # NOTE: ์ž์ฒด self.reset() ํ˜ธ์ถœ์€ ์ œ๊ฑฐ.
226
+ # - gymnasium 0.29+ VectorEnv๊ฐ€ terminated/truncated ์‹œ ์ž๋™์œผ๋กœ resetํ•˜๊ณ 
227
+ # final_info๋ฅผ ์ฑ„์›Œ์ค€๋‹ค. wrapper๊ฐ€ ํ•œ ๋ฒˆ ๋” resetํ•˜๋ฉด ์ฒซ obs๊ฐ€ final obs๋ฅผ
228
+ # ๋ฎ์–ด์“ฐ๊ณ , lerobot rollout์˜ final_info["is_success"]๊ฐ€ ์ •ํ•ฉ์„ฑ์„ ์žƒ๋Š”๋‹ค.
229
+
230
+ return new_obs, float(reward), terminated, truncated, info
231
 
232
  def render(self):
233
  frame = super().render()
 
264
  return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
265
 
266
 
267
+ def _resolve_task_list(task_arg, explicit_split: str | None):
268
+ """`task` ์ธ์ž(๋ฌธ์ž์—ด/๋ฆฌ์ŠคํŠธ)์™€ ์‚ฌ์šฉ์ž๊ฐ€ ๋ช…์‹œํ•œ split์„ ๋ณด๊ณ 
269
+ ์‹ค์ œ๋กœ ๋„์šธ (task_name, split) ํŽ˜์–ด ๋ฆฌ์ŠคํŠธ๋ฅผ ๋งŒ๋“ ๋‹ค.
270
+
271
+ ๊ณต์‹ run_eval.py์ฒ˜๋Ÿผ:
272
+ - benchmark ํ‚ค(`atomic_seen`, `composite_unseen`, ...)๊ฐ€ ๋“ค์–ด์˜ค๋ฉด ํŽผ์นœ๋‹ค.
273
+ - explicit_split์ด ์ง€์ •๋˜๋ฉด ๊ทธ๊ฒƒ์„ ์šฐ์„ ํ•œ๋‹ค.
274
+ - ์•„๋‹ˆ๋ฉด TARGET_TASKS / PRETRAINING_TASKS ๋“ฑ๋ก๋ถ€์—์„œ ์ž๋™ ์ถ”๋ก .
275
+ - ๋‹จ์ผ task ์ด๋ฆ„์ด๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ.
276
+ """
277
+ items = _normalize_task_arg(task_arg)
278
+ pairs: list[tuple[str, str | None]] = []
279
+ for item in items:
280
+ if item in TARGET_TASKS:
281
+ split = explicit_split or "target"
282
+ for sub in TARGET_TASKS[item]:
283
+ pairs.append((sub, split))
284
+ elif item in PRETRAINING_TASKS:
285
+ split = explicit_split or "pretrain"
286
+ for sub in PRETRAINING_TASKS[item]:
287
+ pairs.append((sub, split))
288
+ else:
289
+ # ๋‹จ์ผ task ์ด๋ฆ„
290
+ pairs.append((item, explicit_split))
291
+ # ์ค‘๋ณต ์ œ๊ฑฐ (์ˆœ์„œ ์œ ์ง€)
292
+ seen = set()
293
+ uniq: list[tuple[str, str | None]] = []
294
+ for p in pairs:
295
+ if p in seen:
296
+ continue
297
+ seen.add(p)
298
+ uniq.append(p)
299
+ return uniq
300
+
301
+
302
  # ======================================================================
303
  # LeRobot Hub ํ•„์ˆ˜ ์ง„์ž…์ (Entry Point)
304
  # ======================================================================
305
  def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
306
+ """LeRobot์ด Hub์—์„œ ํ™˜๊ฒฝ์„ ๋กœ๋“œํ•  ๋•Œ ํ˜ธ์ถœํ•˜๋Š” ๋ฉ”์ธ ํ•จ์ˆ˜.
307
+
308
+ ๊ณต์‹ GR00T eval(`Isaac-GR00T/scripts/run_eval.py`)๊ณผ ๋™์ผํ•˜๊ฒŒ:
309
+ - benchmark ํ‚ค(`atomic_seen`, `composite_unseen`, ...)๋Š” sub-task ๋ฆฌ์ŠคํŠธ๋กœ ํŽผ์นœ๋‹ค.
310
+ - `cfg.split`์ด ์ง€์ •๋˜๋ฉด ๊ทธ๊ฒƒ์„ ์šฐ์„ ํ•œ๋‹ค.
311
+ - ๊ฐ sub-task๋Š” ์ž์‹ ์˜ horizon์œผ๋กœ ๋ณ„๋„ VectorEnv๋ฅผ ๋งŒ๋“ ๋‹ค.
312
  """
313
+ env_cls = (
314
+ partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
315
+ )
 
316
 
 
317
  if cfg is not None:
318
+ task_arg = getattr(cfg, "task", None)
319
+ explicit_split = getattr(cfg, "split", None)
320
  gym_kwargs = {
321
  "obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
322
+ "render_mode": getattr(cfg, "render_mode", "rgb_array"),
323
  "observation_width": getattr(cfg, "observation_width", 256),
324
  "observation_height": getattr(cfg, "observation_height", 256),
325
+ "camera_name": getattr(
326
+ cfg, "camera_name",
327
+ "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
328
+ ),
329
+ "fps": getattr(cfg, "fps", 20),
330
  }
331
  else:
332
+ task_arg = "CloseFridge"
333
+ explicit_split = None
334
  gym_kwargs = {
335
  "obs_type": "pixels_agent_pos",
336
  "render_mode": "rgb_array",
337
  "observation_width": 256,
338
  "observation_height": 256,
339
  "camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
340
+ "fps": 20,
341
  }
342
 
343
  parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
344
+ task_split_pairs = _resolve_task_list(task_arg, explicit_split)
345
+ if not task_split_pairs:
346
+ raise ValueError(f"No tasks resolved from task={task_arg!r}, split={explicit_split!r}")
347
+
348
+ out: dict[str, dict[int, Any]] = defaultdict(dict)
349
+ for idx, (task, split) in enumerate(task_split_pairs):
350
+ per_task_kwargs = dict(gym_kwargs)
351
+ per_task_kwargs["split"] = split
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  fns = _make_env_fns(
353
  task_name=task,
354
  n_envs=n_envs,
355
  camera_names=parsed_camera_names,
356
+ gym_kwargs=per_task_kwargs,
357
  )
358
+ # `{suite: {task_id: VectorEnv}}` ๊ตฌ์กฐ: lerobot_eval์€ group/task ๋‹จ์œ„๋กœ ์ง‘๊ณ„ํ•˜๋ฏ€๋กœ
359
+ # task_name ์ž์ฒด๋ฅผ suite ํ‚ค๋กœ ์‚ฌ์šฉํ•ด per-task SR์„ ๊ทธ๋Œ€๋กœ ๋…ธ์ถœ.
360
  out[task][0] = env_cls(fns)
361
 
362
+ return {suite: dict(task_map) for suite, task_map in out.items()}
 
 
env.py.bak ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # env.py
2
+ import gymnasium as gym
3
+ from gymnasium import spaces
4
+ import numpy as np
5
+ from collections import defaultdict
6
+ from collections.abc import Callable, Sequence, Mapping
7
+ from functools import partial
8
+ from typing import Any
9
+
10
+ # RoboCasa ์ „์šฉ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
11
+ from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
12
+ from robocasa.utils.dataset_registry import ATOMIC_TASK_DATASETS, COMPOSITE_TASK_DATASETS, TARGET_TASKS, PRETRAINING_TASKS
13
+
14
+ OBS_STATE_DIM = 16
15
+ ACTION_DIM = 12
16
+ ACTION_LOW = -1.0
17
+ ACTION_HIGH = 1.0
18
+
19
+ def convert_state(dict_state):
20
+ """์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ ์ƒํƒœ๋ฅผ LeRobot์ด ๊ธฐ๋Œ€ํ•˜๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜(Conversion)ํ•ฉ๋‹ˆ๋‹ค."""
21
+ dict_state = dict_state.copy()
22
+ final_state = np.concatenate([
23
+ dict_state["state.base_position"],
24
+ dict_state["state.base_rotation"],
25
+ dict_state["state.end_effector_position_relative"],
26
+ dict_state["state.end_effector_rotation_relative"],
27
+ dict_state["state.gripper_qpos"],
28
+ ], axis=0)
29
+ return final_state
30
+
31
+ def convert_action(action):
32
+ """LeRobot์˜ ์•ก์…˜์„ ์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ๊ฐ€ ์ดํ•ดํ•˜๋Š” dict ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
33
+ action = action.copy()
34
+ output_action = {
35
+ "action.base_motion": action[0:4],
36
+ "action.control_mode": action[4:5],
37
+ "action.end_effector_position": action[5:8],
38
+ "action.end_effector_rotation": action[8:11],
39
+ "action.gripper_close": action[11:12],
40
+ }
41
+ return output_action
42
+
43
+ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
44
+ """์นด๋ฉ”๋ผ ์ด๋ฆ„์„ ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ •๊ทœํ™”(Normalization)ํ•ฉ๋‹ˆ๋‹ค."""
45
+ if isinstance(camera_name, str):
46
+ cams = [c.strip() for c in camera_name.split(",") if c.strip()]
47
+ elif isinstance(camera_name, (list, tuple)):
48
+ cams = [str(c).strip() for c in camera_name if str(c).strip()]
49
+ else:
50
+ raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
51
+ if not cams:
52
+ raise ValueError("camera_name resolved to an empty list.")
53
+ return cams
54
+
55
+ class RoboCasaEnv(RoboCasaGymEnv):
56
+ metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
57
+
58
+ def __init__(
59
+ self,
60
+ task: str,
61
+ camera_name: Sequence[str] = ["robot0_agentview_left", "robot0_eye_in_hand", "robot0_agentview_right"],
62
+ render_mode: str = "rgb_array",
63
+ obs_type: str = "pixels_agent_pos",
64
+ observation_width: int = 256,
65
+ observation_height: int = 256,
66
+ split: str | None = None,
67
+ **kwargs
68
+ ):
69
+ self.obs_type = obs_type
70
+ self.render_mode = render_mode
71
+ self.split = split
72
+ self.task = task
73
+ self._task_description = ""
74
+
75
+ kwargs.pop("fps", None)
76
+ self.kwargs = kwargs
77
+
78
+ meta_info = {**ATOMIC_TASK_DATASETS, **COMPOSITE_TASK_DATASETS}
79
+ try:
80
+ self._max_episode_steps = meta_info[task]['horizon']
81
+ except KeyError:
82
+ raise ValueError(f"Unknown task '{task}'. Valid tasks are: {list(meta_info.keys())}")
83
+
84
+ super().__init__(
85
+ task,
86
+ camera_names=camera_name,
87
+ camera_widths=observation_width,
88
+ camera_heights=observation_height,
89
+ enable_render=(render_mode is not None),
90
+ split=split,
91
+ **kwargs
92
+ )
93
+
94
+ def _create_obs_and_action_space(self):
95
+ images = {}
96
+ for cam in self.camera_names:
97
+ images[cam] = spaces.Box(
98
+ low=0, high=255, shape=(self.camera_heights, self.camera_widths, 3), dtype=np.uint8
99
+ )
100
+ if self.obs_type == "state":
101
+ raise NotImplementedError("The 'state' observation type is not supported.")
102
+ elif self.obs_type == "pixels":
103
+ self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
104
+ elif self.obs_type == "pixels_agent_pos":
105
+ self.observation_space = spaces.Dict({
106
+ "pixels": spaces.Dict(images),
107
+ "agent_pos": spaces.Box(low=-1000, high=1000, shape=(OBS_STATE_DIM,), dtype=np.float32),
108
+ })
109
+ else:
110
+ raise ValueError(f"Unknown obs_type: {self.obs_type}")
111
+
112
+ self.action_space = spaces.Box(
113
+ low=ACTION_LOW, high=ACTION_HIGH, shape=(int(ACTION_DIM),), dtype=np.float32
114
+ )
115
+
116
+ @property
117
+ def task_description(self) -> str:
118
+ return self._task_description
119
+
120
+ def reset(self, seed: int | None = None, **kwargs):
121
+ self.unwrapped.sim._render_context_offscreen.gl_ctx.free()
122
+ observation, info = super().reset(seed, **kwargs)
123
+ self._task_description = self.env.get_ep_meta().get("lang", self.task)
124
+ print(f"[RoboCasaEnv] task_description: {self._task_description!r}")
125
+ return self._format_raw_obs(observation), info
126
+
127
+ def _format_raw_obs(self, raw_obs: dict):
128
+ new_obs = {}
129
+ if self.obs_type == "pixels_agent_pos":
130
+ new_obs["agent_pos"] = convert_state(raw_obs)
131
+ new_obs["pixels"] = {}
132
+ for k, v in raw_obs.items():
133
+ if "video." in k:
134
+ new_obs["pixels"][k.replace("video.", "")] = v
135
+ return new_obs
136
+
137
+ def step(self, action: np.ndarray):
138
+ self.unwrapped.sim._render_context_offscreen.gl_ctx.make_current()
139
+ action_dict = convert_action(action)
140
+ observation, reward, done, truncated, info = super().step(action_dict)
141
+ new_obs = self._format_raw_obs(observation)
142
+
143
+ is_success = bool(info.get("success", 0))
144
+ terminated = done or is_success
145
+ info.update({"task": self.task, "done": done, "is_success": is_success})
146
+
147
+ if terminated:
148
+ info["final_info"] = {"task": self.task, "done": bool(done), "is_success": bool(is_success)}
149
+ self.reset()
150
+
151
+ return new_obs, reward, terminated, truncated, info
152
+
153
+ def render(self):
154
+ frame = super().render()
155
+ if frame is None:
156
+ return frame
157
+ from PIL import Image, ImageDraw, ImageFont
158
+ import textwrap
159
+
160
+ text = self._task_description or self.task
161
+ w = frame.shape[1]
162
+
163
+ try:
164
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
165
+ except Exception:
166
+ font = ImageFont.load_default()
167
+
168
+ lines = textwrap.wrap(text, width=55)
169
+ line_h = 18
170
+ bar_h = len(lines) * line_h + 10
171
+
172
+ bar = Image.new("RGB", (w, bar_h), color=(30, 30, 30))
173
+ draw = ImageDraw.Draw(bar)
174
+ for i, line in enumerate(lines):
175
+ draw.text((8, 5 + i * line_h), line, font=font, fill=(220, 220, 220))
176
+
177
+ return np.concatenate([frame, np.array(bar)], axis=0)
178
+
179
+
180
+ def _make_env_fns(task_name: str, n_envs: int, camera_names: list[str], gym_kwargs: Mapping[str, Any]):
181
+ def _make_env(episode_index: int, **kwargs):
182
+ seed = kwargs.pop("seed", episode_index)
183
+ return RoboCasaEnv(task=task_name, camera_name=camera_names, seed=seed, **kwargs)
184
+
185
+ return [partial(_make_env, i, **gym_kwargs) for i in range(n_envs)]
186
+
187
+
188
+ # ======================================================================
189
+ # LeRobot Hub ํ•„์ˆ˜ ์ง„์ž…์ (Entry Point)
190
+ # ======================================================================
191
+ def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[str, dict[int, Any]]:
192
+ """
193
+ LeRobot์ด Hub์—์„œ ํ™˜๊ฒฝ์„ ๋กœ๋“œํ•  ๋•Œ ํ˜ธ์ถœํ•˜๋Š” ๋ฉ”์ธ ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
194
+ """
195
+ # ํ™˜๊ฒฝ ๋ž˜ํผ ํด๋ž˜์Šค ์„ ํƒ
196
+ env_cls = partial(gym.vector.AsyncVectorEnv, context="spawn") if use_async_envs else gym.vector.SyncVectorEnv
197
+
198
+ # ์„ค์ •๊ฐ’ ์ถ”์ถœ (cfg ๊ฐ์ฒด๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์ ์šฉ)
199
+ if cfg is not None:
200
+ task_name = getattr(cfg, "task", "CloseFridge")
201
+ fps = getattr(cfg, "fps", 20) # fps ์ถ”์ถœ
202
+ gym_kwargs = {
203
+ "obs_type": getattr(cfg, "obs_type", "pixels_agent_pos"),
204
+ "render_mode": getattr(cfg, "render_mode", "rgb_array"), # render_mode ์œ ์ง€
205
+ "observation_width": getattr(cfg, "observation_width", 256),
206
+ "observation_height": getattr(cfg, "observation_height", 256),
207
+ "camera_name": getattr(cfg, "camera_name", "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"),
208
+ "split": getattr(cfg, "split", None),
209
+ "fps": fps, # ํ•ต์‹ฌ ์ธ์ž ๋ˆ„๋ฝ ๋ฐฉ์ง€
210
+ }
211
+ else:
212
+ # cfg ์—†์ด ์ง์ ‘ ํ˜ธ์ถœ๋  ๋•Œ์˜ ๊ธฐ๋ณธ๊ฐ’
213
+ task_name = "CloseFridge"
214
+ gym_kwargs = {
215
+ "obs_type": "pixels_agent_pos",
216
+ "render_mode": "rgb_array",
217
+ "observation_width": 256,
218
+ "observation_height": 256,
219
+ "camera_name": "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right",
220
+ "split": None,
221
+ }
222
+
223
+ parsed_camera_names = _parse_camera_names(gym_kwargs.pop("camera_name"))
224
+ combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
225
+
226
+ # ๋ฒค์น˜๋งˆํฌ์ธ์ง€ ๋‹จ์ผ ํƒœ์Šคํฌ์ธ์ง€ ๊ตฌ๋ถ„
227
+ if task_name in combined_tasks:
228
+ task_names = combined_tasks[task_name]
229
+ gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
230
+ else:
231
+ task_names = [t.strip() for t in task_name.split(",")]
232
+
233
+
234
+ out = defaultdict(dict)
235
+
236
+ # ํƒœ์Šคํฌ๋ณ„๋กœ ํ™˜๊ฒฝ ์ƒ์„ฑ
237
+ for task in task_names:
238
+ fns = _make_env_fns(
239
+ task_name=task,
240
+ n_envs=n_envs,
241
+ camera_names=parsed_camera_names,
242
+ gym_kwargs=gym_kwargs
243
+ )
244
+ out[task][0] = env_cls(fns)
245
+
246
+ # {suite_name: {task_id: VectorEnv}} ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜
247
+ #return {"robocasa": dict(out)}
248
+ return {suite: dict(task_map) for suite, task_map in out.items()}