qgallouedec HF Staff commited on
Commit
24b769c
·
1 Parent(s): 0dc04b5

Upload folder using huggingface_hub

Browse files
.summary/0/events.out.tfevents.1689668898.qgallouedec-MS-7C84 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57ab3e2606c5596edd60eeff747d413f18bc2e204600be0ab03bcbc4088885b7
3
+ size 1721261
README.md CHANGED
@@ -15,7 +15,7 @@ model-index:
15
  type: peg-insert-side-v2
16
  metrics:
17
  - type: mean_reward
18
- value: 629.67 +/- 36.17
19
  name: mean_reward
20
  verified: false
21
  ---
 
15
  type: peg-insert-side-v2
16
  metrics:
17
  - type: mean_reward
18
+ value: 308.94 +/- 175.97
19
  name: mean_reward
20
  verified: false
21
  ---
checkpoint_p0/best_000020128_10305536_reward_380.859.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4f89037b6dc159da0ecbd28afe5180484dbfb721e429e70b92ad21683daac02
3
+ size 98239
checkpoint_p0/checkpoint_000058304_29851648.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3826299d6896f3ce26ffd747dab9b5cf1606fd94acb3f87bf41284484c0cab85
3
+ size 98567
checkpoint_p0/checkpoint_000058608_30007296.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aba0db7bfa3a5e2707771d34c65d1600277b030a2cffa14559da51a61d4843dd
3
+ size 98567
config.json CHANGED
@@ -65,7 +65,7 @@
65
  "summaries_use_frameskip": true,
66
  "heartbeat_interval": 20,
67
  "heartbeat_reporting_interval": 180,
68
- "train_for_env_steps": 10000000,
69
  "train_for_seconds": 10000000000,
70
  "save_every_sec": 15,
71
  "keep_checkpoints": 2,
@@ -128,7 +128,7 @@
128
  "wandb_user": "qgallouedec",
129
  "wandb_project": "sample_facotry_metaworld"
130
  },
131
- "git_hash": "66db1b7a27030aa65fcfa2d6e3503089a7cff207",
132
  "git_repo_name": "https://github.com/huggingface/gia",
133
- "wandb_unique_id": "peg-insert-side-v2_20230708_211703_755132"
134
  }
 
65
  "summaries_use_frameskip": true,
66
  "heartbeat_interval": 20,
67
  "heartbeat_reporting_interval": 180,
68
+ "train_for_env_steps": 30000000,
69
  "train_for_seconds": 10000000000,
70
  "save_every_sec": 15,
71
  "keep_checkpoints": 2,
 
128
  "wandb_user": "qgallouedec",
129
  "wandb_project": "sample_facotry_metaworld"
130
  },
131
+ "git_hash": "952d4a00946fa97ee3267d32a2160be9933e887a",
132
  "git_repo_name": "https://github.com/huggingface/gia",
133
+ "wandb_unique_id": "peg-insert-side-v2_20230718_102816_306101"
134
  }
git.diff CHANGED
@@ -1,3 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  diff --git a/gia/eval/callback.py b/gia/eval/callback.py
2
  index 5c3a080..4b6198f 100644
3
  --- a/gia/eval/callback.py
@@ -14,70 +153,176 @@ index 5c3a080..4b6198f 100644
14
  from gia.config import Arguments
15
  from gia.eval.utils import is_slurm_available
16
 
17
- diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
18
- index 91b645c..3e2cae7 100644
19
- --- a/gia/eval/evaluator.py
20
- +++ b/gia/eval/evaluator.py
21
- @@ -1,3 +1,5 @@
22
- +from typing import Optional
23
  +
24
- import torch
 
 
 
 
 
 
 
25
 
26
- from gia.config.arguments import Arguments
27
- @@ -5,11 +7,12 @@ from gia.model import GiaModel
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- class Evaluator:
31
- - def __init__(self, args: Arguments, task: str) -> None:
32
- + def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None:
33
- self.args = args
34
- self.task = task
35
- + self.mean_random = mean_random
36
 
37
- - @torch.no_grad()
38
- + @torch.inference_mode()
39
- def evaluate(self, model: GiaModel) -> float:
40
- return self._evaluate(model)
41
 
42
- diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
43
- index ec5e5b2..eeaf7cb 100644
44
- --- a/gia/eval/rl/envs/core.py
45
- +++ b/gia/eval/rl/envs/core.py
46
- @@ -177,7 +177,6 @@ def make(task_name: str, num_envs: int = 1):
47
-
48
- elif task_name.startswith("metaworld"):
49
- import gymnasium as gym
50
- - import metaworld
51
-
52
- env_id = TASK_TO_ENV_MAPPING[task_name]
53
- env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs)
54
- diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
55
- index f0d0b9b..39dc0d2 100644
56
- --- a/gia/eval/rl/gia_agent.py
57
- +++ b/gia/eval/rl/gia_agent.py
58
- @@ -54,7 +54,7 @@ class GiaAgent:
59
- self.action_space = action_space
60
- self.deterministic = deterministic
61
- self.device = next(model.parameters()).device
62
- - self._max_length = self.model.config.max_position_embeddings - 10
63
- + self._max_length = self.model.config.max_position_embeddings - 100 # TODO: fix this
64
-
65
- if isinstance(observation_space, spaces.Box):
66
- self._observation_key = "continuous_observations"
67
- @@ -75,6 +75,11 @@ class GiaAgent:
68
- ) -> Tuple[Tuple[Tensor, Tensor], ...]:
69
- return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
70
-
71
- + def set_model(self, model: GiaModel) -> None:
72
- + self.model = model
73
- + self.device = next(model.parameters()).device
74
- + self._max_length = self.model.config.max_position_embeddings
75
  +
76
- def reset(self, num_envs: int = 1) -> None:
77
- if self.prompter is not None:
78
- prompts = self.prompter.generate_prompts(num_envs)
 
 
 
 
 
79
  diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py
80
- index f8531ee..754c05d 100644
81
  --- a/gia/eval/rl/gym_evaluator.py
82
  +++ b/gia/eval/rl/gym_evaluator.py
83
  @@ -1,7 +1,7 @@
@@ -85,37 +330,38 @@ index f8531ee..754c05d 100644
85
  from gym.vector.vector_env import VectorEnv
86
 
87
  -from gia.eval.mappings import TASK_TO_ENV_MAPPING
88
- +# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING
89
  from gia.eval.rl.rl_evaluator import RLEvaluator
90
 
91
 
92
- diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py
93
- index c5cc423..91189f3 100644
94
- --- a/gia/eval/rl/rl_evaluator.py
95
- +++ b/gia/eval/rl/rl_evaluator.py
96
- @@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent
 
97
 
98
 
99
- class RLEvaluator(Evaluator):
100
- + def __init__(self, args, task):
101
- + super().__init__(args, task)
102
- + self.agent = GiaAgent()
103
  +
104
- def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ?
105
- raise NotImplementedError
106
-
107
- diff --git a/gia/eval/rl/scores_dict.json b/gia/eval/rl/scores_dict.json
108
- index 1b8ebee..ff7d030 100644
109
- --- a/gia/eval/rl/scores_dict.json
110
- +++ b/gia/eval/rl/scores_dict.json
111
- @@ -929,8 +929,8 @@
112
- },
113
- "metaworld-assembly": {
114
- "expert": {
115
- - "mean": 311.29314618777823,
116
- - "std": 75.04282151450695
117
- + "mean": 3523.81468486244,
118
- + "std": 63.22745220327798
119
- },
120
- "random": {
121
- "mean": 220.65601680730813,
 
 
 
1
+ diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
2
+ index acfe879..fc2b602 100755
3
+ --- a/data/envs/metaworld/generate_dataset_all.sh
4
+ +++ b/data/envs/metaworld/generate_dataset_all.sh
5
+ @@ -1,59 +1,15 @@
6
+ #!/bin/bash
7
+
8
+ ENVS=(
9
+ - assembly
10
+ - basketball
11
+ bin-picking
12
+ - box-close
13
+ - button-press-topdown
14
+ - button-press-topdown-wall
15
+ - button-press
16
+ - button-press-wall
17
+ - coffee-button
18
+ - coffee-pull
19
+ - coffee-push
20
+ - dial-turn
21
+ - disassemble
22
+ - door-close
23
+ - door-lock
24
+ - door-open
25
+ - door-unlock
26
+ - drawer-close
27
+ - drawer-open
28
+ - faucet-close
29
+ - faucet-open
30
+ hammer
31
+ - hand-insert
32
+ - handle-press-side
33
+ - handle-press
34
+ - handle-pull-side
35
+ - handle-pull
36
+ - lever-pull
37
+ - peg-insert-side
38
+ - peg-unplug-side
39
+ pick-out-of-hole
40
+ pick-place
41
+ - pick-place-wall
42
+ - plate-slide-back-side
43
+ - plate-slide-back
44
+ - plate-slide-side
45
+ - plate-slide
46
+ - push-back
47
+ - push
48
+ - push-wall
49
+ - reach
50
+ - reach-wall
51
+ shelf-place
52
+ soccer
53
+ - stick-pull
54
+ - stick-push
55
+ - sweep-into
56
+ - sweep
57
+ - window-close
58
+ - window-open
59
+ )
60
+
61
+ for ENV in "${ENVS[@]}"; do
62
+ - python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2
63
+ + # python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2
64
+ python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
65
+ done
66
+ diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py
67
+ index 095414e..0ea5bde 100644
68
+ --- a/data/envs/metaworld/train.py
69
+ +++ b/data/envs/metaworld/train.py
70
+ @@ -25,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
71
+ num_workers=8,
72
+ num_envs_per_worker=8,
73
+ worker_num_splits=2,
74
+ - train_for_env_steps=10_000_000,
75
+ + train_for_env_steps=30_000_000,
76
+ encoder_mlp_layers=[64, 64],
77
+ env_frameskip=1,
78
+ nonlinearity="tanh",
79
+ diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh
80
+ index dbf328a..67ab9a0 100755
81
+ --- a/data/envs/metaworld/train_all.sh
82
+ +++ b/data/envs/metaworld/train_all.sh
83
+ @@ -1,56 +1,8 @@
84
+ #!/bin/bash
85
+
86
+ ENVS=(
87
+ - assembly
88
+ - basketball
89
+ - bin-picking
90
+ - box-close
91
+ - button-press-topdown
92
+ - button-press-topdown-wall
93
+ - button-press
94
+ - button-press-wall
95
+ - coffee-button
96
+ - coffee-pull
97
+ - coffee-push
98
+ - dial-turn
99
+ disassemble
100
+ - door-close
101
+ - door-lock
102
+ - door-open
103
+ - door-unlock
104
+ - drawer-close
105
+ - drawer-open
106
+ - faucet-close
107
+ - faucet-open
108
+ - hammer
109
+ - hand-insert
110
+ - handle-press-side
111
+ - handle-press
112
+ - handle-pull-side
113
+ - handle-pull
114
+ - lever-pull
115
+ peg-insert-side
116
+ - peg-unplug-side
117
+ - pick-out-of-hole
118
+ - pick-place
119
+ - pick-place-wall
120
+ - plate-slide-back-side
121
+ - plate-slide-back
122
+ - plate-slide-side
123
+ - plate-slide
124
+ - push-back
125
+ - push
126
+ - push-wall
127
+ - reach
128
+ - reach-wall
129
+ - shelf-place
130
+ - soccer
131
+ - stick-pull
132
+ - stick-push
133
+ - sweep-into
134
+ - sweep
135
+ - window-close
136
+ - window-open
137
+ )
138
+
139
+ for ENV in "${ENVS[@]}"; do
140
  diff --git a/gia/eval/callback.py b/gia/eval/callback.py
141
  index 5c3a080..4b6198f 100644
142
  --- a/gia/eval/callback.py
 
153
  from gia.config import Arguments
154
  from gia.eval.utils import is_slurm_available
155
 
156
+ diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
157
+ index 22c5b49..7464ff5 100644
158
+ --- a/gia/eval/rl/envs/core.py
159
+ +++ b/gia/eval/rl/envs/core.py
160
+ @@ -1,6 +1,8 @@
161
+ +from typing import Dict
162
  +
163
+ import gymnasium as gym
164
+ import numpy as np
165
+ -from gymnasium import Env, ObservationWrapper, spaces
166
+ +from gymnasium import Env, ObservationWrapper, RewardWrapper, spaces
167
+ from sample_factory.envs.env_wrappers import (
168
+ ClipRewardEnv,
169
+ EpisodicLifeEnv,
170
+ @@ -12,63 +14,63 @@ from sample_factory.envs.env_wrappers import (
171
 
 
 
172
 
173
+ TASK_TO_ENV_MAPPING = {
174
+ - "atari-alien": "Alien-v4",
175
+ - "atari-amidar": "Amidar-v4",
176
+ - "atari-assault": "Assault-v4",
177
+ - "atari-asterix": "Asterix-v4",
178
+ - "atari-asteroids": "Asteroids-v4",
179
+ - "atari-atlantis": "Atlantis-v4",
180
+ - "atari-bankheist": "BankHeist-v4",
181
+ - "atari-battlezone": "BattleZone-v4",
182
+ - "atari-beamrider": "BeamRider-v4",
183
+ - "atari-berzerk": "Berzerk-v4",
184
+ - "atari-bowling": "Bowling-v4",
185
+ - "atari-boxing": "Boxing-v4",
186
+ - "atari-breakout": "Breakout-v4",
187
+ - "atari-centipede": "Centipede-v4",
188
+ - "atari-choppercommand": "ChopperCommand-v4",
189
+ - "atari-crazyclimber": "CrazyClimber-v4",
190
+ - "atari-defender": "Defender-v4",
191
+ - "atari-demonattack": "DemonAttack-v4",
192
+ - "atari-doubledunk": "DoubleDunk-v4",
193
+ - "atari-enduro": "Enduro-v4",
194
+ - "atari-fishingderby": "FishingDerby-v4",
195
+ - "atari-freeway": "Freeway-v4",
196
+ - "atari-frostbite": "Frostbite-v4",
197
+ - "atari-gopher": "Gopher-v4",
198
+ - "atari-gravitar": "Gravitar-v4",
199
+ - "atari-hero": "Hero-v4",
200
+ - "atari-icehockey": "IceHockey-v4",
201
+ - "atari-jamesbond": "Jamesbond-v4",
202
+ - "atari-kangaroo": "Kangaroo-v4",
203
+ - "atari-krull": "Krull-v4",
204
+ - "atari-kungfumaster": "KungFuMaster-v4",
205
+ - "atari-montezumarevenge": "MontezumaRevenge-v4",
206
+ - "atari-mspacman": "MsPacman-v4",
207
+ - "atari-namethisgame": "NameThisGame-v4",
208
+ - "atari-phoenix": "Phoenix-v4",
209
+ - "atari-pitfall": "Pitfall-v4",
210
+ - "atari-pong": "Pong-v4",
211
+ - "atari-privateeye": "PrivateEye-v4",
212
+ - "atari-qbert": "Qbert-v4",
213
+ - "atari-riverraid": "Riverraid-v4",
214
+ - "atari-roadrunner": "RoadRunner-v4",
215
+ - "atari-robotank": "Robotank-v4",
216
+ - "atari-seaquest": "Seaquest-v4",
217
+ - "atari-skiing": "Skiing-v4",
218
+ - "atari-solaris": "Solaris-v4",
219
+ - "atari-spaceinvaders": "SpaceInvaders-v4",
220
+ - "atari-stargunner": "StarGunner-v4",
221
+ + "atari-alien": "ALE/Alien-v5",
222
+ + "atari-amidar": "ALE/Amidar-v5",
223
+ + "atari-assault": "ALE/Assault-v5",
224
+ + "atari-asterix": "ALE/Asterix-v5",
225
+ + "atari-asteroids": "ALE/Asteroids-v5",
226
+ + "atari-atlantis": "ALE/Atlantis-v5",
227
+ + "atari-bankheist": "ALE/BankHeist-v5",
228
+ + "atari-battlezone": "ALE/BattleZone-v5",
229
+ + "atari-beamrider": "ALE/BeamRider-v5",
230
+ + "atari-berzerk": "ALE/Berzerk-v5",
231
+ + "atari-bowling": "ALE/Bowling-v5",
232
+ + "atari-boxing": "ALE/Boxing-v5",
233
+ + "atari-breakout": "ALE/Breakout-v5",
234
+ + "atari-centipede": "ALE/Centipede-v5",
235
+ + "atari-choppercommand": "ALE/ChopperCommand-v5",
236
+ + "atari-crazyclimber": "ALE/CrazyClimber-v5",
237
+ + "atari-defender": "ALE/Defender-v5",
238
+ + "atari-demonattack": "ALE/DemonAttack-v5",
239
+ + "atari-doubledunk": "ALE/DoubleDunk-v5",
240
+ + "atari-enduro": "ALE/Enduro-v5",
241
+ + "atari-fishingderby": "ALE/FishingDerby-v5",
242
+ + "atari-freeway": "ALE/Freeway-v5",
243
+ + "atari-frostbite": "ALE/Frostbite-v5",
244
+ + "atari-gopher": "ALE/Gopher-v5",
245
+ + "atari-gravitar": "ALE/Gravitar-v5",
246
+ + "atari-hero": "ALE/Hero-v5",
247
+ + "atari-icehockey": "ALE/IceHockey-v5",
248
+ + "atari-jamesbond": "ALE/Jamesbond-v5",
249
+ + "atari-kangaroo": "ALE/Kangaroo-v5",
250
+ + "atari-krull": "ALE/Krull-v5",
251
+ + "atari-kungfumaster": "ALE/KungFuMaster-v5",
252
+ + "atari-montezumarevenge": "ALE/MontezumaRevenge-v5",
253
+ + "atari-mspacman": "ALE/MsPacman-v5",
254
+ + "atari-namethisgame": "ALE/NameThisGame-v5",
255
+ + "atari-phoenix": "ALE/Phoenix-v5",
256
+ + "atari-pitfall": "ALE/Pitfall-v5",
257
+ + "atari-pong": "ALE/Pong-v5",
258
+ + "atari-privateeye": "ALE/PrivateEye-v5",
259
+ + "atari-qbert": "ALE/Qbert-v5",
260
+ + "atari-riverraid": "ALE/Riverraid-v5",
261
+ + "atari-roadrunner": "ALE/RoadRunner-v5",
262
+ + "atari-robotank": "ALE/Robotank-v5",
263
+ + "atari-seaquest": "ALE/Seaquest-v5",
264
+ + "atari-skiing": "ALE/Skiing-v5",
265
+ + "atari-solaris": "ALE/Solaris-v5",
266
+ + "atari-spaceinvaders": "ALE/SpaceInvaders-v5",
267
+ + "atari-stargunner": "ALE/StarGunner-v5",
268
+ "atari-surround": "ALE/Surround-v5",
269
+ - "atari-tennis": "Tennis-v4",
270
+ - "atari-timepilot": "TimePilot-v4",
271
+ - "atari-tutankham": "Tutankham-v4",
272
+ - "atari-upndown": "UpNDown-v4",
273
+ - "atari-venture": "Venture-v4",
274
+ - "atari-videopinball": "VideoPinball-v4",
275
+ - "atari-wizardofwor": "WizardOfWor-v4",
276
+ - "atari-yarsrevenge": "YarsRevenge-v4",
277
+ - "atari-zaxxon": "Zaxxon-v4",
278
+ + "atari-tennis": "ALE/Tennis-v5",
279
+ + "atari-timepilot": "ALE/TimePilot-v5",
280
+ + "atari-tutankham": "ALE/Tutankham-v5",
281
+ + "atari-upndown": "ALE/UpNDown-v5",
282
+ + "atari-venture": "ALE/Venture-v5",
283
+ + "atari-videopinball": "ALE/VideoPinball-v5",
284
+ + "atari-wizardofwor": "ALE/WizardOfWor-v5",
285
+ + "atari-yarsrevenge": "ALE/YarsRevenge-v5",
286
+ + "atari-zaxxon": "ALE/Zaxxon-v5",
287
+ "babyai-action-obj-door": "BabyAI-ActionObjDoor-v0",
288
+ "babyai-blocked-unlock-pickup": "BabyAI-BlockedUnlockPickup-v0",
289
+ "babyai-boss-level-no-unlock": "BabyAI-BossLevelNoUnlock-v0",
290
+ @@ -217,7 +219,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper):
291
+ """
292
+ Wrapper for BabyAI environments.
293
 
294
+ - Flatten the image and direction observations and concatenate them.
295
+ + Flatten the pseudo-image and concatenante it to the direction observation.
296
+ """
 
 
 
297
 
298
+ def __init__(self, env: Env) -> None:
299
+ @@ -231,7 +233,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper):
300
+ }
301
+ )
302
 
303
+ - def observation(self, observation):
304
+ + def observation(self, observation: Dict[str, np.ndarray]):
305
+ discrete_observations = np.append(observation["direction"], observation["image"].flatten())
306
+ return {
307
+ "text_observations": observation["mission"],
308
+ @@ -239,9 +241,15 @@ class BabyAIDictObservationWrapper(ObservationWrapper):
309
+ }
310
+
311
+
312
+ +class FloatRewardWrapper(RewardWrapper):
313
+ + def reward(self, reward):
314
+ + return float(reward)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  +
316
+ +
317
+ def make_babyai(task_name: str):
318
+ env = gym.make(TASK_TO_ENV_MAPPING[task_name])
319
+ env = BabyAIDictObservationWrapper(env)
320
+ + env = FloatRewardWrapper(env)
321
+ return env
322
+
323
+
324
  diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py
325
+ index f8531ee..44f5f91 100644
326
  --- a/gia/eval/rl/gym_evaluator.py
327
  +++ b/gia/eval/rl/gym_evaluator.py
328
  @@ -1,7 +1,7 @@
 
330
  from gym.vector.vector_env import VectorEnv
331
 
332
  -from gia.eval.mappings import TASK_TO_ENV_MAPPING
333
+ +# from gia.eval.mappings import TASK_TO_ENV_MAPPING
334
  from gia.eval.rl.rl_evaluator import RLEvaluator
335
 
336
 
337
+ diff --git a/tests/eval/rl/envs/test_core.py b/tests/eval/rl/envs/test_core.py
338
+ index e048772..d572a9d 100644
339
+ --- a/tests/eval/rl/envs/test_core.py
340
+ +++ b/tests/eval/rl/envs/test_core.py
341
+ @@ -5,16 +5,19 @@ from gia.eval.rl import make
342
+ from gia.eval.rl.envs.core import get_task_names
343
 
344
 
345
+ +OBS_KEYS = {"discrete_observations", "continuous_observations", "image_observations", "text_observations"}
346
+ +
 
 
347
  +
348
+ @pytest.mark.parametrize("task_name", get_task_names())
349
+ def test_make(task_name: str):
350
+ - num_envs = 2
351
+ - env = make(task_name, num_envs=num_envs)
352
+ + env = make(task_name)
353
+ observation, info = env.reset()
354
+ for _ in range(10):
355
+ - action_space = env.single_action_space if hasattr(env, "single_action_space") else env.action_space
356
+ - action = np.array([action_space.sample() for _ in range(num_envs)])
357
+ + action = np.array(env.action_space.sample())
358
+ observation, reward, terminated, truncated, info = env.step(action)
359
+ - assert reward.shape == (num_envs,)
360
+ - assert terminated.shape == (num_envs,)
361
+ - assert truncated.shape == (num_envs,)
362
+ + assert isinstance(info, dict)
363
+ + assert set(observation.keys()).issubset(OBS_KEYS)
364
+ + assert isinstance(reward, float)
365
+ + assert isinstance(terminated, bool)
366
+ + assert isinstance(truncated, bool)
367
+ assert isinstance(info, dict)
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
sf_log.txt CHANGED
The diff for this file is too large to render. See raw diff