diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh index acfe879..0185b2b 100755 --- a/data/envs/metaworld/generate_dataset_all.sh +++ b/data/envs/metaworld/generate_dataset_all.sh @@ -2,8 +2,6 @@ ENVS=( assembly - basketball - bin-picking box-close button-press-topdown button-press-topdown-wall @@ -11,9 +9,7 @@ ENVS=( button-press-wall coffee-button coffee-pull - coffee-push dial-turn - disassemble door-close door-lock door-open @@ -22,29 +18,15 @@ ENVS=( drawer-open faucet-close faucet-open - hammer hand-insert handle-press-side handle-press handle-pull-side handle-pull lever-pull - peg-insert-side - peg-unplug-side - pick-out-of-hole - pick-place - pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide push-back push push-wall - reach - reach-wall - shelf-place - soccer stick-pull stick-push sweep-into @@ -54,6 +36,6 @@ ENVS=( ) for ENV in "${ENVS[@]}"; do - python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 + # python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2 python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir done diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh index dbf328a..166ccb8 100755 --- a/data/envs/metaworld/train_all.sh +++ b/data/envs/metaworld/train_all.sh @@ -1,56 +1,10 @@ #!/bin/bash ENVS=( - assembly - basketball bin-picking - box-close - button-press-topdown - button-press-topdown-wall - button-press - button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open - hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull peg-insert-side - peg-unplug-side - pick-out-of-hole - pick-place pick-place-wall - plate-slide-back-side - plate-slide-back - plate-slide-side - plate-slide - push-back - push - push-wall - reach - reach-wall - shelf-place - soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open ) for ENV in "${ENVS[@]}"; do diff --git a/data/envs/mujoco/create_mujoco_dataset.sh b/data/envs/mujoco/create_mujoco_dataset.sh old mode 100644 new mode 100755 index d8ce6d6..c4dfebb --- a/data/envs/mujoco/create_mujoco_dataset.sh +++ b/data/envs/mujoco/create_mujoco_dataset.sh @@ -2,7 +2,7 @@ # creates 100,000 per environment from models hosted on the hub ENVS=( - ant halfcheetah hopper doublependulum pendulum reacher swimmer walker + ant ) for ENV in "${ENVS[@]}"; do diff --git a/gia/eval/callback.py b/gia/eval/callback.py index 5c3a080..4b6198f 100644 --- a/gia/eval/callback.py +++ b/gia/eval/callback.py @@ -2,10 +2,10 @@ import glob import json import subprocess -import wandb from accelerate import Accelerator from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +import wandb from gia.config import Arguments from gia.eval.utils import is_slurm_available diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py index af6d86e..f6098d4 100644 --- a/gia/eval/rl/gia_agent.py +++ b/gia/eval/rl/gia_agent.py @@ -94,7 +94,7 @@ class GiaAgent: elif isinstance(self.observation_space, spaces.MultiDiscrete): self._observation_key = "discrete_observations" else: - raise TypeError("Unsupported observation space") + print("Unsupported observation space") if isinstance(self.action_space, spaces.Box): self._num_act_tokens = self.action_space.shape[0] diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py index f8531ee..44f5f91 100644 --- a/gia/eval/rl/gym_evaluator.py +++ b/gia/eval/rl/gym_evaluator.py @@ -1,7 +1,7 @@ import gym from gym.vector.vector_env import VectorEnv -from gia.eval.mappings import TASK_TO_ENV_MAPPING +# from gia.eval.mappings import TASK_TO_ENV_MAPPING from gia.eval.rl.rl_evaluator import RLEvaluator