diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh index acfe879..fc2b602 100755 --- a/data/envs/metaworld/generate_dataset_all.sh +++ b/data/envs/metaworld/generate_dataset_all.sh @@ -1,59 +1,15 @@ #!/bin/bash ENVS=( - assembly - basketball bin-picking - box-close - button-press-topdown - button-press-topdown-wall - button-press - button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn - disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull - peg-insert-side - peg-unplug-side pick-out-of-hole pick-place - pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide - push-back - push - push-wall - reach - reach-wall shelf-place soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open ) for ENV in "${ENVS[@]}"; do - python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 + # python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir done diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py index 095414e..0ea5bde 100644 --- a/data/envs/metaworld/train.py +++ b/data/envs/metaworld/train.py @@ -25,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse num_workers=8, num_envs_per_worker=8, worker_num_splits=2, - train_for_env_steps=10_000_000, + train_for_env_steps=30_000_000, encoder_mlp_layers=[64, 64], env_frameskip=1, nonlinearity="tanh", diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh index dbf328a..67ab9a0 100755 --- a/data/envs/metaworld/train_all.sh +++ b/data/envs/metaworld/train_all.sh @@ -1,56 +1,8 @@ #!/bin/bash ENVS=( - assembly - basketball - bin-picking - box-close - button-press-topdown - button-press-topdown-wall - button-press - button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open - hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull peg-insert-side - peg-unplug-side - pick-out-of-hole - pick-place - pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide - push-back - push - push-wall - reach - reach-wall - shelf-place - soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open ) for ENV in "${ENVS[@]}"; do diff --git a/gia/eval/callback.py b/gia/eval/callback.py index 5c3a080..4b6198f 100644 --- a/gia/eval/callback.py +++ b/gia/eval/callback.py @@ -2,10 +2,10 @@ import glob import json import subprocess -import wandb from accelerate import Accelerator from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +import wandb from gia.config import Arguments from gia.eval.utils import is_slurm_available diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py index 22c5b49..7464ff5 100644 --- a/gia/eval/rl/envs/core.py +++ b/gia/eval/rl/envs/core.py @@ -1,6 +1,8 @@ +from typing import Dict + import gymnasium as gym import numpy as np -from gymnasium import Env, ObservationWrapper, spaces +from gymnasium import Env, ObservationWrapper, RewardWrapper, spaces from sample_factory.envs.env_wrappers import ( ClipRewardEnv, EpisodicLifeEnv, @@ -12,63 +14,63 @@ from sample_factory.envs.env_wrappers import ( TASK_TO_ENV_MAPPING = { - "atari-alien": "Alien-v4", - "atari-amidar": "Amidar-v4", - "atari-assault": "Assault-v4", - "atari-asterix": "Asterix-v4", - "atari-asteroids": "Asteroids-v4", - "atari-atlantis": "Atlantis-v4", - "atari-bankheist": "BankHeist-v4", - "atari-battlezone": "BattleZone-v4", - "atari-beamrider": "BeamRider-v4", - "atari-berzerk": "Berzerk-v4", - "atari-bowling": "Bowling-v4", - "atari-boxing": "Boxing-v4", - "atari-breakout": "Breakout-v4", - "atari-centipede": "Centipede-v4", - "atari-choppercommand": "ChopperCommand-v4", - "atari-crazyclimber": "CrazyClimber-v4", - "atari-defender": "Defender-v4", - "atari-demonattack": "DemonAttack-v4", - "atari-doubledunk": "DoubleDunk-v4", - "atari-enduro": "Enduro-v4", - "atari-fishingderby": "FishingDerby-v4", - "atari-freeway": "Freeway-v4", - "atari-frostbite": "Frostbite-v4", - "atari-gopher": "Gopher-v4", - "atari-gravitar": "Gravitar-v4", - "atari-hero": "Hero-v4", - "atari-icehockey": "IceHockey-v4", - "atari-jamesbond": "Jamesbond-v4", - "atari-kangaroo": "Kangaroo-v4", - "atari-krull": "Krull-v4", - "atari-kungfumaster": "KungFuMaster-v4", - "atari-montezumarevenge": "MontezumaRevenge-v4", - "atari-mspacman": "MsPacman-v4", - "atari-namethisgame": "NameThisGame-v4", - "atari-phoenix": "Phoenix-v4", - "atari-pitfall": "Pitfall-v4", - "atari-pong": "Pong-v4", - "atari-privateeye": "PrivateEye-v4", - "atari-qbert": "Qbert-v4", - "atari-riverraid": "Riverraid-v4", - "atari-roadrunner": "RoadRunner-v4", - "atari-robotank": "Robotank-v4", - "atari-seaquest": "Seaquest-v4", - "atari-skiing": "Skiing-v4", - "atari-solaris": "Solaris-v4", - "atari-spaceinvaders": "SpaceInvaders-v4", - "atari-stargunner": "StarGunner-v4", + "atari-alien": "ALE/Alien-v5", + "atari-amidar": "ALE/Amidar-v5", + "atari-assault": "ALE/Assault-v5", + "atari-asterix": "ALE/Asterix-v5", + "atari-asteroids": "ALE/Asteroids-v5", + "atari-atlantis": "ALE/Atlantis-v5", + "atari-bankheist": "ALE/BankHeist-v5", + "atari-battlezone": "ALE/BattleZone-v5", + "atari-beamrider": "ALE/BeamRider-v5", + "atari-berzerk": "ALE/Berzerk-v5", + "atari-bowling": "ALE/Bowling-v5", + "atari-boxing": "ALE/Boxing-v5", + "atari-breakout": "ALE/Breakout-v5", + "atari-centipede": "ALE/Centipede-v5", + "atari-choppercommand": "ALE/ChopperCommand-v5", + "atari-crazyclimber": "ALE/CrazyClimber-v5", + "atari-defender": "ALE/Defender-v5", + "atari-demonattack": "ALE/DemonAttack-v5", + "atari-doubledunk": "ALE/DoubleDunk-v5", + "atari-enduro": "ALE/Enduro-v5", + "atari-fishingderby": "ALE/FishingDerby-v5", + "atari-freeway": "ALE/Freeway-v5", + "atari-frostbite": "ALE/Frostbite-v5", + "atari-gopher": "ALE/Gopher-v5", + "atari-gravitar": "ALE/Gravitar-v5", + "atari-hero": "ALE/Hero-v5", + "atari-icehockey": "ALE/IceHockey-v5", + "atari-jamesbond": "ALE/Jamesbond-v5", + "atari-kangaroo": "ALE/Kangaroo-v5", + "atari-krull": "ALE/Krull-v5", + "atari-kungfumaster": "ALE/KungFuMaster-v5", + "atari-montezumarevenge": "ALE/MontezumaRevenge-v5", + "atari-mspacman": "ALE/MsPacman-v5", + "atari-namethisgame": "ALE/NameThisGame-v5", + "atari-phoenix": "ALE/Phoenix-v5", + "atari-pitfall": "ALE/Pitfall-v5", + "atari-pong": "ALE/Pong-v5", + "atari-privateeye": "ALE/PrivateEye-v5", + "atari-qbert": "ALE/Qbert-v5", + "atari-riverraid": "ALE/Riverraid-v5", + "atari-roadrunner": "ALE/RoadRunner-v5", + "atari-robotank": "ALE/Robotank-v5", + "atari-seaquest": "ALE/Seaquest-v5", + "atari-skiing": "ALE/Skiing-v5", + "atari-solaris": "ALE/Solaris-v5", + "atari-spaceinvaders": "ALE/SpaceInvaders-v5", + "atari-stargunner": "ALE/StarGunner-v5", "atari-surround": "ALE/Surround-v5", - "atari-tennis": "Tennis-v4", - "atari-timepilot": "TimePilot-v4", - "atari-tutankham": "Tutankham-v4", - "atari-upndown": "UpNDown-v4", - "atari-venture": "Venture-v4", - "atari-videopinball": "VideoPinball-v4", - "atari-wizardofwor": "WizardOfWor-v4", - "atari-yarsrevenge": "YarsRevenge-v4", - "atari-zaxxon": "Zaxxon-v4", + "atari-tennis": "ALE/Tennis-v5", + "atari-timepilot": "ALE/TimePilot-v5", + "atari-tutankham": "ALE/Tutankham-v5", + "atari-upndown": "ALE/UpNDown-v5", + "atari-venture": "ALE/Venture-v5", + "atari-videopinball": "ALE/VideoPinball-v5", + "atari-wizardofwor": "ALE/WizardOfWor-v5", + "atari-yarsrevenge": "ALE/YarsRevenge-v5", + "atari-zaxxon": "ALE/Zaxxon-v5", "babyai-action-obj-door": "BabyAI-ActionObjDoor-v0", "babyai-blocked-unlock-pickup": "BabyAI-BlockedUnlockPickup-v0", "babyai-boss-level-no-unlock": "BabyAI-BossLevelNoUnlock-v0", @@ -217,7 +219,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper): """ Wrapper for BabyAI environments. - Flatten the image and direction observations and concatenate them. + Flatten the pseudo-image and concatenante it to the direction observation. """ def __init__(self, env: Env) -> None: @@ -231,7 +233,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper): } ) - def observation(self, observation): + def observation(self, observation: Dict[str, np.ndarray]): discrete_observations = np.append(observation["direction"], observation["image"].flatten()) return { "text_observations": observation["mission"], @@ -239,9 +241,15 @@ class BabyAIDictObservationWrapper(ObservationWrapper): } +class FloatRewardWrapper(RewardWrapper): + def reward(self, reward): + return float(reward) + + def make_babyai(task_name: str): env = gym.make(TASK_TO_ENV_MAPPING[task_name]) env = BabyAIDictObservationWrapper(env) + env = FloatRewardWrapper(env) return env diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py index f8531ee..44f5f91 100644 --- a/gia/eval/rl/gym_evaluator.py +++ b/gia/eval/rl/gym_evaluator.py @@ -1,7 +1,7 @@ import gym from gym.vector.vector_env import VectorEnv -from gia.eval.mappings import TASK_TO_ENV_MAPPING +# from gia.eval.mappings import TASK_TO_ENV_MAPPING from gia.eval.rl.rl_evaluator import RLEvaluator diff --git a/tests/eval/rl/envs/test_core.py b/tests/eval/rl/envs/test_core.py index e048772..d572a9d 100644 --- a/tests/eval/rl/envs/test_core.py +++ b/tests/eval/rl/envs/test_core.py @@ -5,16 +5,19 @@ from gia.eval.rl import make from gia.eval.rl.envs.core import get_task_names +OBS_KEYS = {"discrete_observations", "continuous_observations", "image_observations", "text_observations"} + + @pytest.mark.parametrize("task_name", get_task_names()) def test_make(task_name: str): - num_envs = 2 - env = make(task_name, num_envs=num_envs) + env = make(task_name) observation, info = env.reset() for _ in range(10): - action_space = env.single_action_space if hasattr(env, "single_action_space") else env.action_space - action = np.array([action_space.sample() for _ in range(num_envs)]) + action = np.array(env.action_space.sample()) observation, reward, terminated, truncated, info = env.step(action) - assert reward.shape == (num_envs,) - assert terminated.shape == (num_envs,) - assert truncated.shape == (num_envs,) + assert isinstance(info, dict) + assert set(observation.keys()).issubset(OBS_KEYS) + assert isinstance(reward, float) + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) assert isinstance(info, dict)