|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,59 +1,15 @@ |
|
|
#!/bin/bash |
|
|
|
|
|
ENVS=( |
|
|
- assembly |
|
|
- basketball |
|
|
bin-picking |
|
|
- box-close |
|
|
- button-press-topdown |
|
|
- button-press-topdown-wall |
|
|
- button-press |
|
|
- button-press-wall |
|
|
- coffee-button |
|
|
- coffee-pull |
|
|
- coffee-push |
|
|
- dial-turn |
|
|
- disassemble |
|
|
- door-close |
|
|
- door-lock |
|
|
- door-open |
|
|
- door-unlock |
|
|
- drawer-close |
|
|
- drawer-open |
|
|
- faucet-close |
|
|
- faucet-open |
|
|
hammer |
|
|
- hand-insert |
|
|
- handle-press-side |
|
|
- handle-press |
|
|
- handle-pull-side |
|
|
- handle-pull |
|
|
- lever-pull |
|
|
- peg-insert-side |
|
|
- peg-unplug-side |
|
|
pick-out-of-hole |
|
|
pick-place |
|
|
- pick-place-wall |
|
|
- plate-slide-back-side |
|
|
- plate-slide-back |
|
|
- plate-slide-side |
|
|
- plate-slide |
|
|
- push-back |
|
|
- push |
|
|
- push-wall |
|
|
- reach |
|
|
- reach-wall |
|
|
shelf-place |
|
|
soccer |
|
|
- stick-pull |
|
|
- stick-push |
|
|
- sweep-into |
|
|
- sweep |
|
|
- window-close |
|
|
- window-open |
|
|
) |
|
|
|
|
|
for ENV in "${ENVS[@]}"; do |
|
|
- python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 |
|
|
+ # python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 |
|
|
python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir |
|
|
done |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -25,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse |
|
|
num_workers=8, |
|
|
num_envs_per_worker=8, |
|
|
worker_num_splits=2, |
|
|
- train_for_env_steps=10_000_000, |
|
|
+ train_for_env_steps=30_000_000, |
|
|
encoder_mlp_layers=[64, 64], |
|
|
env_frameskip=1, |
|
|
nonlinearity="tanh", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,56 +1,8 @@ |
|
|
#!/bin/bash |
|
|
|
|
|
ENVS=( |
|
|
- assembly |
|
|
- basketball |
|
|
- bin-picking |
|
|
- box-close |
|
|
- button-press-topdown |
|
|
- button-press-topdown-wall |
|
|
- button-press |
|
|
- button-press-wall |
|
|
- coffee-button |
|
|
- coffee-pull |
|
|
- coffee-push |
|
|
- dial-turn |
|
|
disassemble |
|
|
- door-close |
|
|
- door-lock |
|
|
- door-open |
|
|
- door-unlock |
|
|
- drawer-close |
|
|
- drawer-open |
|
|
- faucet-close |
|
|
- faucet-open |
|
|
- hammer |
|
|
- hand-insert |
|
|
- handle-press-side |
|
|
- handle-press |
|
|
- handle-pull-side |
|
|
- handle-pull |
|
|
- lever-pull |
|
|
peg-insert-side |
|
|
- peg-unplug-side |
|
|
- pick-out-of-hole |
|
|
- pick-place |
|
|
- pick-place-wall |
|
|
- plate-slide-back-side |
|
|
- plate-slide-back |
|
|
- plate-slide-side |
|
|
- plate-slide |
|
|
- push-back |
|
|
- push |
|
|
- push-wall |
|
|
- reach |
|
|
- reach-wall |
|
|
- shelf-place |
|
|
- soccer |
|
|
- stick-pull |
|
|
- stick-push |
|
|
- sweep-into |
|
|
- sweep |
|
|
- window-close |
|
|
- window-open |
|
|
) |
|
|
|
|
|
for ENV in "${ENVS[@]}"; do |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -2,10 +2,10 @@ import glob |
|
|
import json |
|
|
import subprocess |
|
|
|
|
|
-import wandb |
|
|
from accelerate import Accelerator |
|
|
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments |
|
|
|
|
|
+import wandb |
|
|
from gia.config import Arguments |
|
|
from gia.eval.utils import is_slurm_available |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,6 +1,8 @@ |
|
|
+from typing import Dict |
|
|
+ |
|
|
import gymnasium as gym |
|
|
import numpy as np |
|
|
-from gymnasium import Env, ObservationWrapper, spaces |
|
|
+from gymnasium import Env, ObservationWrapper, RewardWrapper, spaces |
|
|
from sample_factory.envs.env_wrappers import ( |
|
|
ClipRewardEnv, |
|
|
EpisodicLifeEnv, |
|
|
@@ -12,63 +14,63 @@ from sample_factory.envs.env_wrappers import ( |
|
|
|
|
|
|
|
|
TASK_TO_ENV_MAPPING = { |
|
|
- "atari-alien": "Alien-v4", |
|
|
- "atari-amidar": "Amidar-v4", |
|
|
- "atari-assault": "Assault-v4", |
|
|
- "atari-asterix": "Asterix-v4", |
|
|
- "atari-asteroids": "Asteroids-v4", |
|
|
- "atari-atlantis": "Atlantis-v4", |
|
|
- "atari-bankheist": "BankHeist-v4", |
|
|
- "atari-battlezone": "BattleZone-v4", |
|
|
- "atari-beamrider": "BeamRider-v4", |
|
|
- "atari-berzerk": "Berzerk-v4", |
|
|
- "atari-bowling": "Bowling-v4", |
|
|
- "atari-boxing": "Boxing-v4", |
|
|
- "atari-breakout": "Breakout-v4", |
|
|
- "atari-centipede": "Centipede-v4", |
|
|
- "atari-choppercommand": "ChopperCommand-v4", |
|
|
- "atari-crazyclimber": "CrazyClimber-v4", |
|
|
- "atari-defender": "Defender-v4", |
|
|
- "atari-demonattack": "DemonAttack-v4", |
|
|
- "atari-doubledunk": "DoubleDunk-v4", |
|
|
- "atari-enduro": "Enduro-v4", |
|
|
- "atari-fishingderby": "FishingDerby-v4", |
|
|
- "atari-freeway": "Freeway-v4", |
|
|
- "atari-frostbite": "Frostbite-v4", |
|
|
- "atari-gopher": "Gopher-v4", |
|
|
- "atari-gravitar": "Gravitar-v4", |
|
|
- "atari-hero": "Hero-v4", |
|
|
- "atari-icehockey": "IceHockey-v4", |
|
|
- "atari-jamesbond": "Jamesbond-v4", |
|
|
- "atari-kangaroo": "Kangaroo-v4", |
|
|
- "atari-krull": "Krull-v4", |
|
|
- "atari-kungfumaster": "KungFuMaster-v4", |
|
|
- "atari-montezumarevenge": "MontezumaRevenge-v4", |
|
|
- "atari-mspacman": "MsPacman-v4", |
|
|
- "atari-namethisgame": "NameThisGame-v4", |
|
|
- "atari-phoenix": "Phoenix-v4", |
|
|
- "atari-pitfall": "Pitfall-v4", |
|
|
- "atari-pong": "Pong-v4", |
|
|
- "atari-privateeye": "PrivateEye-v4", |
|
|
- "atari-qbert": "Qbert-v4", |
|
|
- "atari-riverraid": "Riverraid-v4", |
|
|
- "atari-roadrunner": "RoadRunner-v4", |
|
|
- "atari-robotank": "Robotank-v4", |
|
|
- "atari-seaquest": "Seaquest-v4", |
|
|
- "atari-skiing": "Skiing-v4", |
|
|
- "atari-solaris": "Solaris-v4", |
|
|
- "atari-spaceinvaders": "SpaceInvaders-v4", |
|
|
- "atari-stargunner": "StarGunner-v4", |
|
|
+ "atari-alien": "ALE/Alien-v5", |
|
|
+ "atari-amidar": "ALE/Amidar-v5", |
|
|
+ "atari-assault": "ALE/Assault-v5", |
|
|
+ "atari-asterix": "ALE/Asterix-v5", |
|
|
+ "atari-asteroids": "ALE/Asteroids-v5", |
|
|
+ "atari-atlantis": "ALE/Atlantis-v5", |
|
|
+ "atari-bankheist": "ALE/BankHeist-v5", |
|
|
+ "atari-battlezone": "ALE/BattleZone-v5", |
|
|
+ "atari-beamrider": "ALE/BeamRider-v5", |
|
|
+ "atari-berzerk": "ALE/Berzerk-v5", |
|
|
+ "atari-bowling": "ALE/Bowling-v5", |
|
|
+ "atari-boxing": "ALE/Boxing-v5", |
|
|
+ "atari-breakout": "ALE/Breakout-v5", |
|
|
+ "atari-centipede": "ALE/Centipede-v5", |
|
|
+ "atari-choppercommand": "ALE/ChopperCommand-v5", |
|
|
+ "atari-crazyclimber": "ALE/CrazyClimber-v5", |
|
|
+ "atari-defender": "ALE/Defender-v5", |
|
|
+ "atari-demonattack": "ALE/DemonAttack-v5", |
|
|
+ "atari-doubledunk": "ALE/DoubleDunk-v5", |
|
|
+ "atari-enduro": "ALE/Enduro-v5", |
|
|
+ "atari-fishingderby": "ALE/FishingDerby-v5", |
|
|
+ "atari-freeway": "ALE/Freeway-v5", |
|
|
+ "atari-frostbite": "ALE/Frostbite-v5", |
|
|
+ "atari-gopher": "ALE/Gopher-v5", |
|
|
+ "atari-gravitar": "ALE/Gravitar-v5", |
|
|
+ "atari-hero": "ALE/Hero-v5", |
|
|
+ "atari-icehockey": "ALE/IceHockey-v5", |
|
|
+ "atari-jamesbond": "ALE/Jamesbond-v5", |
|
|
+ "atari-kangaroo": "ALE/Kangaroo-v5", |
|
|
+ "atari-krull": "ALE/Krull-v5", |
|
|
+ "atari-kungfumaster": "ALE/KungFuMaster-v5", |
|
|
+ "atari-montezumarevenge": "ALE/MontezumaRevenge-v5", |
|
|
+ "atari-mspacman": "ALE/MsPacman-v5", |
|
|
+ "atari-namethisgame": "ALE/NameThisGame-v5", |
|
|
+ "atari-phoenix": "ALE/Phoenix-v5", |
|
|
+ "atari-pitfall": "ALE/Pitfall-v5", |
|
|
+ "atari-pong": "ALE/Pong-v5", |
|
|
+ "atari-privateeye": "ALE/PrivateEye-v5", |
|
|
+ "atari-qbert": "ALE/Qbert-v5", |
|
|
+ "atari-riverraid": "ALE/Riverraid-v5", |
|
|
+ "atari-roadrunner": "ALE/RoadRunner-v5", |
|
|
+ "atari-robotank": "ALE/Robotank-v5", |
|
|
+ "atari-seaquest": "ALE/Seaquest-v5", |
|
|
+ "atari-skiing": "ALE/Skiing-v5", |
|
|
+ "atari-solaris": "ALE/Solaris-v5", |
|
|
+ "atari-spaceinvaders": "ALE/SpaceInvaders-v5", |
|
|
+ "atari-stargunner": "ALE/StarGunner-v5", |
|
|
"atari-surround": "ALE/Surround-v5", |
|
|
- "atari-tennis": "Tennis-v4", |
|
|
- "atari-timepilot": "TimePilot-v4", |
|
|
- "atari-tutankham": "Tutankham-v4", |
|
|
- "atari-upndown": "UpNDown-v4", |
|
|
- "atari-venture": "Venture-v4", |
|
|
- "atari-videopinball": "VideoPinball-v4", |
|
|
- "atari-wizardofwor": "WizardOfWor-v4", |
|
|
- "atari-yarsrevenge": "YarsRevenge-v4", |
|
|
- "atari-zaxxon": "Zaxxon-v4", |
|
|
+ "atari-tennis": "ALE/Tennis-v5", |
|
|
+ "atari-timepilot": "ALE/TimePilot-v5", |
|
|
+ "atari-tutankham": "ALE/Tutankham-v5", |
|
|
+ "atari-upndown": "ALE/UpNDown-v5", |
|
|
+ "atari-venture": "ALE/Venture-v5", |
|
|
+ "atari-videopinball": "ALE/VideoPinball-v5", |
|
|
+ "atari-wizardofwor": "ALE/WizardOfWor-v5", |
|
|
+ "atari-yarsrevenge": "ALE/YarsRevenge-v5", |
|
|
+ "atari-zaxxon": "ALE/Zaxxon-v5", |
|
|
"babyai-action-obj-door": "BabyAI-ActionObjDoor-v0", |
|
|
"babyai-blocked-unlock-pickup": "BabyAI-BlockedUnlockPickup-v0", |
|
|
"babyai-boss-level-no-unlock": "BabyAI-BossLevelNoUnlock-v0", |
|
|
@@ -217,7 +219,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper): |
|
|
""" |
|
|
Wrapper for BabyAI environments. |
|
|
|
|
|
- Flatten the image and direction observations and concatenate them. |
|
|
+ Flatten the pseudo-image and concatenante it to the direction observation. |
|
|
""" |
|
|
|
|
|
def __init__(self, env: Env) -> None: |
|
|
@@ -231,7 +233,7 @@ class BabyAIDictObservationWrapper(ObservationWrapper): |
|
|
} |
|
|
) |
|
|
|
|
|
- def observation(self, observation): |
|
|
+ def observation(self, observation: Dict[str, np.ndarray]): |
|
|
discrete_observations = np.append(observation["direction"], observation["image"].flatten()) |
|
|
return { |
|
|
"text_observations": observation["mission"], |
|
|
@@ -239,9 +241,15 @@ class BabyAIDictObservationWrapper(ObservationWrapper): |
|
|
} |
|
|
|
|
|
|
|
|
+class FloatRewardWrapper(RewardWrapper): |
|
|
+ def reward(self, reward): |
|
|
+ return float(reward) |
|
|
+ |
|
|
+ |
|
|
def make_babyai(task_name: str): |
|
|
env = gym.make(TASK_TO_ENV_MAPPING[task_name]) |
|
|
env = BabyAIDictObservationWrapper(env) |
|
|
+ env = FloatRewardWrapper(env) |
|
|
return env |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,7 +1,7 @@ |
|
|
import gym |
|
|
from gym.vector.vector_env import VectorEnv |
|
|
|
|
|
-from gia.eval.mappings import TASK_TO_ENV_MAPPING |
|
|
+# from gia.eval.mappings import TASK_TO_ENV_MAPPING |
|
|
from gia.eval.rl.rl_evaluator import RLEvaluator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -5,16 +5,19 @@ from gia.eval.rl import make |
|
|
from gia.eval.rl.envs.core import get_task_names |
|
|
|
|
|
|
|
|
+OBS_KEYS = {"discrete_observations", "continuous_observations", "image_observations", "text_observations"} |
|
|
+ |
|
|
+ |
|
|
@pytest.mark.parametrize("task_name", get_task_names()) |
|
|
def test_make(task_name: str): |
|
|
- num_envs = 2 |
|
|
- env = make(task_name, num_envs=num_envs) |
|
|
+ env = make(task_name) |
|
|
observation, info = env.reset() |
|
|
for _ in range(10): |
|
|
- action_space = env.single_action_space if hasattr(env, "single_action_space") else env.action_space |
|
|
- action = np.array([action_space.sample() for _ in range(num_envs)]) |
|
|
+ action = np.array(env.action_space.sample()) |
|
|
observation, reward, terminated, truncated, info = env.step(action) |
|
|
- assert reward.shape == (num_envs,) |
|
|
- assert terminated.shape == (num_envs,) |
|
|
- assert truncated.shape == (num_envs,) |
|
|
+ assert isinstance(info, dict) |
|
|
+ assert set(observation.keys()).issubset(OBS_KEYS) |
|
|
+ assert isinstance(reward, float) |
|
|
+ assert isinstance(terminated, bool) |
|
|
+ assert isinstance(truncated, bool) |
|
|
assert isinstance(info, dict) |
|
|
|