peg-insert-side-v2 / git.diff
qgallouedec's picture
qgallouedec HF Staff
Upload folder using huggingface_hub
24b769c
diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
index acfe879..fc2b602 100755
--- a/data/envs/metaworld/generate_dataset_all.sh
+++ b/data/envs/metaworld/generate_dataset_all.sh
@@ -1,59 +1,15 @@
#!/bin/bash
ENVS=(
- assembly
- basketball
bin-picking
- box-close
- button-press-topdown
- button-press-topdown-wall
- button-press
- button-press-wall
- coffee-button
- coffee-pull
- coffee-push
- dial-turn
- disassemble
- door-close
- door-lock
- door-open
- door-unlock
- drawer-close
- drawer-open
- faucet-close
- faucet-open
hammer
- hand-insert
- handle-press-side
- handle-press
- handle-pull-side
- handle-pull
- lever-pull
- peg-insert-side
- peg-unplug-side
pick-out-of-hole
pick-place
- pick-place-wall
- plate-slide-back-side
- plate-slide-back
- plate-slide-side
- plate-slide
- push-back
- push
- push-wall
- reach
- reach-wall
shelf-place
soccer
- stick-pull
- stick-push
- sweep-into
- sweep
- window-close
- window-open
)
for ENV in "${ENVS[@]}"; do
- python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2
+ # python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2
python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
done
diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py
index 095414e..0ea5bde 100644
--- a/data/envs/metaworld/train.py
+++ b/data/envs/metaworld/train.py
@@ -25,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
num_workers=8,
num_envs_per_worker=8,
worker_num_splits=2,
- train_for_env_steps=10_000_000,
+ train_for_env_steps=30_000_000,
encoder_mlp_layers=[64, 64],
env_frameskip=1,
nonlinearity="tanh",
diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh
index dbf328a..67ab9a0 100755
--- a/data/envs/metaworld/train_all.sh
+++ b/data/envs/metaworld/train_all.sh
@@ -1,56 +1,8 @@
#!/bin/bash
ENVS=(
- assembly
- basketball
- bin-picking
- box-close
- button-press-topdown
- button-press-topdown-wall
- button-press
- button-press-wall
- coffee-button
- coffee-pull
- coffee-push
- dial-turn
disassemble
- door-close
- door-lock
- door-open
- door-unlock
- drawer-close
- drawer-open
- faucet-close
- faucet-open
- hammer
- hand-insert
- handle-press-side
- handle-press
- handle-pull-side
- handle-pull
- lever-pull
peg-insert-side
- peg-unplug-side
- pick-out-of-hole
- pick-place
- pick-place-wall
- plate-slide-back-side
- plate-slide-back
- plate-slide-side
- plate-slide
- push-back
- push
- push-wall
- reach
- reach-wall
- shelf-place
- soccer
- stick-pull
- stick-push
- sweep-into
- sweep
- window-close
- window-open
)
for ENV in "${ENVS[@]}"; do
diff --git a/gia/eval/callback.py b/gia/eval/callback.py
index 5c3a080..4b6198f 100644
--- a/gia/eval/callback.py
+++ b/gia/eval/callback.py
@@ -2,10 +2,10 @@ import glob
import json
import subprocess
-import wandb
from accelerate import Accelerator
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
+import wandb
from gia.config import Arguments
from gia.eval.utils import is_slurm_available
diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
index 22c5b49..7464ff5 100644
--- a/gia/eval/rl/envs/core.py
+++ b/gia/eval/rl/envs/core.py
@@ -1,6 +1,8 @@
+from typing import Dict
+
import gymnasium as gym
import numpy as np
-from gymnasium import Env, ObservationWrapper, spaces
+from gymnasium import Env, ObservationWrapper, RewardWrapper, spaces
from sample_factory.envs.env_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
@@ -12,63 +14,63 @@ from sample_factory.envs.env_wrappers import (
TASK_TO_ENV_MAPPING = {
- "atari-alien": "Alien-v4",
- "atari-amidar": "Amidar-v4",
- "atari-assault": "Assault-v4",
- "atari-asterix": "Asterix-v4",
- "atari-asteroids": "Asteroids-v4",
- "atari-atlantis": "Atlantis-v4",
- "atari-bankheist": "BankHeist-v4",
- "atari-battlezone": "BattleZone-v4",
- "atari-beamrider": "BeamRider-v4",
- "atari-berzerk": "Berzerk-v4",
- "atari-bowling": "Bowling-v4",
- "atari-boxing": "Boxing-v4",
- "atari-breakout": "Breakout-v4",
- "atari-centipede": "Centipede-v4",
- "atari-choppercommand": "ChopperCommand-v4",
- "atari-crazyclimber": "CrazyClimber-v4",
- "atari-defender": "Defender-v4",
- "atari-demonattack": "DemonAttack-v4",
- "atari-doubledunk": "DoubleDunk-v4",
- "atari-enduro": "Enduro-v4",
- "atari-fishingderby": "FishingDerby-v4",
- "atari-freeway": "Freeway-v4",
- "atari-frostbite": "Frostbite-v4",
- "atari-gopher": "Gopher-v4",
- "atari-gravitar": "Gravitar-v4",
- "atari-hero": "Hero-v4",
- "atari-icehockey": "IceHockey-v4",
- "atari-jamesbond": "Jamesbond-v4",
- "atari-kangaroo": "Kangaroo-v4",
- "atari-krull": "Krull-v4",
- "atari-kungfumaster": "KungFuMaster-v4",
- "atari-montezumarevenge": "MontezumaRevenge-v4",
- "atari-mspacman": "MsPacman-v4",
- "atari-namethisgame": "NameThisGame-v4",
- "atari-phoenix": "Phoenix-v4",
- "atari-pitfall": "Pitfall-v4",
- "atari-pong": "Pong-v4",
- "atari-privateeye": "PrivateEye-v4",
- "atari-qbert": "Qbert-v4",
- "atari-riverraid": "Riverraid-v4",
- "atari-roadrunner": "RoadRunner-v4",
- "atari-robotank": "Robotank-v4",
- "atari-seaquest": "Seaquest-v4",
- "atari-skiing": "Skiing-v4",
- "atari-solaris": "Solaris-v4",
- "atari-spaceinvaders": "SpaceInvaders-v4",
- "atari-stargunner": "StarGunner-v4",
+ "atari-alien": "ALE/Alien-v5",
+ "atari-amidar": "ALE/Amidar-v5",
+ "atari-assault": "ALE/Assault-v5",
+ "atari-asterix": "ALE/Asterix-v5",
+ "atari-asteroids": "ALE/Asteroids-v5",
+ "atari-atlantis": "ALE/Atlantis-v5",
+ "atari-bankheist": "ALE/BankHeist-v5",
+ "atari-battlezone": "ALE/BattleZone-v5",
+ "atari-beamrider": "ALE/BeamRider-v5",
+ "atari-berzerk": "ALE/Berzerk-v5",
+ "atari-bowling": "ALE/Bowling-v5",
+ "atari-boxing": "ALE/Boxing-v5",
+ "atari-breakout": "ALE/Breakout-v5",
+ "atari-centipede": "ALE/Centipede-v5",
+ "atari-choppercommand": "ALE/ChopperCommand-v5",
+ "atari-crazyclimber": "ALE/CrazyClimber-v5",
+ "atari-defender": "ALE/Defender-v5",
+ "atari-demonattack": "ALE/DemonAttack-v5",
+ "atari-doubledunk": "ALE/DoubleDunk-v5",
+ "atari-enduro": "ALE/Enduro-v5",
+ "atari-fishingderby": "ALE/FishingDerby-v5",
+ "atari-freeway": "ALE/Freeway-v5",
+ "atari-frostbite": "ALE/Frostbite-v5",
+ "atari-gopher": "ALE/Gopher-v5",
+ "atari-gravitar": "ALE/Gravitar-v5",
+ "atari-hero": "ALE/Hero-v5",
+ "atari-icehockey": "ALE/IceHockey-v5",
+ "atari-jamesbond": "ALE/Jamesbond-v5",
+ "atari-kangaroo": "ALE/Kangaroo-v5",
+ "atari-krull": "ALE/Krull-v5",
+ "atari-kungfumaster": "ALE/KungFuMaster-v5",
+ "atari-montezumarevenge": "ALE/MontezumaRevenge-v5",
+ "atari-mspacman": "ALE/MsPacman-v5",
+ "atari-namethisgame": "ALE/NameThisGame-v5",
+ "atari-phoenix": "ALE/Phoenix-v5",
+ "atari-pitfall": "ALE/Pitfall-v5",
+ "atari-pong": "ALE/Pong-v5",
+ "atari-privateeye": "ALE/PrivateEye-v5",
+ "atari-qbert": "ALE/Qbert-v5",
+ "atari-riverraid": "ALE/Riverraid-v5",
+ "atari-roadrunner": "ALE/RoadRunner-v5",
+ "atari-robotank": "ALE/Robotank-v5",
+ "atari-seaquest": "ALE/Seaquest-v5",
+ "atari-skiing": "ALE/Skiing-v5",
+ "atari-solaris": "ALE/Solaris-v5",
+ "atari-spaceinvaders": "ALE/SpaceInvaders-v5",
+ "atari-stargunner": "ALE/StarGunner-v5",
"atari-surround": "ALE/Surround-v5",
- "atari-tennis": "Tennis-v4",
- "atari-timepilot": "TimePilot-v4",
- "atari-tutankham": "Tutankham-v4",
- "atari-upndown": "UpNDown-v4",
- "atari-venture": "Venture-v4",
- "atari-videopinball": "VideoPinball-v4",
- "atari-wizardofwor": "WizardOfWor-v4",
- "atari-yarsrevenge": "YarsRevenge-v4",
- "atari-zaxxon": "Zaxxon-v4",
+ "atari-tennis": "ALE/Tennis-v5",
+ "atari-timepilot": "ALE/TimePilot-v5",
+ "atari-tutankham": "ALE/Tutankham-v5",
+ "atari-upndown": "ALE/UpNDown-v5",
+ "atari-venture": "ALE/Venture-v5",
+ "atari-videopinball": "ALE/VideoPinball-v5",
+ "atari-wizardofwor": "ALE/WizardOfWor-v5",
+ "atari-yarsrevenge": "ALE/YarsRevenge-v5",
+ "atari-zaxxon": "ALE/Zaxxon-v5",
"babyai-action-obj-door": "BabyAI-ActionObjDoor-v0",
"babyai-blocked-unlock-pickup": "BabyAI-BlockedUnlockPickup-v0",
"babyai-boss-level-no-unlock": "BabyAI-BossLevelNoUnlock-v0",
@@ -217,7 +219,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper):
"""
Wrapper for BabyAI environments.
- Flatten the image and direction observations and concatenate them.
+ Flatten the pseudo-image and concatenante it to the direction observation.
"""
def __init__(self, env: Env) -> None:
@@ -231,7 +233,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper):
}
)
- def observation(self, observation):
+ def observation(self, observation: Dict[str, np.ndarray]):
discrete_observations = np.append(observation["direction"], observation["image"].flatten())
return {
"text_observations": observation["mission"],
@@ -239,9 +241,15 @@ class BabyAIDictObservationWrapper(ObservationWrapper):
}
+class FloatRewardWrapper(RewardWrapper):
+ def reward(self, reward):
+ return float(reward)
+
+
def make_babyai(task_name: str):
env = gym.make(TASK_TO_ENV_MAPPING[task_name])
env = BabyAIDictObservationWrapper(env)
+ env = FloatRewardWrapper(env)
return env
diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py
index f8531ee..44f5f91 100644
--- a/gia/eval/rl/gym_evaluator.py
+++ b/gia/eval/rl/gym_evaluator.py
@@ -1,7 +1,7 @@
import gym
from gym.vector.vector_env import VectorEnv
-from gia.eval.mappings import TASK_TO_ENV_MAPPING
+# from gia.eval.mappings import TASK_TO_ENV_MAPPING
from gia.eval.rl.rl_evaluator import RLEvaluator
diff --git a/tests/eval/rl/envs/test_core.py b/tests/eval/rl/envs/test_core.py
index e048772..d572a9d 100644
--- a/tests/eval/rl/envs/test_core.py
+++ b/tests/eval/rl/envs/test_core.py
@@ -5,16 +5,19 @@ from gia.eval.rl import make
from gia.eval.rl.envs.core import get_task_names
+OBS_KEYS = {"discrete_observations", "continuous_observations", "image_observations", "text_observations"}
+
+
@pytest.mark.parametrize("task_name", get_task_names())
def test_make(task_name: str):
- num_envs = 2
- env = make(task_name, num_envs=num_envs)
+ env = make(task_name)
observation, info = env.reset()
for _ in range(10):
- action_space = env.single_action_space if hasattr(env, "single_action_space") else env.action_space
- action = np.array([action_space.sample() for _ in range(num_envs)])
+ action = np.array(env.action_space.sample())
observation, reward, terminated, truncated, info = env.step(action)
- assert reward.shape == (num_envs,)
- assert terminated.shape == (num_envs,)
- assert truncated.shape == (num_envs,)
+ assert isinstance(info, dict)
+ assert set(observation.keys()).issubset(OBS_KEYS)
+ assert isinstance(reward, float)
+ assert isinstance(terminated, bool)
+ assert isinstance(truncated, bool)
assert isinstance(info, dict)