Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- Metaworld/metaworld/envs/__init__.py +6 -0
- Metaworld/metaworld/envs/__pycache__/__init__.cpython-38.pyc +0 -0
- Metaworld/metaworld/envs/__pycache__/asset_path_utils.cpython-38.pyc +0 -0
- Metaworld/metaworld/envs/__pycache__/reward_utils.cpython-38.pyc +0 -0
- Metaworld/metaworld/envs/asset_path_utils.py +12 -0
- Metaworld/metaworld/envs/mujoco/__init__.py +0 -0
- Metaworld/metaworld/envs/mujoco/__pycache__/__init__.cpython-38.pyc +0 -0
- Metaworld/metaworld/envs/mujoco/__pycache__/env_dict.cpython-38.pyc +0 -0
- Metaworld/metaworld/envs/mujoco/__pycache__/mujoco_env.cpython-38.pyc +0 -0
- Metaworld/metaworld/envs/mujoco/env_dict.py +643 -0
- Metaworld/metaworld/envs/mujoco/mujoco_env.py +155 -0
- Metaworld/metaworld/envs/mujoco/sawyer_xyz/__init__.py +0 -0
- Metaworld/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py +607 -0
- Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_coffee_push.py +132 -0
- Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_dial_turn.py +115 -0
- Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_disassemble_peg.py +186 -0
- Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_plate_slide_back_side.py +119 -0
- Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_plate_slide_side.py +124 -0
- Metaworld/metaworld/envs/reward_utils.py +220 -0
- Metaworld/metaworld/policies/__pycache__/__init__.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_bin_picking_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_box_close_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_button_press_topdown_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_button_press_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_button_press_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_button_press_wall_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_button_press_wall_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_coffee_pull_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_coffee_push_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_dial_turn_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_disassemble_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_disassemble_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_door_close_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_door_lock_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_door_lock_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_door_open_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_door_unlock_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_door_unlock_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_drawer_close_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_drawer_open_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_faucet_close_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_hammer_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_hammer_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_handle_press_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_handle_press_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_handle_pull_v1_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_handle_pull_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_peg_insertion_side_v2_policy.cpython-38.pyc +0 -0
- Metaworld/metaworld/policies/__pycache__/sawyer_peg_unplug_side_v2_policy.cpython-38.pyc +0 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-close_expert.zarr/data/point_cloud/10.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-open_expert.zarr/data/point_cloud/7.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 36 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-close_expert.zarr/data/point_cloud/10.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-open_expert.zarr/data/point_cloud/7.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/6.0.0 filter=lfs diff=lfs merge=lfs -text
|
Metaworld/metaworld/envs/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from metaworld.envs.mujoco.env_dict import (ALL_V2_ENVIRONMENTS_GOAL_HIDDEN,
|
| 2 |
+
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
| 3 |
+
)
|
| 4 |
+
|
| 5 |
+
__all__ = ['ALL_V2_ENVIRONMENTS_GOAL_HIDDEN',
|
| 6 |
+
'ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE']
|
Metaworld/metaworld/envs/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (341 Bytes). View file
|
|
|
Metaworld/metaworld/envs/__pycache__/asset_path_utils.cpython-38.pyc
ADDED
|
Binary file (621 Bytes). View file
|
|
|
Metaworld/metaworld/envs/__pycache__/reward_utils.cpython-38.pyc
ADDED
|
Binary file (7.36 kB). View file
|
|
|
Metaworld/metaworld/envs/asset_path_utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
ENV_ASSET_DIR_V1 = os.path.join(os.path.dirname(__file__), 'assets_v1')
|
| 4 |
+
ENV_ASSET_DIR_V2 = os.path.join(os.path.dirname(__file__), 'assets_v2')
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def full_v1_path_for(file_name):
|
| 8 |
+
return os.path.join(ENV_ASSET_DIR_V1, file_name)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def full_v2_path_for(file_name):
|
| 12 |
+
return os.path.join(ENV_ASSET_DIR_V2, file_name)
|
Metaworld/metaworld/envs/mujoco/__init__.py
ADDED
|
File without changes
|
Metaworld/metaworld/envs/mujoco/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
Metaworld/metaworld/envs/mujoco/__pycache__/env_dict.cpython-38.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
Metaworld/metaworld/envs/mujoco/__pycache__/mujoco_env.cpython-38.pyc
ADDED
|
Binary file (5.3 kB). View file
|
|
|
Metaworld/metaworld/envs/mujoco/env_dict.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from metaworld.envs.mujoco.sawyer_xyz.v1 import (
|
| 7 |
+
SawyerNutAssemblyEnv,
|
| 8 |
+
SawyerBasketballEnv,
|
| 9 |
+
SawyerBinPickingEnv,
|
| 10 |
+
SawyerBoxCloseEnv,
|
| 11 |
+
SawyerButtonPressEnv,
|
| 12 |
+
SawyerButtonPressTopdownEnv,
|
| 13 |
+
SawyerButtonPressTopdownWallEnv,
|
| 14 |
+
SawyerButtonPressWallEnv,
|
| 15 |
+
SawyerCoffeeButtonEnv,
|
| 16 |
+
SawyerCoffeePullEnv,
|
| 17 |
+
SawyerCoffeePushEnv,
|
| 18 |
+
SawyerDialTurnEnv,
|
| 19 |
+
SawyerNutDisassembleEnv,
|
| 20 |
+
SawyerDoorEnv,
|
| 21 |
+
SawyerDoorCloseEnv,
|
| 22 |
+
SawyerDoorLockEnv,
|
| 23 |
+
SawyerDoorUnlockEnv,
|
| 24 |
+
SawyerDrawerCloseEnv,
|
| 25 |
+
SawyerDrawerOpenEnv,
|
| 26 |
+
SawyerFaucetCloseEnv,
|
| 27 |
+
SawyerFaucetOpenEnv,
|
| 28 |
+
SawyerHammerEnv,
|
| 29 |
+
SawyerHandInsertEnv,
|
| 30 |
+
SawyerHandlePressEnv,
|
| 31 |
+
SawyerHandlePressSideEnv,
|
| 32 |
+
SawyerHandlePullEnv,
|
| 33 |
+
SawyerHandlePullSideEnv,
|
| 34 |
+
SawyerLeverPullEnv,
|
| 35 |
+
SawyerPegInsertionSideEnv,
|
| 36 |
+
SawyerPegUnplugSideEnv,
|
| 37 |
+
SawyerPickOutOfHoleEnv,
|
| 38 |
+
SawyerPlateSlideEnv,
|
| 39 |
+
SawyerPlateSlideBackEnv,
|
| 40 |
+
SawyerPlateSlideBackSideEnv,
|
| 41 |
+
SawyerPlateSlideSideEnv,
|
| 42 |
+
SawyerPushBackEnv,
|
| 43 |
+
SawyerReachPushPickPlaceEnv,
|
| 44 |
+
SawyerReachPushPickPlaceWallEnv,
|
| 45 |
+
SawyerShelfPlaceEnv,
|
| 46 |
+
SawyerSoccerEnv,
|
| 47 |
+
SawyerStickPullEnv,
|
| 48 |
+
SawyerStickPushEnv,
|
| 49 |
+
SawyerSweepEnv,
|
| 50 |
+
SawyerSweepIntoGoalEnv,
|
| 51 |
+
SawyerWindowCloseEnv,
|
| 52 |
+
SawyerWindowOpenEnv,
|
| 53 |
+
)
|
| 54 |
+
from metaworld.envs.mujoco.sawyer_xyz.v2 import (
|
| 55 |
+
SawyerNutAssemblyEnvV2,
|
| 56 |
+
SawyerBasketballEnvV2,
|
| 57 |
+
SawyerBinPickingEnvV2,
|
| 58 |
+
SawyerBoxCloseEnvV2,
|
| 59 |
+
SawyerButtonPressTopdownEnvV2,
|
| 60 |
+
SawyerButtonPressTopdownWallEnvV2,
|
| 61 |
+
SawyerButtonPressEnvV2,
|
| 62 |
+
SawyerButtonPressWallEnvV2,
|
| 63 |
+
SawyerCoffeeButtonEnvV2,
|
| 64 |
+
SawyerCoffeePullEnvV2,
|
| 65 |
+
SawyerCoffeePushEnvV2,
|
| 66 |
+
SawyerDialTurnEnvV2,
|
| 67 |
+
SawyerNutDisassembleEnvV2,
|
| 68 |
+
SawyerDoorCloseEnvV2,
|
| 69 |
+
SawyerDoorLockEnvV2,
|
| 70 |
+
SawyerDoorUnlockEnvV2,
|
| 71 |
+
SawyerDoorEnvV2,
|
| 72 |
+
SawyerDrawerCloseEnvV2,
|
| 73 |
+
SawyerDrawerOpenEnvV2,
|
| 74 |
+
SawyerFaucetCloseEnvV2,
|
| 75 |
+
SawyerFaucetOpenEnvV2,
|
| 76 |
+
SawyerHammerEnvV2,
|
| 77 |
+
SawyerHandInsertEnvV2,
|
| 78 |
+
SawyerHandlePressSideEnvV2,
|
| 79 |
+
SawyerHandlePressEnvV2,
|
| 80 |
+
SawyerHandlePullSideEnvV2,
|
| 81 |
+
SawyerHandlePullEnvV2,
|
| 82 |
+
SawyerLeverPullEnvV2,
|
| 83 |
+
SawyerPegInsertionSideEnvV2,
|
| 84 |
+
SawyerPegUnplugSideEnvV2,
|
| 85 |
+
SawyerPickOutOfHoleEnvV2,
|
| 86 |
+
SawyerPickPlaceEnvV2,
|
| 87 |
+
SawyerPickPlaceWallEnvV2,
|
| 88 |
+
SawyerPlateSlideBackSideEnvV2,
|
| 89 |
+
SawyerPlateSlideBackEnvV2,
|
| 90 |
+
SawyerPlateSlideSideEnvV2,
|
| 91 |
+
SawyerPlateSlideEnvV2,
|
| 92 |
+
SawyerPushBackEnvV2,
|
| 93 |
+
SawyerPushEnvV2,
|
| 94 |
+
SawyerPushWallEnvV2,
|
| 95 |
+
SawyerReachEnvV2,
|
| 96 |
+
SawyerReachWallEnvV2,
|
| 97 |
+
SawyerShelfPlaceEnvV2,
|
| 98 |
+
SawyerSoccerEnvV2,
|
| 99 |
+
SawyerStickPullEnvV2,
|
| 100 |
+
SawyerStickPushEnvV2,
|
| 101 |
+
SawyerSweepEnvV2,
|
| 102 |
+
SawyerSweepIntoGoalEnvV2,
|
| 103 |
+
SawyerWindowCloseEnvV2,
|
| 104 |
+
SawyerWindowOpenEnvV2,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
ALL_V1_ENVIRONMENTS = OrderedDict((
|
| 109 |
+
('reach-v1', SawyerReachPushPickPlaceEnv),
|
| 110 |
+
('push-v1', SawyerReachPushPickPlaceEnv),
|
| 111 |
+
('pick-place-v1', SawyerReachPushPickPlaceEnv),
|
| 112 |
+
('door-open-v1', SawyerDoorEnv),
|
| 113 |
+
('drawer-open-v1', SawyerDrawerOpenEnv),
|
| 114 |
+
('drawer-close-v1', SawyerDrawerCloseEnv),
|
| 115 |
+
('button-press-topdown-v1', SawyerButtonPressTopdownEnv),
|
| 116 |
+
('peg-insert-side-v1', SawyerPegInsertionSideEnv),
|
| 117 |
+
('window-open-v1', SawyerWindowOpenEnv),
|
| 118 |
+
('window-close-v1', SawyerWindowCloseEnv),
|
| 119 |
+
('door-close-v1', SawyerDoorCloseEnv),
|
| 120 |
+
('reach-wall-v1', SawyerReachPushPickPlaceWallEnv),
|
| 121 |
+
('pick-place-wall-v1', SawyerReachPushPickPlaceWallEnv),
|
| 122 |
+
('push-wall-v1', SawyerReachPushPickPlaceWallEnv),
|
| 123 |
+
('button-press-v1', SawyerButtonPressEnv),
|
| 124 |
+
('button-press-topdown-wall-v1', SawyerButtonPressTopdownWallEnv),
|
| 125 |
+
('button-press-wall-v1', SawyerButtonPressWallEnv),
|
| 126 |
+
('peg-unplug-side-v1', SawyerPegUnplugSideEnv),
|
| 127 |
+
('disassemble-v1', SawyerNutDisassembleEnv),
|
| 128 |
+
('hammer-v1', SawyerHammerEnv),
|
| 129 |
+
('plate-slide-v1', SawyerPlateSlideEnv),
|
| 130 |
+
('plate-slide-side-v1', SawyerPlateSlideSideEnv),
|
| 131 |
+
('plate-slide-back-v1', SawyerPlateSlideBackEnv),
|
| 132 |
+
('plate-slide-back-side-v1', SawyerPlateSlideBackSideEnv),
|
| 133 |
+
('handle-press-v1', SawyerHandlePressEnv),
|
| 134 |
+
('handle-pull-v1', SawyerHandlePullEnv),
|
| 135 |
+
('handle-press-side-v1', SawyerHandlePressSideEnv),
|
| 136 |
+
('handle-pull-side-v1', SawyerHandlePullSideEnv),
|
| 137 |
+
('stick-push-v1', SawyerStickPushEnv),
|
| 138 |
+
('stick-pull-v1', SawyerStickPullEnv),
|
| 139 |
+
('basketball-v1', SawyerBasketballEnv),
|
| 140 |
+
('soccer-v1', SawyerSoccerEnv),
|
| 141 |
+
('faucet-open-v1', SawyerFaucetOpenEnv),
|
| 142 |
+
('faucet-close-v1', SawyerFaucetCloseEnv),
|
| 143 |
+
('coffee-push-v1', SawyerCoffeePushEnv),
|
| 144 |
+
('coffee-pull-v1', SawyerCoffeePullEnv),
|
| 145 |
+
('coffee-button-v1', SawyerCoffeeButtonEnv),
|
| 146 |
+
('sweep-v1', SawyerSweepEnv),
|
| 147 |
+
('sweep-into-v1', SawyerSweepIntoGoalEnv),
|
| 148 |
+
('pick-out-of-hole-v1', SawyerPickOutOfHoleEnv),
|
| 149 |
+
('assembly-v1', SawyerNutAssemblyEnv),
|
| 150 |
+
('shelf-place-v1', SawyerShelfPlaceEnv),
|
| 151 |
+
('push-back-v1', SawyerPushBackEnv),
|
| 152 |
+
('lever-pull-v1', SawyerLeverPullEnv),
|
| 153 |
+
('dial-turn-v1', SawyerDialTurnEnv),
|
| 154 |
+
('bin-picking-v1', SawyerBinPickingEnv),
|
| 155 |
+
('box-close-v1', SawyerBoxCloseEnv),
|
| 156 |
+
('hand-insert-v1', SawyerHandInsertEnv),
|
| 157 |
+
('door-lock-v1', SawyerDoorLockEnv),
|
| 158 |
+
('door-unlock-v1', SawyerDoorUnlockEnv),
|
| 159 |
+
))
|
| 160 |
+
|
| 161 |
+
ALL_V2_ENVIRONMENTS = OrderedDict((
|
| 162 |
+
('assembly-v2', SawyerNutAssemblyEnvV2),
|
| 163 |
+
('basketball-v2', SawyerBasketballEnvV2),
|
| 164 |
+
('bin-picking-v2', SawyerBinPickingEnvV2),
|
| 165 |
+
('box-close-v2', SawyerBoxCloseEnvV2),
|
| 166 |
+
('button-press-topdown-v2', SawyerButtonPressTopdownEnvV2),
|
| 167 |
+
('button-press-topdown-wall-v2', SawyerButtonPressTopdownWallEnvV2),
|
| 168 |
+
('button-press-v2', SawyerButtonPressEnvV2),
|
| 169 |
+
('button-press-wall-v2', SawyerButtonPressWallEnvV2),
|
| 170 |
+
('coffee-button-v2', SawyerCoffeeButtonEnvV2),
|
| 171 |
+
('coffee-pull-v2', SawyerCoffeePullEnvV2),
|
| 172 |
+
('coffee-push-v2', SawyerCoffeePushEnvV2),
|
| 173 |
+
('dial-turn-v2', SawyerDialTurnEnvV2),
|
| 174 |
+
('disassemble-v2', SawyerNutDisassembleEnvV2),
|
| 175 |
+
('door-close-v2', SawyerDoorCloseEnvV2),
|
| 176 |
+
('door-lock-v2', SawyerDoorLockEnvV2),
|
| 177 |
+
('door-open-v2', SawyerDoorEnvV2),
|
| 178 |
+
('door-unlock-v2', SawyerDoorUnlockEnvV2),
|
| 179 |
+
('hand-insert-v2', SawyerHandInsertEnvV2),
|
| 180 |
+
('drawer-close-v2', SawyerDrawerCloseEnvV2),
|
| 181 |
+
('drawer-open-v2', SawyerDrawerOpenEnvV2),
|
| 182 |
+
('faucet-open-v2', SawyerFaucetOpenEnvV2),
|
| 183 |
+
('faucet-close-v2', SawyerFaucetCloseEnvV2),
|
| 184 |
+
('hammer-v2', SawyerHammerEnvV2),
|
| 185 |
+
('handle-press-side-v2', SawyerHandlePressSideEnvV2),
|
| 186 |
+
('handle-press-v2', SawyerHandlePressEnvV2),
|
| 187 |
+
('handle-pull-side-v2', SawyerHandlePullSideEnvV2),
|
| 188 |
+
('handle-pull-v2', SawyerHandlePullEnvV2),
|
| 189 |
+
('lever-pull-v2', SawyerLeverPullEnvV2),
|
| 190 |
+
('peg-insert-side-v2', SawyerPegInsertionSideEnvV2),
|
| 191 |
+
('pick-place-wall-v2', SawyerPickPlaceWallEnvV2),
|
| 192 |
+
('pick-out-of-hole-v2', SawyerPickOutOfHoleEnvV2),
|
| 193 |
+
('reach-v2', SawyerReachEnvV2),
|
| 194 |
+
('push-back-v2', SawyerPushBackEnvV2),
|
| 195 |
+
('push-v2', SawyerPushEnvV2),
|
| 196 |
+
('pick-place-v2', SawyerPickPlaceEnvV2),
|
| 197 |
+
('plate-slide-v2', SawyerPlateSlideEnvV2),
|
| 198 |
+
('plate-slide-side-v2', SawyerPlateSlideSideEnvV2),
|
| 199 |
+
('plate-slide-back-v2', SawyerPlateSlideBackEnvV2),
|
| 200 |
+
('plate-slide-back-side-v2', SawyerPlateSlideBackSideEnvV2),
|
| 201 |
+
('peg-insert-side-v2', SawyerPegInsertionSideEnvV2),
|
| 202 |
+
('peg-unplug-side-v2', SawyerPegUnplugSideEnvV2),
|
| 203 |
+
('soccer-v2', SawyerSoccerEnvV2),
|
| 204 |
+
('stick-push-v2', SawyerStickPushEnvV2),
|
| 205 |
+
('stick-pull-v2', SawyerStickPullEnvV2),
|
| 206 |
+
('push-wall-v2', SawyerPushWallEnvV2),
|
| 207 |
+
('push-v2', SawyerPushEnvV2),
|
| 208 |
+
('reach-wall-v2', SawyerReachWallEnvV2),
|
| 209 |
+
('reach-v2', SawyerReachEnvV2),
|
| 210 |
+
('shelf-place-v2', SawyerShelfPlaceEnvV2),
|
| 211 |
+
('sweep-into-v2', SawyerSweepIntoGoalEnvV2),
|
| 212 |
+
('sweep-v2', SawyerSweepEnvV2),
|
| 213 |
+
('window-open-v2', SawyerWindowOpenEnvV2),
|
| 214 |
+
('window-close-v2', SawyerWindowCloseEnvV2),
|
| 215 |
+
))
|
| 216 |
+
|
| 217 |
+
_NUM_METAWORLD_ENVS = len(ALL_V1_ENVIRONMENTS)
|
| 218 |
+
|
| 219 |
+
EASY_MODE_CLS_DICT = OrderedDict(
|
| 220 |
+
(('reach-v1', SawyerReachPushPickPlaceEnv),
|
| 221 |
+
('push-v1', SawyerReachPushPickPlaceEnv),
|
| 222 |
+
('pick-place-v1', SawyerReachPushPickPlaceEnv),
|
| 223 |
+
('door-open-v1', SawyerDoorEnv), ('drawer-open-v1', SawyerDrawerOpenEnv),
|
| 224 |
+
('drawer-close-v1', SawyerDrawerCloseEnv),
|
| 225 |
+
('button-press-topdown-v1', SawyerButtonPressTopdownEnv),
|
| 226 |
+
('peg-insert-side-v1', SawyerPegInsertionSideEnv),
|
| 227 |
+
('window-open-v1', SawyerWindowOpenEnv),
|
| 228 |
+
('window-close-v1', SawyerWindowCloseEnv)), )
|
| 229 |
+
|
| 230 |
+
EASY_MODE_ARGS_KWARGS = {
|
| 231 |
+
key: dict(args=[],
|
| 232 |
+
kwargs={'task_id': list(ALL_V1_ENVIRONMENTS.keys()).index(key)})
|
| 233 |
+
for key, _ in EASY_MODE_CLS_DICT.items()
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
EASY_MODE_ARGS_KWARGS['reach-v1']['kwargs']['task_type'] = 'reach'
|
| 237 |
+
EASY_MODE_ARGS_KWARGS['push-v1']['kwargs']['task_type'] = 'push'
|
| 238 |
+
EASY_MODE_ARGS_KWARGS['pick-place-v1']['kwargs']['task_type'] = 'pick_place'
|
| 239 |
+
|
| 240 |
+
MEDIUM_MODE_CLS_DICT = OrderedDict(
|
| 241 |
+
(('train',
|
| 242 |
+
OrderedDict((('reach-v1', SawyerReachPushPickPlaceEnv),
|
| 243 |
+
('push-v1', SawyerReachPushPickPlaceEnv),
|
| 244 |
+
('pick-place-v1', SawyerReachPushPickPlaceEnv),
|
| 245 |
+
('door-open-v1', SawyerDoorEnv), ('drawer-close-v1',
|
| 246 |
+
SawyerDrawerCloseEnv),
|
| 247 |
+
('button-press-topdown-v1', SawyerButtonPressTopdownEnv),
|
| 248 |
+
('peg-insert-side-v1',
|
| 249 |
+
SawyerPegInsertionSideEnv), ('window-open-v1',
|
| 250 |
+
SawyerWindowOpenEnv),
|
| 251 |
+
('sweep-v1', SawyerSweepEnv), ('basketball-v1',
|
| 252 |
+
SawyerBasketballEnv)))),
|
| 253 |
+
('test',
|
| 254 |
+
OrderedDict(
|
| 255 |
+
(('drawer-open-v1', SawyerDrawerOpenEnv), ('door-close-v1',
|
| 256 |
+
SawyerDoorCloseEnv),
|
| 257 |
+
('shelf-place-v1', SawyerShelfPlaceEnv), ('sweep-into-v1',
|
| 258 |
+
SawyerSweepIntoGoalEnv), (
|
| 259 |
+
'lever-pull-v1',
|
| 260 |
+
SawyerLeverPullEnv,
|
| 261 |
+
))))))
|
| 262 |
+
medium_mode_train_args_kwargs = {
|
| 263 |
+
key: dict(args=[],
|
| 264 |
+
kwargs={
|
| 265 |
+
'task_id': list(ALL_V1_ENVIRONMENTS.keys()).index(key),
|
| 266 |
+
})
|
| 267 |
+
for key, _ in MEDIUM_MODE_CLS_DICT['train'].items()
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
medium_mode_test_args_kwargs = {
|
| 271 |
+
key: dict(args=[],
|
| 272 |
+
kwargs={'task_id': list(ALL_V1_ENVIRONMENTS.keys()).index(key)})
|
| 273 |
+
for key, _ in MEDIUM_MODE_CLS_DICT['test'].items()
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
medium_mode_train_args_kwargs['reach-v1']['kwargs']['task_type'] = 'reach'
|
| 277 |
+
medium_mode_train_args_kwargs['push-v1']['kwargs']['task_type'] = 'push'
|
| 278 |
+
medium_mode_train_args_kwargs['pick-place-v1']['kwargs'][
|
| 279 |
+
'task_type'] = 'pick_place'
|
| 280 |
+
|
| 281 |
+
MEDIUM_MODE_ARGS_KWARGS = dict(
|
| 282 |
+
train=medium_mode_train_args_kwargs,
|
| 283 |
+
test=medium_mode_test_args_kwargs,
|
| 284 |
+
)
|
| 285 |
+
'''
|
| 286 |
+
ML45 environments and arguments
|
| 287 |
+
'''
|
| 288 |
+
HARD_MODE_CLS_DICT = OrderedDict(
|
| 289 |
+
(('train',
|
| 290 |
+
OrderedDict((
|
| 291 |
+
('reach-v1', SawyerReachPushPickPlaceEnv),
|
| 292 |
+
('push-v1', SawyerReachPushPickPlaceEnv),
|
| 293 |
+
('pick-place-v1', SawyerReachPushPickPlaceEnv),
|
| 294 |
+
('door-open-v1', SawyerDoorEnv),
|
| 295 |
+
('drawer-open-v1', SawyerDrawerOpenEnv),
|
| 296 |
+
('drawer-close-v1', SawyerDrawerCloseEnv),
|
| 297 |
+
('button-press-topdown-v1', SawyerButtonPressTopdownEnv),
|
| 298 |
+
('peg-insert-side-v1', SawyerPegInsertionSideEnv),
|
| 299 |
+
('window-open-v1', SawyerWindowOpenEnv),
|
| 300 |
+
('window-close-v1', SawyerWindowCloseEnv),
|
| 301 |
+
('door-close-v1', SawyerDoorCloseEnv),
|
| 302 |
+
('reach-wall-v1', SawyerReachPushPickPlaceWallEnv),
|
| 303 |
+
('pick-place-wall-v1', SawyerReachPushPickPlaceWallEnv),
|
| 304 |
+
('push-wall-v1', SawyerReachPushPickPlaceWallEnv),
|
| 305 |
+
('button-press-v1', SawyerButtonPressEnv),
|
| 306 |
+
('button-press-topdown-wall-v1', SawyerButtonPressTopdownWallEnv),
|
| 307 |
+
('button-press-wall-v1', SawyerButtonPressWallEnv),
|
| 308 |
+
('peg-unplug-side-v1', SawyerPegUnplugSideEnv),
|
| 309 |
+
('disassemble-v1', SawyerNutDisassembleEnv),
|
| 310 |
+
('hammer-v1', SawyerHammerEnv),
|
| 311 |
+
('plate-slide-v1', SawyerPlateSlideEnv),
|
| 312 |
+
('plate-slide-side-v1', SawyerPlateSlideSideEnv),
|
| 313 |
+
('plate-slide-back-v1', SawyerPlateSlideBackEnv),
|
| 314 |
+
('plate-slide-back-side-v1', SawyerPlateSlideBackSideEnv),
|
| 315 |
+
('handle-press-v1', SawyerHandlePressEnv),
|
| 316 |
+
('handle-pull-v1', SawyerHandlePullEnv),
|
| 317 |
+
('handle-press-side-v1', SawyerHandlePressSideEnv),
|
| 318 |
+
('handle-pull-side-v1', SawyerHandlePullSideEnv),
|
| 319 |
+
('stick-push-v1', SawyerStickPushEnv),
|
| 320 |
+
('stick-pull-v1', SawyerStickPullEnv),
|
| 321 |
+
('basketball-v1', SawyerBasketballEnv),
|
| 322 |
+
('soccer-v1', SawyerSoccerEnv),
|
| 323 |
+
('faucet-open-v1', SawyerFaucetOpenEnv),
|
| 324 |
+
('faucet-close-v1', SawyerFaucetCloseEnv),
|
| 325 |
+
('coffee-push-v1', SawyerCoffeePushEnv),
|
| 326 |
+
('coffee-pull-v1', SawyerCoffeePullEnv),
|
| 327 |
+
('coffee-button-v1', SawyerCoffeeButtonEnv),
|
| 328 |
+
('sweep-v1', SawyerSweepEnv),
|
| 329 |
+
('sweep-into-v1', SawyerSweepIntoGoalEnv),
|
| 330 |
+
('pick-out-of-hole-v1', SawyerPickOutOfHoleEnv),
|
| 331 |
+
('assembly-v1', SawyerNutAssemblyEnv),
|
| 332 |
+
('shelf-place-v1', SawyerShelfPlaceEnv),
|
| 333 |
+
('push-back-v1', SawyerPushBackEnv),
|
| 334 |
+
('lever-pull-v1', SawyerLeverPullEnv),
|
| 335 |
+
('dial-turn-v1', SawyerDialTurnEnv),
|
| 336 |
+
))), ('test',
|
| 337 |
+
OrderedDict((
|
| 338 |
+
('bin-picking-v1', SawyerBinPickingEnv),
|
| 339 |
+
('box-close-v1', SawyerBoxCloseEnv),
|
| 340 |
+
('hand-insert-v1', SawyerHandInsertEnv),
|
| 341 |
+
('door-lock-v1', SawyerDoorLockEnv),
|
| 342 |
+
('door-unlock-v1', SawyerDoorUnlockEnv),
|
| 343 |
+
)))))
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _hard_mode_args_kwargs(env_cls_, key_):
|
| 347 |
+
del env_cls_
|
| 348 |
+
|
| 349 |
+
kwargs = dict(task_id=list(ALL_V1_ENVIRONMENTS.keys()).index(key_))
|
| 350 |
+
if key_ == 'reach-v1' or key_ == 'reach-wall-v1':
|
| 351 |
+
kwargs['task_type'] = 'reach'
|
| 352 |
+
elif key_ == 'push-v1' or key_ == 'push-wall-v1':
|
| 353 |
+
kwargs['task_type'] = 'push'
|
| 354 |
+
elif key_ == 'pick-place-v1' or key_ == 'pick-place-wall-v1':
|
| 355 |
+
kwargs['task_type'] = 'pick_place'
|
| 356 |
+
return dict(args=[], kwargs=kwargs)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
HARD_MODE_ARGS_KWARGS = dict(train={}, test={})
|
| 360 |
+
for key, env_cls in HARD_MODE_CLS_DICT['train'].items():
|
| 361 |
+
HARD_MODE_ARGS_KWARGS['train'][key] = _hard_mode_args_kwargs(env_cls, key)
|
| 362 |
+
for key, env_cls in HARD_MODE_CLS_DICT['test'].items():
|
| 363 |
+
HARD_MODE_ARGS_KWARGS['test'][key] = _hard_mode_args_kwargs(env_cls, key)
|
| 364 |
+
|
| 365 |
+
############################## V2 DICTS ##############################
|
| 366 |
+
|
| 367 |
+
MT10_V2 = OrderedDict(
|
| 368 |
+
(('reach-v2', SawyerReachEnvV2), ('push-v2', SawyerPushEnvV2),
|
| 369 |
+
('pick-place-v2', SawyerPickPlaceEnvV2),
|
| 370 |
+
('door-open-v2', SawyerDoorEnvV2),
|
| 371 |
+
('drawer-open-v2', SawyerDrawerOpenEnvV2),
|
| 372 |
+
('drawer-close-v2', SawyerDrawerCloseEnvV2),
|
| 373 |
+
('button-press-topdown-v2', SawyerButtonPressTopdownEnvV2),
|
| 374 |
+
('peg-insert-side-v2', SawyerPegInsertionSideEnvV2),
|
| 375 |
+
('window-open-v2', SawyerWindowOpenEnvV2),
|
| 376 |
+
('window-close-v2', SawyerWindowCloseEnvV2)), )
|
| 377 |
+
|
| 378 |
+
MT10_V2_ARGS_KWARGS = {
|
| 379 |
+
key: dict(args=[],
|
| 380 |
+
kwargs={'task_id': list(ALL_V2_ENVIRONMENTS.keys()).index(key)})
|
| 381 |
+
for key, _ in MT10_V2.items()
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
ML10_V2 = OrderedDict(
|
| 385 |
+
(('train',
|
| 386 |
+
OrderedDict(
|
| 387 |
+
(('reach-v2', SawyerReachEnvV2), ('push-v2', SawyerPushEnvV2),
|
| 388 |
+
('pick-place-v2', SawyerPickPlaceEnvV2),
|
| 389 |
+
('door-open-v2', SawyerDoorEnvV2), ('drawer-close-v2',
|
| 390 |
+
SawyerDrawerCloseEnvV2),
|
| 391 |
+
('button-press-topdown-v2', SawyerButtonPressEnvV2),
|
| 392 |
+
('peg-insert-side-v2',
|
| 393 |
+
SawyerPegInsertionSideEnvV2), ('window-open-v2',
|
| 394 |
+
SawyerWindowOpenEnvV2),
|
| 395 |
+
('sweep-v2', SawyerSweepEnvV2), ('basketball-v2',
|
| 396 |
+
SawyerBasketballEnvV2)))),
|
| 397 |
+
('test',
|
| 398 |
+
OrderedDict(
|
| 399 |
+
(('drawer-open-v2', SawyerDrawerOpenEnvV2),
|
| 400 |
+
('door-close-v2', SawyerDoorCloseEnvV2), ('shelf-place-v2',
|
| 401 |
+
SawyerShelfPlaceEnvV2),
|
| 402 |
+
('sweep-into-v2', SawyerSweepIntoGoalEnvV2), (
|
| 403 |
+
'lever-pull-v2',
|
| 404 |
+
SawyerLeverPullEnvV2,
|
| 405 |
+
))))))
|
| 406 |
+
|
| 407 |
+
ml10_train_args_kwargs = {
|
| 408 |
+
key: dict(args=[],
|
| 409 |
+
kwargs={
|
| 410 |
+
'task_id': list(ALL_V2_ENVIRONMENTS.keys()).index(key),
|
| 411 |
+
})
|
| 412 |
+
for key, _ in ML10_V2['train'].items()
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
ml10_test_args_kwargs = {
|
| 416 |
+
key: dict(args=[],
|
| 417 |
+
kwargs={'task_id': list(ALL_V2_ENVIRONMENTS.keys()).index(key)})
|
| 418 |
+
for key, _ in ML10_V2['test'].items()
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
ML10_ARGS_KWARGS = dict(
|
| 422 |
+
train=ml10_train_args_kwargs,
|
| 423 |
+
test=ml10_test_args_kwargs,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
ML1_V2 = OrderedDict(
|
| 427 |
+
(('train', ALL_V2_ENVIRONMENTS), ('test', ALL_V2_ENVIRONMENTS)))
|
| 428 |
+
|
| 429 |
+
ML1_args_kwargs = {
|
| 430 |
+
key: dict(args=[],
|
| 431 |
+
kwargs={
|
| 432 |
+
'task_id': list(ALL_V2_ENVIRONMENTS.keys()).index(key),
|
| 433 |
+
})
|
| 434 |
+
for key, _ in ML1_V2['train'].items()
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
MT50_V2 = OrderedDict((
|
| 438 |
+
('assembly-v2', SawyerNutAssemblyEnvV2),
|
| 439 |
+
('basketball-v2', SawyerBasketballEnvV2),
|
| 440 |
+
('bin-picking-v2', SawyerBinPickingEnvV2),
|
| 441 |
+
('box-close-v2', SawyerBoxCloseEnvV2),
|
| 442 |
+
('button-press-topdown-v2', SawyerButtonPressTopdownEnvV2),
|
| 443 |
+
('button-press-topdown-wall-v2', SawyerButtonPressTopdownWallEnvV2),
|
| 444 |
+
('button-press-v2', SawyerButtonPressEnvV2),
|
| 445 |
+
('button-press-wall-v2', SawyerButtonPressWallEnvV2),
|
| 446 |
+
('coffee-button-v2', SawyerCoffeeButtonEnvV2),
|
| 447 |
+
('coffee-pull-v2', SawyerCoffeePullEnvV2),
|
| 448 |
+
('coffee-push-v2', SawyerCoffeePushEnvV2),
|
| 449 |
+
('dial-turn-v2', SawyerDialTurnEnvV2),
|
| 450 |
+
('disassemble-v2', SawyerNutDisassembleEnvV2),
|
| 451 |
+
('door-close-v2', SawyerDoorCloseEnvV2),
|
| 452 |
+
('door-lock-v2', SawyerDoorLockEnvV2),
|
| 453 |
+
('door-open-v2', SawyerDoorEnvV2),
|
| 454 |
+
('door-unlock-v2', SawyerDoorUnlockEnvV2),
|
| 455 |
+
('hand-insert-v2', SawyerHandInsertEnvV2),
|
| 456 |
+
('drawer-close-v2', SawyerDrawerCloseEnvV2),
|
| 457 |
+
('drawer-open-v2', SawyerDrawerOpenEnvV2),
|
| 458 |
+
('faucet-open-v2', SawyerFaucetOpenEnvV2),
|
| 459 |
+
('faucet-close-v2', SawyerFaucetCloseEnvV2),
|
| 460 |
+
('hammer-v2', SawyerHammerEnvV2),
|
| 461 |
+
('handle-press-side-v2', SawyerHandlePressSideEnvV2),
|
| 462 |
+
('handle-press-v2', SawyerHandlePressEnvV2),
|
| 463 |
+
('handle-pull-side-v2', SawyerHandlePullSideEnvV2),
|
| 464 |
+
('handle-pull-v2', SawyerHandlePullEnvV2),
|
| 465 |
+
('lever-pull-v2', SawyerLeverPullEnvV2),
|
| 466 |
+
('peg-insert-side-v2', SawyerPegInsertionSideEnvV2),
|
| 467 |
+
('pick-place-wall-v2', SawyerPickPlaceWallEnvV2),
|
| 468 |
+
('pick-out-of-hole-v2', SawyerPickOutOfHoleEnvV2),
|
| 469 |
+
('reach-v2', SawyerReachEnvV2),
|
| 470 |
+
('push-back-v2', SawyerPushBackEnvV2),
|
| 471 |
+
('push-v2', SawyerPushEnvV2),
|
| 472 |
+
('pick-place-v2', SawyerPickPlaceEnvV2),
|
| 473 |
+
('plate-slide-v2', SawyerPlateSlideEnvV2),
|
| 474 |
+
('plate-slide-side-v2', SawyerPlateSlideSideEnvV2),
|
| 475 |
+
('plate-slide-back-v2', SawyerPlateSlideBackEnvV2),
|
| 476 |
+
('plate-slide-back-side-v2', SawyerPlateSlideBackSideEnvV2),
|
| 477 |
+
('peg-insert-side-v2', SawyerPegInsertionSideEnvV2),
|
| 478 |
+
('peg-unplug-side-v2', SawyerPegUnplugSideEnvV2),
|
| 479 |
+
('soccer-v2', SawyerSoccerEnvV2),
|
| 480 |
+
('stick-push-v2', SawyerStickPushEnvV2),
|
| 481 |
+
('stick-pull-v2', SawyerStickPullEnvV2),
|
| 482 |
+
('push-wall-v2', SawyerPushWallEnvV2),
|
| 483 |
+
('push-v2', SawyerPushEnvV2),
|
| 484 |
+
('reach-wall-v2', SawyerReachWallEnvV2),
|
| 485 |
+
('reach-v2', SawyerReachEnvV2),
|
| 486 |
+
('shelf-place-v2', SawyerShelfPlaceEnvV2),
|
| 487 |
+
('sweep-into-v2', SawyerSweepIntoGoalEnvV2),
|
| 488 |
+
('sweep-v2', SawyerSweepEnvV2),
|
| 489 |
+
('window-open-v2', SawyerWindowOpenEnvV2),
|
| 490 |
+
('window-close-v2', SawyerWindowCloseEnvV2),
|
| 491 |
+
))
|
| 492 |
+
|
| 493 |
+
MT50_V2_ARGS_KWARGS = {
|
| 494 |
+
key: dict(args=[],
|
| 495 |
+
kwargs={'task_id': list(ALL_V2_ENVIRONMENTS.keys()).index(key)})
|
| 496 |
+
for key, _ in MT50_V2.items()
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
ML45_V2 = OrderedDict(
|
| 500 |
+
(('train',
|
| 501 |
+
OrderedDict((
|
| 502 |
+
('assembly-v2', SawyerNutAssemblyEnvV2),
|
| 503 |
+
('basketball-v2', SawyerBasketballEnvV2),
|
| 504 |
+
('button-press-topdown-v2', SawyerButtonPressTopdownEnvV2),
|
| 505 |
+
('button-press-topdown-wall-v2', SawyerButtonPressTopdownWallEnvV2),
|
| 506 |
+
('button-press-v2', SawyerButtonPressEnvV2),
|
| 507 |
+
('button-press-wall-v2', SawyerButtonPressWallEnvV2),
|
| 508 |
+
('coffee-button-v2', SawyerCoffeeButtonEnvV2),
|
| 509 |
+
('coffee-pull-v2', SawyerCoffeePullEnvV2),
|
| 510 |
+
('coffee-push-v2', SawyerCoffeePushEnvV2),
|
| 511 |
+
('dial-turn-v2', SawyerDialTurnEnvV2),
|
| 512 |
+
('disassemble-v2', SawyerNutDisassembleEnvV2),
|
| 513 |
+
('door-close-v2', SawyerDoorCloseEnvV2),
|
| 514 |
+
('door-open-v2', SawyerDoorEnvV2),
|
| 515 |
+
('drawer-close-v2', SawyerDrawerCloseEnvV2),
|
| 516 |
+
('drawer-open-v2', SawyerDrawerOpenEnvV2),
|
| 517 |
+
('faucet-open-v2', SawyerFaucetOpenEnvV2),
|
| 518 |
+
('faucet-close-v2', SawyerFaucetCloseEnvV2),
|
| 519 |
+
('hammer-v2', SawyerHammerEnvV2),
|
| 520 |
+
('handle-press-side-v2', SawyerHandlePressSideEnvV2),
|
| 521 |
+
('handle-press-v2', SawyerHandlePressEnvV2),
|
| 522 |
+
('handle-pull-side-v2', SawyerHandlePullSideEnvV2),
|
| 523 |
+
('handle-pull-v2', SawyerHandlePullEnvV2),
|
| 524 |
+
('lever-pull-v2', SawyerLeverPullEnvV2),
|
| 525 |
+
('peg-insert-side-v2', SawyerPegInsertionSideEnvV2),
|
| 526 |
+
('pick-place-wall-v2', SawyerPickPlaceWallEnvV2),
|
| 527 |
+
('pick-out-of-hole-v2', SawyerPickOutOfHoleEnvV2),
|
| 528 |
+
('reach-v2', SawyerReachEnvV2),
|
| 529 |
+
('push-back-v2', SawyerPushBackEnvV2),
|
| 530 |
+
('push-v2', SawyerPushEnvV2),
|
| 531 |
+
('pick-place-v2', SawyerPickPlaceEnvV2),
|
| 532 |
+
('plate-slide-v2', SawyerPlateSlideEnvV2),
|
| 533 |
+
('plate-slide-side-v2', SawyerPlateSlideSideEnvV2),
|
| 534 |
+
('plate-slide-back-v2', SawyerPlateSlideBackEnvV2),
|
| 535 |
+
('plate-slide-back-side-v2', SawyerPlateSlideBackSideEnvV2),
|
| 536 |
+
('peg-insert-side-v2', SawyerPegInsertionSideEnvV2),
|
| 537 |
+
('peg-unplug-side-v2', SawyerPegUnplugSideEnvV2),
|
| 538 |
+
('soccer-v2', SawyerSoccerEnvV2),
|
| 539 |
+
('stick-push-v2', SawyerStickPushEnvV2),
|
| 540 |
+
('stick-pull-v2', SawyerStickPullEnvV2),
|
| 541 |
+
('push-wall-v2', SawyerPushWallEnvV2),
|
| 542 |
+
('push-v2', SawyerPushEnvV2),
|
| 543 |
+
('reach-wall-v2', SawyerReachWallEnvV2),
|
| 544 |
+
('reach-v2', SawyerReachEnvV2),
|
| 545 |
+
('shelf-place-v2', SawyerShelfPlaceEnvV2),
|
| 546 |
+
('sweep-into-v2', SawyerSweepIntoGoalEnvV2),
|
| 547 |
+
('sweep-v2', SawyerSweepEnvV2),
|
| 548 |
+
('window-open-v2', SawyerWindowOpenEnvV2),
|
| 549 |
+
('window-close-v2', SawyerWindowCloseEnvV2),
|
| 550 |
+
))), ('test',
|
| 551 |
+
OrderedDict((
|
| 552 |
+
('bin-picking-v2', SawyerBinPickingEnvV2),
|
| 553 |
+
('box-close-v2', SawyerBoxCloseEnvV2),
|
| 554 |
+
('hand-insert-v2', SawyerHandInsertEnvV2),
|
| 555 |
+
('door-lock-v2', SawyerDoorLockEnvV2),
|
| 556 |
+
('door-unlock-v2', SawyerDoorUnlockEnvV2),
|
| 557 |
+
)))))
|
| 558 |
+
|
| 559 |
+
ml45_train_args_kwargs = {
|
| 560 |
+
key: dict(args=[],
|
| 561 |
+
kwargs={
|
| 562 |
+
'task_id': list(ALL_V2_ENVIRONMENTS.keys()).index(key),
|
| 563 |
+
})
|
| 564 |
+
for key, _ in ML45_V2['train'].items()
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
ml45_test_args_kwargs = {
|
| 568 |
+
key: dict(args=[],
|
| 569 |
+
kwargs={'task_id': list(ALL_V2_ENVIRONMENTS.keys()).index(key)})
|
| 570 |
+
for key, _ in ML45_V2['test'].items()
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
ML45_ARGS_KWARGS = dict(
|
| 574 |
+
train=ml45_train_args_kwargs,
|
| 575 |
+
test=ml45_test_args_kwargs,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def create_hidden_goal_envs():
|
| 580 |
+
hidden_goal_envs = {}
|
| 581 |
+
for env_name, env_cls in ALL_V2_ENVIRONMENTS.items():
|
| 582 |
+
d = {}
|
| 583 |
+
|
| 584 |
+
def initialize(env, seed=None):
|
| 585 |
+
if seed is not None:
|
| 586 |
+
st0 = np.random.get_state()
|
| 587 |
+
np.random.seed(seed)
|
| 588 |
+
super(type(env), env).__init__()
|
| 589 |
+
env._partially_observable = True
|
| 590 |
+
env._freeze_rand_vec = False
|
| 591 |
+
env._set_task_called = True
|
| 592 |
+
env.reset()
|
| 593 |
+
env._freeze_rand_vec = True
|
| 594 |
+
if seed is not None:
|
| 595 |
+
env.seed(seed)
|
| 596 |
+
np.random.set_state(st0)
|
| 597 |
+
|
| 598 |
+
d['__init__'] = initialize
|
| 599 |
+
hg_env_name = re.sub("(^|[-])\s*([a-zA-Z])",
|
| 600 |
+
lambda p: p.group(0).upper(), env_name)
|
| 601 |
+
hg_env_name = hg_env_name.replace("-", "")
|
| 602 |
+
hg_env_key = '{}-goal-hidden'.format(env_name)
|
| 603 |
+
hg_env_name = '{}GoalHidden'.format(hg_env_name)
|
| 604 |
+
HiddenGoalEnvCls = type(hg_env_name, (env_cls, ), d)
|
| 605 |
+
hidden_goal_envs[hg_env_key] = HiddenGoalEnvCls
|
| 606 |
+
|
| 607 |
+
return OrderedDict(hidden_goal_envs)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def create_observable_goal_envs():
|
| 611 |
+
observable_goal_envs = {}
|
| 612 |
+
for env_name, env_cls in ALL_V2_ENVIRONMENTS.items():
|
| 613 |
+
d = {}
|
| 614 |
+
|
| 615 |
+
def initialize(env, seed=None):
|
| 616 |
+
if seed is not None:
|
| 617 |
+
st0 = np.random.get_state()
|
| 618 |
+
np.random.seed(seed)
|
| 619 |
+
super(type(env), env).__init__()
|
| 620 |
+
env._partially_observable = False
|
| 621 |
+
env._freeze_rand_vec = False
|
| 622 |
+
env._set_task_called = True
|
| 623 |
+
env.reset()
|
| 624 |
+
env._freeze_rand_vec = True
|
| 625 |
+
if seed is not None:
|
| 626 |
+
env.seed(seed)
|
| 627 |
+
np.random.set_state(st0)
|
| 628 |
+
|
| 629 |
+
d['__init__'] = initialize
|
| 630 |
+
og_env_name = re.sub("(^|[-])\s*([a-zA-Z])",
|
| 631 |
+
lambda p: p.group(0).upper(), env_name)
|
| 632 |
+
og_env_name = og_env_name.replace("-", "")
|
| 633 |
+
|
| 634 |
+
og_env_key = '{}-goal-observable'.format(env_name)
|
| 635 |
+
og_env_name = '{}GoalObservable'.format(og_env_name)
|
| 636 |
+
ObservableGoalEnvCls = type(og_env_name, (env_cls, ), d)
|
| 637 |
+
observable_goal_envs[og_env_key] = ObservableGoalEnvCls
|
| 638 |
+
|
| 639 |
+
return OrderedDict(observable_goal_envs)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
ALL_V2_ENVIRONMENTS_GOAL_HIDDEN = create_hidden_goal_envs()
|
| 643 |
+
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE = create_observable_goal_envs()
|
Metaworld/metaworld/envs/mujoco/mujoco_env.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import glfw
|
| 5 |
+
from gym import error
|
| 6 |
+
from gym.utils import seeding
|
| 7 |
+
import numpy as np
|
| 8 |
+
from os import path
|
| 9 |
+
import gym
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import mujoco_py
|
| 13 |
+
except ImportError as e:
|
| 14 |
+
raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _assert_task_is_set(func):
|
| 18 |
+
def inner(*args, **kwargs):
|
| 19 |
+
env = args[0]
|
| 20 |
+
if not env._set_task_called:
|
| 21 |
+
raise RuntimeError(
|
| 22 |
+
'You must call env.set_task before using env.'
|
| 23 |
+
+ func.__name__
|
| 24 |
+
)
|
| 25 |
+
return func(*args, **kwargs)
|
| 26 |
+
return inner
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
DEFAULT_SIZE = 500
|
| 30 |
+
|
| 31 |
+
class MujocoEnv(gym.Env, abc.ABC):
|
| 32 |
+
"""
|
| 33 |
+
This is a simplified version of the gym MujocoEnv class.
|
| 34 |
+
|
| 35 |
+
Some differences are:
|
| 36 |
+
- Do not automatically set the observation/action space.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
max_path_length = 500
|
| 40 |
+
|
| 41 |
+
def __init__(self, model_path, frame_skip):
|
| 42 |
+
if not path.exists(model_path):
|
| 43 |
+
raise IOError("File %s does not exist" % model_path)
|
| 44 |
+
|
| 45 |
+
self.frame_skip = frame_skip
|
| 46 |
+
self.model = mujoco_py.load_model_from_path(model_path)
|
| 47 |
+
self.sim = mujoco_py.MjSim(self.model)
|
| 48 |
+
self.data = self.sim.data
|
| 49 |
+
self.viewer = None
|
| 50 |
+
self._viewers = {}
|
| 51 |
+
|
| 52 |
+
self.metadata = {
|
| 53 |
+
'render.modes': ['human'],
|
| 54 |
+
'video.frames_per_second': int(np.round(1.0 / self.dt))
|
| 55 |
+
}
|
| 56 |
+
self.init_qpos = self.sim.data.qpos.ravel().copy()
|
| 57 |
+
self.init_qvel = self.sim.data.qvel.ravel().copy()
|
| 58 |
+
|
| 59 |
+
self._did_see_sim_exception = False
|
| 60 |
+
|
| 61 |
+
self.np_random, _ = seeding.np_random(None)
|
| 62 |
+
|
| 63 |
+
def seed(self, seed):
|
| 64 |
+
assert seed is not None
|
| 65 |
+
self.np_random, seed = seeding.np_random(seed)
|
| 66 |
+
self.action_space.seed(seed)
|
| 67 |
+
self.observation_space.seed(seed)
|
| 68 |
+
self.goal_space.seed(seed)
|
| 69 |
+
return [seed]
|
| 70 |
+
|
| 71 |
+
@abc.abstractmethod
|
| 72 |
+
def reset_model(self):
|
| 73 |
+
"""
|
| 74 |
+
Reset the robot degrees of freedom (qpos and qvel).
|
| 75 |
+
Implement this in each subclass.
|
| 76 |
+
"""
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def viewer_setup(self):
|
| 80 |
+
"""
|
| 81 |
+
This method is called when the viewer is initialized and after every reset
|
| 82 |
+
Optionally implement this method, if you need to tinker with camera position
|
| 83 |
+
and so forth.
|
| 84 |
+
"""
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
@_assert_task_is_set
|
| 88 |
+
def reset(self):
|
| 89 |
+
self._did_see_sim_exception = False
|
| 90 |
+
self.sim.reset()
|
| 91 |
+
ob = self.reset_model()
|
| 92 |
+
if self.viewer is not None:
|
| 93 |
+
self.viewer_setup()
|
| 94 |
+
return ob
|
| 95 |
+
|
| 96 |
+
def set_state(self, qpos, qvel):
|
| 97 |
+
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
|
| 98 |
+
old_state = self.sim.get_state()
|
| 99 |
+
new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel,
|
| 100 |
+
old_state.act, old_state.udd_state)
|
| 101 |
+
self.sim.set_state(new_state)
|
| 102 |
+
self.sim.forward()
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def dt(self):
|
| 106 |
+
return self.model.opt.timestep * self.frame_skip
|
| 107 |
+
|
| 108 |
+
def do_simulation(self, ctrl, n_frames=None):
|
| 109 |
+
if getattr(self, 'curr_path_length', 0) > self.max_path_length:
|
| 110 |
+
raise ValueError('Maximum path length allowed by the benchmark has been exceeded')
|
| 111 |
+
if self._did_see_sim_exception:
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
if n_frames is None:
|
| 115 |
+
n_frames = self.frame_skip
|
| 116 |
+
self.sim.data.ctrl[:] = ctrl
|
| 117 |
+
|
| 118 |
+
for _ in range(n_frames):
|
| 119 |
+
try:
|
| 120 |
+
self.sim.step()
|
| 121 |
+
except mujoco_py.MujocoException as err:
|
| 122 |
+
warnings.warn(str(err), category=RuntimeWarning)
|
| 123 |
+
self._did_see_sim_exception = True
|
| 124 |
+
|
| 125 |
+
def render(self, offscreen=False, camera_name="corner2", resolution=(640, 480)):
|
| 126 |
+
assert_string = ("camera_name should be one of ",
|
| 127 |
+
"corner3, corner, corner2, topview, gripperPOV, behindGripper")
|
| 128 |
+
assert camera_name in {"corner3", "corner", "corner2",
|
| 129 |
+
"topview", "gripperPOV", "behindGripper"}, assert_string
|
| 130 |
+
if not offscreen:
|
| 131 |
+
self._get_viewer('human').render()
|
| 132 |
+
else:
|
| 133 |
+
return self.sim.render(
|
| 134 |
+
*resolution,
|
| 135 |
+
mode='offscreen',
|
| 136 |
+
camera_name=camera_name
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def close(self):
|
| 140 |
+
if self.viewer is not None:
|
| 141 |
+
glfw.destroy_window(self.viewer.window)
|
| 142 |
+
self.viewer = None
|
| 143 |
+
|
| 144 |
+
def _get_viewer(self, mode):
|
| 145 |
+
self.viewer = self._viewers.get(mode)
|
| 146 |
+
if self.viewer is None:
|
| 147 |
+
if mode == 'human':
|
| 148 |
+
self.viewer = mujoco_py.MjViewer(self.sim)
|
| 149 |
+
self.viewer_setup()
|
| 150 |
+
self._viewers[mode] = self.viewer
|
| 151 |
+
self.viewer_setup()
|
| 152 |
+
return self.viewer
|
| 153 |
+
|
| 154 |
+
def get_body_com(self, body_name):
|
| 155 |
+
return self.data.get_body_xpos(body_name)
|
Metaworld/metaworld/envs/mujoco/sawyer_xyz/__init__.py
ADDED
|
File without changes
|
Metaworld/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import copy
|
| 3 |
+
import pickle
|
| 4 |
+
|
| 5 |
+
from gym.spaces import Box
|
| 6 |
+
from gym.spaces import Discrete
|
| 7 |
+
import mujoco_py
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from metaworld.envs import reward_utils
|
| 11 |
+
from metaworld.envs.mujoco.mujoco_env import MujocoEnv, _assert_task_is_set
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SawyerMocapBase(MujocoEnv, metaclass=abc.ABCMeta):
|
| 15 |
+
"""
|
| 16 |
+
Provides some commonly-shared functions for Sawyer Mujoco envs that use
|
| 17 |
+
mocap for XYZ control.
|
| 18 |
+
"""
|
| 19 |
+
mocap_low = np.array([-0.2, 0.5, 0.06])
|
| 20 |
+
mocap_high = np.array([0.2, 0.7, 0.6])
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_name, frame_skip=5):
|
| 23 |
+
MujocoEnv.__init__(self, model_name, frame_skip=frame_skip)
|
| 24 |
+
self.reset_mocap_welds()
|
| 25 |
+
|
| 26 |
+
def get_endeff_pos(self):
|
| 27 |
+
return self.data.get_body_xpos('hand').copy()
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def tcp_center(self):
|
| 31 |
+
"""The COM of the gripper's 2 fingers
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
(np.ndarray): 3-element position
|
| 35 |
+
"""
|
| 36 |
+
right_finger_pos = self._get_site_pos('rightEndEffector')
|
| 37 |
+
left_finger_pos = self._get_site_pos('leftEndEffector')
|
| 38 |
+
tcp_center = (right_finger_pos + left_finger_pos) / 2.0
|
| 39 |
+
return tcp_center
|
| 40 |
+
|
| 41 |
+
def get_env_state(self):
|
| 42 |
+
joint_state = self.sim.get_state()
|
| 43 |
+
mocap_state = self.data.mocap_pos, self.data.mocap_quat
|
| 44 |
+
state = joint_state, mocap_state
|
| 45 |
+
return copy.deepcopy(state)
|
| 46 |
+
|
| 47 |
+
def set_env_state(self, state):
|
| 48 |
+
joint_state, mocap_state = state
|
| 49 |
+
self.sim.set_state(joint_state)
|
| 50 |
+
mocap_pos, mocap_quat = mocap_state
|
| 51 |
+
self.data.set_mocap_pos('mocap', mocap_pos)
|
| 52 |
+
self.data.set_mocap_quat('mocap', mocap_quat)
|
| 53 |
+
self.sim.forward()
|
| 54 |
+
|
| 55 |
+
def __getstate__(self):
|
| 56 |
+
state = self.__dict__.copy()
|
| 57 |
+
del state['model']
|
| 58 |
+
del state['sim']
|
| 59 |
+
del state['data']
|
| 60 |
+
mjb = self.model.get_mjb()
|
| 61 |
+
return {'state': state, 'mjb': mjb, 'env_state': self.get_env_state()}
|
| 62 |
+
|
| 63 |
+
def __setstate__(self, state):
|
| 64 |
+
self.__dict__ = state['state']
|
| 65 |
+
self.model = mujoco_py.load_model_from_mjb(state['mjb'])
|
| 66 |
+
self.sim = mujoco_py.MjSim(self.model)
|
| 67 |
+
self.data = self.sim.data
|
| 68 |
+
self.set_env_state(state['env_state'])
|
| 69 |
+
|
| 70 |
+
def reset_mocap_welds(self):
|
| 71 |
+
"""Resets the mocap welds that we use for actuation."""
|
| 72 |
+
sim = self.sim
|
| 73 |
+
if sim.model.nmocap > 0 and sim.model.eq_data is not None:
|
| 74 |
+
for i in range(sim.model.eq_data.shape[0]):
|
| 75 |
+
if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD:
|
| 76 |
+
sim.model.eq_data[i, :] = np.array(
|
| 77 |
+
[0., 0., 0., 1., 0., 0., 0.])
|
| 78 |
+
sim.forward()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class SawyerXYZEnv(SawyerMocapBase, metaclass=abc.ABCMeta):
|
| 82 |
+
_HAND_SPACE = Box(
|
| 83 |
+
np.array([-0.525, .348, -.0525]),
|
| 84 |
+
np.array([+0.525, 1.025, .7])
|
| 85 |
+
)
|
| 86 |
+
max_path_length = 500
|
| 87 |
+
|
| 88 |
+
TARGET_RADIUS = 0.05
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
model_name,
|
| 93 |
+
frame_skip=5,
|
| 94 |
+
hand_low=(-0.2, 0.55, 0.05),
|
| 95 |
+
hand_high=(0.2, 0.75, 0.3),
|
| 96 |
+
mocap_low=None,
|
| 97 |
+
mocap_high=None,
|
| 98 |
+
action_scale=1./100,
|
| 99 |
+
action_rot_scale=1.,
|
| 100 |
+
):
|
| 101 |
+
super().__init__(model_name, frame_skip=frame_skip)
|
| 102 |
+
self.random_init = True
|
| 103 |
+
self.action_scale = action_scale
|
| 104 |
+
self.action_rot_scale = action_rot_scale
|
| 105 |
+
self.hand_low = np.array(hand_low)
|
| 106 |
+
self.hand_high = np.array(hand_high)
|
| 107 |
+
if mocap_low is None:
|
| 108 |
+
mocap_low = hand_low
|
| 109 |
+
if mocap_high is None:
|
| 110 |
+
mocap_high = hand_high
|
| 111 |
+
self.mocap_low = np.hstack(mocap_low)
|
| 112 |
+
self.mocap_high = np.hstack(mocap_high)
|
| 113 |
+
self.curr_path_length = 0
|
| 114 |
+
self.seeded_rand_vec = False
|
| 115 |
+
self._freeze_rand_vec = True
|
| 116 |
+
self._last_rand_vec = None
|
| 117 |
+
|
| 118 |
+
# We use continuous goal space by default and
|
| 119 |
+
# can discretize the goal space by calling
|
| 120 |
+
# the `discretize_goal_space` method.
|
| 121 |
+
self.discrete_goal_space = None
|
| 122 |
+
self.discrete_goals = []
|
| 123 |
+
self.active_discrete_goal = None
|
| 124 |
+
|
| 125 |
+
self.init_left_pad = self.get_body_com('leftpad')
|
| 126 |
+
self.init_right_pad = self.get_body_com('rightpad')
|
| 127 |
+
|
| 128 |
+
self.action_space = Box(
|
| 129 |
+
np.array([-1, -1, -1, -1]),
|
| 130 |
+
np.array([+1, +1, +1, +1]),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
self.isV2 = "V2" in type(self).__name__
|
| 134 |
+
# Technically these observation lengths are different between v1 and v2,
|
| 135 |
+
# but we handle that elsewhere and just stick with v2 numbers here
|
| 136 |
+
self._obs_obj_max_len = 14 if self.isV2 else 6
|
| 137 |
+
self._obs_obj_possible_lens = (6, 14)
|
| 138 |
+
|
| 139 |
+
self._set_task_called = False
|
| 140 |
+
self._partially_observable = True
|
| 141 |
+
|
| 142 |
+
self.hand_init_pos = None # OVERRIDE ME
|
| 143 |
+
self._target_pos = None # OVERRIDE ME
|
| 144 |
+
self._random_reset_space = None # OVERRIDE ME
|
| 145 |
+
|
| 146 |
+
self._last_stable_obs = None
|
| 147 |
+
# Note: It is unlikely that the positions and orientations stored
|
| 148 |
+
# in this initiation of _prev_obs are correct. That being said, it
|
| 149 |
+
# doesn't seem to matter (it will only effect frame-stacking for the
|
| 150 |
+
# very first observation)
|
| 151 |
+
self._prev_obs = self._get_curr_obs_combined_no_goal()
|
| 152 |
+
|
| 153 |
+
def _set_task_inner(self):
|
| 154 |
+
# Doesn't absorb "extra" kwargs, to ensure nothing's missed.
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
def set_task(self, task):
|
| 158 |
+
self._set_task_called = True
|
| 159 |
+
data = pickle.loads(task.data)
|
| 160 |
+
assert isinstance(self, data['env_cls'])
|
| 161 |
+
del data['env_cls']
|
| 162 |
+
self._last_rand_vec = data['rand_vec']
|
| 163 |
+
self._freeze_rand_vec = True
|
| 164 |
+
self._last_rand_vec = data['rand_vec']
|
| 165 |
+
del data['rand_vec']
|
| 166 |
+
self._partially_observable = data['partially_observable']
|
| 167 |
+
del data['partially_observable']
|
| 168 |
+
self._set_task_inner(**data)
|
| 169 |
+
self.reset()
|
| 170 |
+
|
| 171 |
+
def set_xyz_action(self, action):
|
| 172 |
+
action = np.clip(action, -1, 1)
|
| 173 |
+
pos_delta = action * self.action_scale
|
| 174 |
+
new_mocap_pos = self.data.mocap_pos + pos_delta[None]
|
| 175 |
+
|
| 176 |
+
new_mocap_pos[0, :] = np.clip(
|
| 177 |
+
new_mocap_pos[0, :],
|
| 178 |
+
self.mocap_low,
|
| 179 |
+
self.mocap_high,
|
| 180 |
+
)
|
| 181 |
+
self.data.set_mocap_pos('mocap', new_mocap_pos)
|
| 182 |
+
self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
|
| 183 |
+
|
| 184 |
+
def discretize_goal_space(self, goals):
|
| 185 |
+
assert False
|
| 186 |
+
assert len(goals) >= 1
|
| 187 |
+
self.discrete_goals = goals
|
| 188 |
+
# update the goal_space to a Discrete space
|
| 189 |
+
self.discrete_goal_space = Discrete(len(self.discrete_goals))
|
| 190 |
+
|
| 191 |
+
def _set_obj_xyz(self, pos):
|
| 192 |
+
qpos = self.data.qpos.flat.copy()
|
| 193 |
+
qvel = self.data.qvel.flat.copy()
|
| 194 |
+
qpos[9:12] = pos.copy()
|
| 195 |
+
qvel[9:15] = 0
|
| 196 |
+
self.set_state(qpos, qvel)
|
| 197 |
+
|
| 198 |
+
def _get_site_pos(self, siteName):
|
| 199 |
+
_id = self.model.site_names.index(siteName)
|
| 200 |
+
return self.data.site_xpos[_id].copy()
|
| 201 |
+
|
| 202 |
+
def _set_pos_site(self, name, pos):
|
| 203 |
+
"""Sets the position of the site corresponding to `name`
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
name (str): The site's name
|
| 207 |
+
pos (np.ndarray): Flat, 3 element array indicating site's location
|
| 208 |
+
"""
|
| 209 |
+
assert isinstance(pos, np.ndarray)
|
| 210 |
+
assert pos.ndim == 1
|
| 211 |
+
|
| 212 |
+
self.data.site_xpos[self.model.site_name2id(name)] = pos[:3]
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def _target_site_config(self):
|
| 216 |
+
"""Retrieves site name(s) and position(s) corresponding to env targets
|
| 217 |
+
|
| 218 |
+
:rtype: list of (str, np.ndarray)
|
| 219 |
+
"""
|
| 220 |
+
return [('goal', self._target_pos)]
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def touching_main_object(self):
|
| 224 |
+
"""Calls `touching_object` for the ID of the env's main object
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
(bool) whether the gripper is touching the object
|
| 228 |
+
|
| 229 |
+
"""
|
| 230 |
+
return self.touching_object(self._get_id_main_object)
|
| 231 |
+
|
| 232 |
+
def touching_object(self, object_geom_id):
|
| 233 |
+
"""Determines whether the gripper is touching the object with given id
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
object_geom_id (int): the ID of the object in question
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
(bool): whether the gripper is touching the object
|
| 240 |
+
|
| 241 |
+
"""
|
| 242 |
+
leftpad_geom_id = self.unwrapped.model.geom_name2id('leftpad_geom')
|
| 243 |
+
rightpad_geom_id = self.unwrapped.model.geom_name2id('rightpad_geom')
|
| 244 |
+
|
| 245 |
+
leftpad_object_contacts = [
|
| 246 |
+
x for x in self.unwrapped.data.contact
|
| 247 |
+
if (leftpad_geom_id in (x.geom1, x.geom2)
|
| 248 |
+
and object_geom_id in (x.geom1, x.geom2))
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
rightpad_object_contacts = [
|
| 252 |
+
x for x in self.unwrapped.data.contact
|
| 253 |
+
if (rightpad_geom_id in (x.geom1, x.geom2)
|
| 254 |
+
and object_geom_id in (x.geom1, x.geom2))
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
leftpad_object_contact_force = sum(
|
| 258 |
+
self.unwrapped.data.efc_force[x.efc_address]
|
| 259 |
+
for x in leftpad_object_contacts)
|
| 260 |
+
|
| 261 |
+
rightpad_object_contact_force = sum(
|
| 262 |
+
self.unwrapped.data.efc_force[x.efc_address]
|
| 263 |
+
for x in rightpad_object_contacts)
|
| 264 |
+
|
| 265 |
+
return 0 < leftpad_object_contact_force and \
|
| 266 |
+
0 < rightpad_object_contact_force
|
| 267 |
+
|
| 268 |
+
@property
|
| 269 |
+
def _get_id_main_object(self):
|
| 270 |
+
return self.unwrapped.model.geom_name2id('objGeom')
|
| 271 |
+
|
| 272 |
+
def _get_pos_objects(self):
|
| 273 |
+
"""Retrieves object position(s) from mujoco properties or instance vars
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
np.ndarray: Flat array (usually 3 elements) representing the
|
| 277 |
+
object(s)' position(s)
|
| 278 |
+
"""
|
| 279 |
+
# Throw error rather than making this an @abc.abstractmethod so that
|
| 280 |
+
# V1 environments don't have to implement it
|
| 281 |
+
raise NotImplementedError
|
| 282 |
+
|
| 283 |
+
def _get_quat_objects(self):
|
| 284 |
+
"""Retrieves object quaternion(s) from mujoco properties
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
np.ndarray: Flat array (usually 4 elements) representing the
|
| 288 |
+
object(s)' quaternion(s)
|
| 289 |
+
|
| 290 |
+
"""
|
| 291 |
+
# Throw error rather than making this an @abc.abstractmethod so that
|
| 292 |
+
# V1 environments don't have to implement it
|
| 293 |
+
if self.isV2:
|
| 294 |
+
raise NotImplementedError
|
| 295 |
+
else:
|
| 296 |
+
return None
|
| 297 |
+
|
| 298 |
+
def _get_pos_goal(self):
|
| 299 |
+
"""Retrieves goal position from mujoco properties or instance vars
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
np.ndarray: Flat array (3 elements) representing the goal position
|
| 303 |
+
"""
|
| 304 |
+
assert isinstance(self._target_pos, np.ndarray)
|
| 305 |
+
assert self._target_pos.ndim == 1
|
| 306 |
+
return self._target_pos
|
| 307 |
+
|
| 308 |
+
def _get_curr_obs_combined_no_goal(self):
|
| 309 |
+
"""Combines the end effector's {pos, closed amount} and the object(s)'
|
| 310 |
+
{pos, quat} into a single flat observation. The goal's position is
|
| 311 |
+
*not* included in this.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
np.ndarray: The flat observation array (18 elements)
|
| 315 |
+
|
| 316 |
+
"""
|
| 317 |
+
pos_hand = self.get_endeff_pos()
|
| 318 |
+
|
| 319 |
+
finger_right, finger_left = (
|
| 320 |
+
self._get_site_pos('rightEndEffector'),
|
| 321 |
+
self._get_site_pos('leftEndEffector')
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# the gripper can be at maximum about ~0.1 m apart.
|
| 325 |
+
# dividing by 0.1 normalized the gripper distance between
|
| 326 |
+
# 0 and 1. Further, we clip because sometimes the grippers
|
| 327 |
+
# are slightly more than 0.1m apart (~0.00045 m)
|
| 328 |
+
# clipping removes the effects of this random extra distance
|
| 329 |
+
# that is produced by mujoco
|
| 330 |
+
gripper_distance_apart = np.linalg.norm(finger_right - finger_left)
|
| 331 |
+
gripper_distance_apart = np.clip(gripper_distance_apart / 0.1, 0., 1.)
|
| 332 |
+
|
| 333 |
+
obs_obj_padded = np.zeros(self._obs_obj_max_len)
|
| 334 |
+
|
| 335 |
+
obj_pos = self._get_pos_objects()
|
| 336 |
+
assert len(obj_pos) % 3 == 0
|
| 337 |
+
|
| 338 |
+
obj_pos_split = np.split(obj_pos, len(obj_pos) // 3)
|
| 339 |
+
|
| 340 |
+
if self.isV2:
|
| 341 |
+
obj_quat = self._get_quat_objects()
|
| 342 |
+
assert len(obj_quat) % 4 == 0
|
| 343 |
+
obj_quat_split = np.split(obj_quat, len(obj_quat) // 4)
|
| 344 |
+
obs_obj_padded[:len(obj_pos) + len(obj_quat)] = np.hstack([
|
| 345 |
+
np.hstack((pos, quat))
|
| 346 |
+
for pos, quat in zip(obj_pos_split, obj_quat_split)
|
| 347 |
+
])
|
| 348 |
+
assert(len(obs_obj_padded) in self._obs_obj_possible_lens)
|
| 349 |
+
return np.hstack((pos_hand, gripper_distance_apart, obs_obj_padded))
|
| 350 |
+
else:
|
| 351 |
+
# is a v1 environment
|
| 352 |
+
obs_obj_padded[:len(obj_pos)] = obj_pos
|
| 353 |
+
assert(len(obs_obj_padded) in self._obs_obj_possible_lens)
|
| 354 |
+
return np.hstack((pos_hand, obs_obj_padded))
|
| 355 |
+
|
| 356 |
+
def _get_obs(self):
|
| 357 |
+
"""Frame stacks `_get_curr_obs_combined_no_goal()` and concatenates the
|
| 358 |
+
goal position to form a single flat observation.
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
np.ndarray: The flat observation array (39 elements)
|
| 362 |
+
"""
|
| 363 |
+
# do frame stacking
|
| 364 |
+
pos_goal = self._get_pos_goal()
|
| 365 |
+
if self._partially_observable:
|
| 366 |
+
pos_goal = np.zeros_like(pos_goal)
|
| 367 |
+
curr_obs = self._get_curr_obs_combined_no_goal()
|
| 368 |
+
# do frame stacking
|
| 369 |
+
if self.isV2:
|
| 370 |
+
obs = np.hstack((curr_obs, self._prev_obs, pos_goal))
|
| 371 |
+
else:
|
| 372 |
+
obs = np.hstack((curr_obs, pos_goal))
|
| 373 |
+
self._prev_obs = curr_obs
|
| 374 |
+
return obs
|
| 375 |
+
|
| 376 |
+
def _get_obs_dict(self):
|
| 377 |
+
obs = self._get_obs()
|
| 378 |
+
return dict(
|
| 379 |
+
state_observation=obs,
|
| 380 |
+
state_desired_goal=self._get_pos_goal(),
|
| 381 |
+
state_achieved_goal=obs[3:-3],
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
@property
|
| 385 |
+
def observation_space(self):
|
| 386 |
+
obs_obj_max_len = self._obs_obj_max_len if self.isV2 else 6
|
| 387 |
+
|
| 388 |
+
obj_low = np.full(obs_obj_max_len, -np.inf)
|
| 389 |
+
obj_high = np.full(obs_obj_max_len, +np.inf)
|
| 390 |
+
goal_low = np.zeros(3) if self._partially_observable \
|
| 391 |
+
else self.goal_space.low
|
| 392 |
+
goal_high = np.zeros(3) if self._partially_observable \
|
| 393 |
+
else self.goal_space.high
|
| 394 |
+
gripper_low = -1.
|
| 395 |
+
gripper_high = +1.
|
| 396 |
+
|
| 397 |
+
return Box(
|
| 398 |
+
np.hstack((self._HAND_SPACE.low, gripper_low, obj_low, self._HAND_SPACE.low, gripper_low, obj_low, goal_low)),
|
| 399 |
+
np.hstack((self._HAND_SPACE.high, gripper_high, obj_high, self._HAND_SPACE.high, gripper_high, obj_high, goal_high))
|
| 400 |
+
) if self.isV2 else Box(
|
| 401 |
+
np.hstack((self._HAND_SPACE.low, obj_low, goal_low)),
|
| 402 |
+
np.hstack((self._HAND_SPACE.high, obj_high, goal_high))
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
@_assert_task_is_set
|
| 406 |
+
def step(self, action):
|
| 407 |
+
self.set_xyz_action(action[:3])
|
| 408 |
+
self.do_simulation([action[-1], -action[-1]])
|
| 409 |
+
self.curr_path_length += 1
|
| 410 |
+
|
| 411 |
+
# Running the simulator can sometimes mess up site positions, so
|
| 412 |
+
# re-position them here to make sure they're accurate
|
| 413 |
+
for site in self._target_site_config:
|
| 414 |
+
self._set_pos_site(*site)
|
| 415 |
+
|
| 416 |
+
if self._did_see_sim_exception:
|
| 417 |
+
return (
|
| 418 |
+
self._last_stable_obs, # observation just before going unstable
|
| 419 |
+
0.0, # reward (penalize for causing instability)
|
| 420 |
+
False, # termination flag always False
|
| 421 |
+
{ # info
|
| 422 |
+
'success': False,
|
| 423 |
+
'near_object': 0.0,
|
| 424 |
+
'grasp_success': False,
|
| 425 |
+
'grasp_reward': 0.0,
|
| 426 |
+
'in_place_reward': 0.0,
|
| 427 |
+
'obj_to_target': 0.0,
|
| 428 |
+
'unscaled_reward': 0.0,
|
| 429 |
+
}
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
self._last_stable_obs = self._get_obs()
|
| 433 |
+
if not self.isV2:
|
| 434 |
+
# v1 environments expect this superclass step() to return only the
|
| 435 |
+
# most recent observation. they override the rest of the
|
| 436 |
+
# functionality and end up returning the same sort of tuple that
|
| 437 |
+
# this does
|
| 438 |
+
return self._last_stable_obs
|
| 439 |
+
|
| 440 |
+
reward, info = self.evaluate_state(self._last_stable_obs, action)
|
| 441 |
+
return self._last_stable_obs, reward, False, info
|
| 442 |
+
|
| 443 |
+
def evaluate_state(self, obs, action):
|
| 444 |
+
"""Does the heavy-lifting for `step()` -- namely, calculating reward
|
| 445 |
+
and populating the `info` dict with training metrics
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
float: Reward between 0 and 10
|
| 449 |
+
dict: Dictionary which contains useful metrics (success,
|
| 450 |
+
near_object, grasp_success, grasp_reward, in_place_reward,
|
| 451 |
+
obj_to_target, unscaled_reward)
|
| 452 |
+
|
| 453 |
+
"""
|
| 454 |
+
# Throw error rather than making this an @abc.abstractmethod so that
|
| 455 |
+
# V1 environments don't have to implement it
|
| 456 |
+
raise NotImplementedError
|
| 457 |
+
|
| 458 |
+
def reset(self):
|
| 459 |
+
self.curr_path_length = 0
|
| 460 |
+
return super().reset()
|
| 461 |
+
|
| 462 |
+
def _reset_hand(self, steps=50):
|
| 463 |
+
for _ in range(steps):
|
| 464 |
+
self.data.set_mocap_pos('mocap', self.hand_init_pos)
|
| 465 |
+
self.data.set_mocap_quat('mocap', np.array([1, 0, 1, 0]))
|
| 466 |
+
self.do_simulation([-1, 1], self.frame_skip)
|
| 467 |
+
self.init_tcp = self.tcp_center
|
| 468 |
+
|
| 469 |
+
def _get_state_rand_vec(self):
|
| 470 |
+
if self._freeze_rand_vec:
|
| 471 |
+
assert self._last_rand_vec is not None
|
| 472 |
+
return self._last_rand_vec
|
| 473 |
+
elif self.seeded_rand_vec:
|
| 474 |
+
rand_vec = self.np_random.uniform(
|
| 475 |
+
self._random_reset_space.low,
|
| 476 |
+
self._random_reset_space.high,
|
| 477 |
+
size=self._random_reset_space.low.size)
|
| 478 |
+
return rand_vec
|
| 479 |
+
else:
|
| 480 |
+
rand_vec = np.random.uniform(
|
| 481 |
+
self._random_reset_space.low,
|
| 482 |
+
self._random_reset_space.high,
|
| 483 |
+
size=self._random_reset_space.low.size)
|
| 484 |
+
self._last_rand_vec = rand_vec
|
| 485 |
+
return rand_vec
|
| 486 |
+
|
| 487 |
+
def _gripper_caging_reward(self,
|
| 488 |
+
action,
|
| 489 |
+
obj_pos,
|
| 490 |
+
obj_radius,
|
| 491 |
+
pad_success_thresh,
|
| 492 |
+
object_reach_radius,
|
| 493 |
+
xz_thresh,
|
| 494 |
+
desired_gripper_effort=1.0,
|
| 495 |
+
high_density=False,
|
| 496 |
+
medium_density=False):
|
| 497 |
+
"""Reward for agent grasping obj
|
| 498 |
+
Args:
|
| 499 |
+
action(np.ndarray): (4,) array representing the action
|
| 500 |
+
delta(x), delta(y), delta(z), gripper_effort
|
| 501 |
+
obj_pos(np.ndarray): (3,) array representing the obj x,y,z
|
| 502 |
+
obj_radius(float):radius of object's bounding sphere
|
| 503 |
+
pad_success_thresh(float): successful distance of gripper_pad
|
| 504 |
+
to object
|
| 505 |
+
object_reach_radius(float): successful distance of gripper center
|
| 506 |
+
to the object.
|
| 507 |
+
xz_thresh(float): successful distance of gripper in x_z axis to the
|
| 508 |
+
object. Y axis not included since the caging function handles
|
| 509 |
+
successful grasping in the Y axis.
|
| 510 |
+
"""
|
| 511 |
+
if high_density and medium_density:
|
| 512 |
+
raise ValueError("Can only be either high_density or medium_density")
|
| 513 |
+
# MARK: Left-right gripper information for caging reward----------------
|
| 514 |
+
left_pad = self.get_body_com('leftpad')
|
| 515 |
+
right_pad = self.get_body_com('rightpad')
|
| 516 |
+
|
| 517 |
+
# get current positions of left and right pads (Y axis)
|
| 518 |
+
pad_y_lr = np.hstack((left_pad[1], right_pad[1]))
|
| 519 |
+
# compare *current* pad positions with *current* obj position (Y axis)
|
| 520 |
+
pad_to_obj_lr = np.abs(pad_y_lr - obj_pos[1])
|
| 521 |
+
# compare *current* pad positions with *initial* obj position (Y axis)
|
| 522 |
+
pad_to_objinit_lr = np.abs(pad_y_lr - self.obj_init_pos[1])
|
| 523 |
+
|
| 524 |
+
# Compute the left/right caging rewards. This is crucial for success,
|
| 525 |
+
# yet counterintuitive mathematically because we invented it
|
| 526 |
+
# accidentally.
|
| 527 |
+
#
|
| 528 |
+
# Before touching the object, `pad_to_obj_lr` ("x") is always separated
|
| 529 |
+
# from `caging_lr_margin` ("the margin") by some small number,
|
| 530 |
+
# `pad_success_thresh`.
|
| 531 |
+
#
|
| 532 |
+
# When far away from the object:
|
| 533 |
+
# x = margin + pad_success_thresh
|
| 534 |
+
# --> Thus x is outside the margin, yielding very small reward.
|
| 535 |
+
# Here, any variation in the reward is due to the fact that
|
| 536 |
+
# the margin itself is shifting.
|
| 537 |
+
# When near the object (within pad_success_thresh):
|
| 538 |
+
# x = pad_success_thresh - margin
|
| 539 |
+
# --> Thus x is well within the margin. As long as x > obj_radius,
|
| 540 |
+
# it will also be within the bounds, yielding maximum reward.
|
| 541 |
+
# Here, any variation in the reward is due to the gripper
|
| 542 |
+
# moving *too close* to the object (i.e, blowing past the
|
| 543 |
+
# obj_radius bound).
|
| 544 |
+
#
|
| 545 |
+
# Therefore, before touching the object, this is very nearly a binary
|
| 546 |
+
# reward -- if the gripper is between obj_radius and pad_success_thresh,
|
| 547 |
+
# it gets maximum reward. Otherwise, the reward very quickly falls off.
|
| 548 |
+
#
|
| 549 |
+
# After grasping the object and moving it away from initial position,
|
| 550 |
+
# x remains (mostly) constant while the margin grows considerably. This
|
| 551 |
+
# penalizes the agent if it moves *back* toward `obj_init_pos`, but
|
| 552 |
+
# offers no encouragement for leaving that position in the first place.
|
| 553 |
+
# That part is left to the reward functions of individual environments.
|
| 554 |
+
caging_lr_margin = np.abs(pad_to_objinit_lr - pad_success_thresh)
|
| 555 |
+
caging_lr = [reward_utils.tolerance(
|
| 556 |
+
pad_to_obj_lr[i], # "x" in the description above
|
| 557 |
+
bounds=(obj_radius, pad_success_thresh),
|
| 558 |
+
margin=caging_lr_margin[i], # "margin" in the description above
|
| 559 |
+
sigmoid='long_tail',
|
| 560 |
+
) for i in range(2)]
|
| 561 |
+
caging_y = reward_utils.hamacher_product(*caging_lr)
|
| 562 |
+
|
| 563 |
+
# MARK: X-Z gripper information for caging reward-----------------------
|
| 564 |
+
tcp = self.tcp_center
|
| 565 |
+
xz = [0, 2]
|
| 566 |
+
|
| 567 |
+
# Compared to the caging_y reward, caging_xz is simple. The margin is
|
| 568 |
+
# constant (something in the 0.3 to 0.5 range) and x shrinks as the
|
| 569 |
+
# gripper moves towards the object. After picking up the object, the
|
| 570 |
+
# reward is maximized and changes very little
|
| 571 |
+
caging_xz_margin = np.linalg.norm(self.obj_init_pos[xz] - self.init_tcp[xz])
|
| 572 |
+
caging_xz_margin -= xz_thresh
|
| 573 |
+
caging_xz = reward_utils.tolerance(
|
| 574 |
+
np.linalg.norm(tcp[xz] - obj_pos[xz]), # "x" in the description above
|
| 575 |
+
bounds=(0, xz_thresh),
|
| 576 |
+
margin=caging_xz_margin, # "margin" in the description above
|
| 577 |
+
sigmoid='long_tail',
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# MARK: Closed-extent gripper information for caging reward-------------
|
| 581 |
+
gripper_closed = min(max(0, action[-1]), desired_gripper_effort) \
|
| 582 |
+
/ desired_gripper_effort
|
| 583 |
+
|
| 584 |
+
# MARK: Combine components----------------------------------------------
|
| 585 |
+
caging = reward_utils.hamacher_product(caging_y, caging_xz)
|
| 586 |
+
gripping = gripper_closed if caging > 0.97 else 0.
|
| 587 |
+
caging_and_gripping = reward_utils.hamacher_product(caging, gripping)
|
| 588 |
+
|
| 589 |
+
if high_density:
|
| 590 |
+
caging_and_gripping = (caging_and_gripping + caging) / 2
|
| 591 |
+
if medium_density:
|
| 592 |
+
tcp = self.tcp_center
|
| 593 |
+
tcp_to_obj = np.linalg.norm(obj_pos - tcp)
|
| 594 |
+
tcp_to_obj_init = np.linalg.norm(self.obj_init_pos - self.init_tcp)
|
| 595 |
+
# Compute reach reward
|
| 596 |
+
# - We subtract `object_reach_radius` from the margin so that the
|
| 597 |
+
# reward always starts with a value of 0.1
|
| 598 |
+
reach_margin = abs(tcp_to_obj_init - object_reach_radius)
|
| 599 |
+
reach = reward_utils.tolerance(
|
| 600 |
+
tcp_to_obj,
|
| 601 |
+
bounds=(0, object_reach_radius),
|
| 602 |
+
margin=reach_margin,
|
| 603 |
+
sigmoid='long_tail',
|
| 604 |
+
)
|
| 605 |
+
caging_and_gripping = (caging_and_gripping + reach) / 2
|
| 606 |
+
|
| 607 |
+
return caging_and_gripping
|
Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_coffee_push.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from gym.spaces import Box
|
| 3 |
+
|
| 4 |
+
from metaworld.envs.asset_path_utils import full_v1_path_for
|
| 5 |
+
from metaworld.envs.mujoco.sawyer_xyz.sawyer_xyz_env import SawyerXYZEnv, _assert_task_is_set
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SawyerCoffeePushEnv(SawyerXYZEnv):
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
|
| 12 |
+
hand_low = (-0.5, 0.40, 0.05)
|
| 13 |
+
hand_high = (0.5, 1, 0.5)
|
| 14 |
+
obj_low = (-0.1, 0.6, 0.)
|
| 15 |
+
obj_high = (0.1, 0.7, 0.)
|
| 16 |
+
goal_low = (-0.1, 0.8, -.001)
|
| 17 |
+
goal_high = (0.1, 0.9, 0.0)
|
| 18 |
+
|
| 19 |
+
super().__init__(
|
| 20 |
+
self.model_name,
|
| 21 |
+
hand_low=hand_low,
|
| 22 |
+
hand_high=hand_high,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.init_config = {
|
| 26 |
+
'obj_init_angle': 0.3,
|
| 27 |
+
'obj_init_pos': np.array([0., .6, 0.]),
|
| 28 |
+
'hand_init_pos': np.array([0., .6, .2]),
|
| 29 |
+
}
|
| 30 |
+
self.goal = np.array([0., 0.8, 0])
|
| 31 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 32 |
+
self.obj_init_angle = self.init_config['obj_init_angle']
|
| 33 |
+
self.hand_init_pos = self.init_config['hand_init_pos']
|
| 34 |
+
|
| 35 |
+
self._random_reset_space = Box(
|
| 36 |
+
np.hstack((obj_low, goal_low)),
|
| 37 |
+
np.hstack((obj_high, goal_high)),
|
| 38 |
+
)
|
| 39 |
+
self.goal_space = Box(np.array(goal_low), np.array(goal_high))
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def model_name(self):
|
| 43 |
+
return full_v1_path_for('sawyer_xyz/sawyer_coffee.xml')
|
| 44 |
+
|
| 45 |
+
@_assert_task_is_set
|
| 46 |
+
def step(self, action):
|
| 47 |
+
ob = super().step(action)
|
| 48 |
+
reward, reachDist, pushDist = self.compute_reward(action, ob)
|
| 49 |
+
info = {
|
| 50 |
+
'reachDist': reachDist,
|
| 51 |
+
'goalDist': pushDist,
|
| 52 |
+
'epRew': reward,
|
| 53 |
+
'pickRew': None,
|
| 54 |
+
'success': float(pushDist <= 0.07)
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
return ob, reward, False, info
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def _target_site_config(self):
|
| 61 |
+
return [('coffee_goal', self._target_pos)]
|
| 62 |
+
|
| 63 |
+
def _get_pos_objects(self):
|
| 64 |
+
return self.data.get_geom_xpos('objGeom')
|
| 65 |
+
|
| 66 |
+
def adjust_initObjPos(self, orig_init_pos):
|
| 67 |
+
# This is to account for meshes for the geom and object are not aligned
|
| 68 |
+
# If this is not done, the object could be initialized in an extreme position
|
| 69 |
+
diff = self.get_body_com('obj')[:2] - self.data.get_geom_xpos('objGeom')[:2]
|
| 70 |
+
adjustedPos = orig_init_pos[:2] + diff
|
| 71 |
+
|
| 72 |
+
#The convention we follow is that body_com[2] is always 0, and geom_pos[2] is the object height
|
| 73 |
+
return [adjustedPos[0], adjustedPos[1],self.get_body_com('obj')[-1]]
|
| 74 |
+
|
| 75 |
+
def reset_model(self):
|
| 76 |
+
self._reset_hand()
|
| 77 |
+
self._target_pos = self.goal.copy()
|
| 78 |
+
self.obj_init_pos = self.adjust_initObjPos(self.init_config['obj_init_pos'])
|
| 79 |
+
self.obj_init_angle = self.init_config['obj_init_angle']
|
| 80 |
+
self.objHeight = self.data.get_geom_xpos('objGeom')[2]
|
| 81 |
+
|
| 82 |
+
if self.random_init:
|
| 83 |
+
goal_pos = self._get_state_rand_vec()
|
| 84 |
+
self._target_pos = goal_pos[3:]
|
| 85 |
+
while np.linalg.norm(goal_pos[:2] - self._target_pos[:2]) < 0.15:
|
| 86 |
+
goal_pos = self._get_state_rand_vec()
|
| 87 |
+
self._target_pos = goal_pos[3:]
|
| 88 |
+
self._target_pos = np.concatenate((goal_pos[-3:-1], [self.obj_init_pos[-1]]))
|
| 89 |
+
self.obj_init_pos = np.concatenate((goal_pos[:2], [self.obj_init_pos[-1]]))
|
| 90 |
+
machine_pos = self._target_pos - np.array([0, -0.1, -0.27])
|
| 91 |
+
button_pos = machine_pos + np.array([0., -0.12, 0.05])
|
| 92 |
+
self.sim.model.body_pos[self.model.body_name2id('coffee_machine')] = machine_pos
|
| 93 |
+
self.sim.model.body_pos[self.model.body_name2id('button')] = button_pos
|
| 94 |
+
|
| 95 |
+
self._set_obj_xyz(self.obj_init_pos)
|
| 96 |
+
self.maxPushDist = np.linalg.norm(self.obj_init_pos[:2] - np.array(self._target_pos)[:2])
|
| 97 |
+
|
| 98 |
+
return self._get_obs()
|
| 99 |
+
|
| 100 |
+
def _reset_hand(self):
|
| 101 |
+
super()._reset_hand(10)
|
| 102 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 103 |
+
self.init_fingerCOM = (rightFinger + leftFinger)/2
|
| 104 |
+
self.reachCompleted = False
|
| 105 |
+
|
| 106 |
+
def compute_reward(self, actions, obs):
|
| 107 |
+
del actions
|
| 108 |
+
|
| 109 |
+
objPos = obs[3:6]
|
| 110 |
+
|
| 111 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 112 |
+
fingerCOM = (rightFinger + leftFinger)/2
|
| 113 |
+
|
| 114 |
+
goal = self._target_pos
|
| 115 |
+
|
| 116 |
+
c1 = 1000
|
| 117 |
+
c2 = 0.01
|
| 118 |
+
c3 = 0.001
|
| 119 |
+
assert np.all(goal == self._get_site_pos('coffee_goal'))
|
| 120 |
+
reachDist = np.linalg.norm(fingerCOM - objPos)
|
| 121 |
+
pushDist = np.linalg.norm(objPos[:2] - goal[:2])
|
| 122 |
+
reachRew = -reachDist
|
| 123 |
+
|
| 124 |
+
if reachDist < 0.05:
|
| 125 |
+
pushRew = 1000*(self.maxPushDist - pushDist) + c1*(np.exp(-(pushDist**2)/c2) + np.exp(-(pushDist**2)/c3))
|
| 126 |
+
pushRew = max(pushRew, 0)
|
| 127 |
+
else:
|
| 128 |
+
pushRew = 0
|
| 129 |
+
|
| 130 |
+
reward = reachRew + pushRew
|
| 131 |
+
|
| 132 |
+
return [reward, reachDist, pushDist]
|
Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_dial_turn.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from gym.spaces import Box
|
| 3 |
+
|
| 4 |
+
from metaworld.envs.asset_path_utils import full_v1_path_for
|
| 5 |
+
from metaworld.envs.mujoco.sawyer_xyz.sawyer_xyz_env import SawyerXYZEnv, _assert_task_is_set
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SawyerDialTurnEnv(SawyerXYZEnv):
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
|
| 12 |
+
hand_low = (-0.5, 0.40, 0.05)
|
| 13 |
+
hand_high = (0.5, 1, 0.5)
|
| 14 |
+
obj_low = (-0.1, 0.7, 0.05)
|
| 15 |
+
obj_high = (0.1, 0.8, 0.05)
|
| 16 |
+
|
| 17 |
+
super().__init__(
|
| 18 |
+
self.model_name,
|
| 19 |
+
hand_low=hand_low,
|
| 20 |
+
hand_high=hand_high,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
self.init_config = {
|
| 24 |
+
'obj_init_pos': np.array([0, 0.7, 0.05]),
|
| 25 |
+
'hand_init_pos': np.array([0, 0.6, 0.2], dtype=np.float32),
|
| 26 |
+
}
|
| 27 |
+
self.goal = np.array([0., 0.73, 0.08])
|
| 28 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 29 |
+
self.hand_init_pos = self.init_config['hand_init_pos']
|
| 30 |
+
goal_low = self.hand_low
|
| 31 |
+
goal_high = self.hand_high
|
| 32 |
+
|
| 33 |
+
self._random_reset_space = Box(
|
| 34 |
+
np.array(obj_low),
|
| 35 |
+
np.array(obj_high),
|
| 36 |
+
)
|
| 37 |
+
self.goal_space = Box(np.array(goal_low), np.array(goal_high))
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def model_name(self):
|
| 41 |
+
return full_v1_path_for('sawyer_xyz/sawyer_dial.xml')
|
| 42 |
+
|
| 43 |
+
@_assert_task_is_set
|
| 44 |
+
def step(self, action):
|
| 45 |
+
ob = super().step(action)
|
| 46 |
+
reward, reachDist, pullDist = self.compute_reward(action, ob)
|
| 47 |
+
|
| 48 |
+
info = {
|
| 49 |
+
'reachDist': reachDist,
|
| 50 |
+
'goalDist': pullDist,
|
| 51 |
+
'epRew': reward,
|
| 52 |
+
'pickRew': None,
|
| 53 |
+
'success': float(pullDist <= 0.03)
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
return ob, reward, False, info
|
| 57 |
+
|
| 58 |
+
def _get_pos_objects(self):
|
| 59 |
+
return self._get_site_pos('dialStart')
|
| 60 |
+
|
| 61 |
+
def reset_model(self):
|
| 62 |
+
self._reset_hand()
|
| 63 |
+
self._target_pos = self.goal.copy()
|
| 64 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 65 |
+
|
| 66 |
+
if self.random_init:
|
| 67 |
+
goal_pos = self._get_state_rand_vec()
|
| 68 |
+
self.obj_init_pos = goal_pos[:3]
|
| 69 |
+
final_pos = goal_pos.copy() + np.array([0, 0.03, 0.03])
|
| 70 |
+
self._target_pos = final_pos
|
| 71 |
+
|
| 72 |
+
self.sim.model.body_pos[self.model.body_name2id('dial')] = self.obj_init_pos
|
| 73 |
+
self.maxPullDist = np.abs(self._target_pos[1] - self.obj_init_pos[1])
|
| 74 |
+
|
| 75 |
+
return self._get_obs()
|
| 76 |
+
|
| 77 |
+
def _reset_hand(self):
|
| 78 |
+
super()._reset_hand(10)
|
| 79 |
+
|
| 80 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 81 |
+
self.init_fingerCOM = (rightFinger + leftFinger)/2
|
| 82 |
+
self.reachCompleted = False
|
| 83 |
+
|
| 84 |
+
def compute_reward(self, actions, obs):
|
| 85 |
+
del actions
|
| 86 |
+
|
| 87 |
+
objPos = obs[3:6]
|
| 88 |
+
|
| 89 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 90 |
+
fingerCOM = (rightFinger + leftFinger)/2
|
| 91 |
+
|
| 92 |
+
pullGoal = self._target_pos
|
| 93 |
+
|
| 94 |
+
pullDist = np.abs(objPos[1] - pullGoal[1])
|
| 95 |
+
reachDist = np.linalg.norm(objPos - fingerCOM)
|
| 96 |
+
reachRew = -reachDist
|
| 97 |
+
|
| 98 |
+
self.reachCompleted = reachDist < 0.05
|
| 99 |
+
|
| 100 |
+
def pullReward():
|
| 101 |
+
c1 = 1000
|
| 102 |
+
c2 = 0.001
|
| 103 |
+
c3 = 0.0001
|
| 104 |
+
|
| 105 |
+
if self.reachCompleted:
|
| 106 |
+
pullRew = 1000*(self.maxPullDist - pullDist) + c1*(np.exp(-(pullDist**2)/c2) + np.exp(-(pullDist**2)/c3))
|
| 107 |
+
pullRew = max(pullRew,0)
|
| 108 |
+
return pullRew
|
| 109 |
+
else:
|
| 110 |
+
return 0
|
| 111 |
+
|
| 112 |
+
pullRew = pullReward()
|
| 113 |
+
reward = reachRew + pullRew
|
| 114 |
+
|
| 115 |
+
return [reward, reachDist, pullDist]
|
Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_disassemble_peg.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from gym.spaces import Box
|
| 3 |
+
|
| 4 |
+
from metaworld.envs.asset_path_utils import full_v1_path_for
|
| 5 |
+
from metaworld.envs.mujoco.sawyer_xyz.sawyer_xyz_env import SawyerXYZEnv, _assert_task_is_set
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SawyerNutDisassembleEnv(SawyerXYZEnv):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
|
| 11 |
+
liftThresh = 0.05
|
| 12 |
+
hand_low = (-0.5, 0.40, 0.05)
|
| 13 |
+
hand_high = (0.5, 1, 0.5)
|
| 14 |
+
obj_low = (0.1, 0.75, 0.02)
|
| 15 |
+
obj_high = (0., 0.85, 0.02)
|
| 16 |
+
goal_low = (-0.1, 0.75, 0.1699)
|
| 17 |
+
goal_high = (0.1, 0.85, 0.1701)
|
| 18 |
+
|
| 19 |
+
super().__init__(
|
| 20 |
+
self.model_name,
|
| 21 |
+
hand_low=hand_low,
|
| 22 |
+
hand_high=hand_high,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.init_config = {
|
| 26 |
+
'obj_init_angle': 0.3,
|
| 27 |
+
'obj_init_pos': np.array([0, 0.8, 0.02]),
|
| 28 |
+
'hand_init_pos': np.array((0, 0.6, 0.2), dtype=np.float32),
|
| 29 |
+
}
|
| 30 |
+
self.goal = np.array([0, 0.8, 0.17])
|
| 31 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 32 |
+
self.obj_init_angle = self.init_config['obj_init_angle']
|
| 33 |
+
self.hand_init_pos = self.init_config['hand_init_pos']
|
| 34 |
+
|
| 35 |
+
self.liftThresh = liftThresh
|
| 36 |
+
|
| 37 |
+
self._random_reset_space = Box(
|
| 38 |
+
np.hstack((obj_low, goal_low)),
|
| 39 |
+
np.hstack((obj_high, goal_high)),
|
| 40 |
+
)
|
| 41 |
+
self.goal_space = Box(np.array(goal_low), np.array(goal_high))
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def model_name(self):
|
| 45 |
+
return full_v1_path_for('sawyer_xyz/sawyer_assembly_peg.xml')
|
| 46 |
+
|
| 47 |
+
@_assert_task_is_set
|
| 48 |
+
def step(self, action):
|
| 49 |
+
ob = super().step(action)
|
| 50 |
+
reward, _, reachDist, pickRew, _, placingDist, success = self.compute_reward(action, ob)
|
| 51 |
+
info = {
|
| 52 |
+
'reachDist': reachDist,
|
| 53 |
+
'pickRew': pickRew,
|
| 54 |
+
'epRew': reward,
|
| 55 |
+
'goalDist': placingDist,
|
| 56 |
+
'success': success
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
return ob, reward, False, info
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def _target_site_config(self):
|
| 63 |
+
return [('pegTop', self._target_pos)]
|
| 64 |
+
|
| 65 |
+
def _get_pos_objects(self):
|
| 66 |
+
return self.data.get_geom_xpos('RoundNut-8')
|
| 67 |
+
|
| 68 |
+
def _get_obs_dict(self):
|
| 69 |
+
obs_dict = super()._get_obs_dict()
|
| 70 |
+
obs_dict['state_achieved_goal'] = self.get_body_com('RoundNut')
|
| 71 |
+
return obs_dict
|
| 72 |
+
|
| 73 |
+
def reset_model(self):
|
| 74 |
+
self._reset_hand()
|
| 75 |
+
self._target_pos = self.goal.copy()
|
| 76 |
+
self.obj_init_pos = np.array(self.init_config['obj_init_pos'])
|
| 77 |
+
self.obj_init_angle = self.init_config['obj_init_angle']
|
| 78 |
+
|
| 79 |
+
if self.random_init:
|
| 80 |
+
goal_pos = self._get_state_rand_vec()
|
| 81 |
+
while np.linalg.norm(goal_pos[:2] - goal_pos[-3:-1]) < 0.1:
|
| 82 |
+
goal_pos = self._get_state_rand_vec()
|
| 83 |
+
self.obj_init_pos = goal_pos[:3]
|
| 84 |
+
self._target_pos = goal_pos[:3] + np.array([0, 0, 0.15])
|
| 85 |
+
|
| 86 |
+
peg_pos = self.obj_init_pos + np.array([0., 0., 0.03])
|
| 87 |
+
peg_top_pos = self.obj_init_pos + np.array([0., 0., 0.08])
|
| 88 |
+
self.sim.model.body_pos[self.model.body_name2id('peg')] = peg_pos
|
| 89 |
+
self.sim.model.site_pos[self.model.site_name2id('pegTop')] = peg_top_pos
|
| 90 |
+
self._set_obj_xyz(self.obj_init_pos)
|
| 91 |
+
self.objHeight = self.data.get_geom_xpos('RoundNut-8')[2]
|
| 92 |
+
self.heightTarget = self.objHeight + self.liftThresh
|
| 93 |
+
self.maxPlacingDist = np.linalg.norm(np.array([self.obj_init_pos[0], self.obj_init_pos[1], self.heightTarget]) - np.array(self._target_pos)) + self.heightTarget
|
| 94 |
+
|
| 95 |
+
return self._get_obs()
|
| 96 |
+
|
| 97 |
+
def _reset_hand(self):
|
| 98 |
+
super()._reset_hand(10)
|
| 99 |
+
|
| 100 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 101 |
+
self.init_fingerCOM = (rightFinger + leftFinger)/2
|
| 102 |
+
self.pickCompleted = False
|
| 103 |
+
|
| 104 |
+
def compute_reward(self, actions, obs):
|
| 105 |
+
|
| 106 |
+
graspPos = obs[3:6]
|
| 107 |
+
objPos = graspPos
|
| 108 |
+
|
| 109 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 110 |
+
fingerCOM = (rightFinger + leftFinger)/2
|
| 111 |
+
|
| 112 |
+
heightTarget = self.heightTarget
|
| 113 |
+
placingGoal = self._target_pos
|
| 114 |
+
|
| 115 |
+
reachDist = np.linalg.norm(graspPos - fingerCOM)
|
| 116 |
+
reachDistxy = np.linalg.norm(graspPos[:-1] - fingerCOM[:-1])
|
| 117 |
+
zDist = np.abs(fingerCOM[-1] - self.init_fingerCOM[-1])
|
| 118 |
+
|
| 119 |
+
placingDist = np.linalg.norm(objPos - placingGoal)
|
| 120 |
+
|
| 121 |
+
def reachReward():
|
| 122 |
+
reachRew = -reachDist
|
| 123 |
+
if reachDistxy < 0.04:
|
| 124 |
+
reachRew = -reachDist
|
| 125 |
+
else:
|
| 126 |
+
reachRew = -reachDistxy - 2*zDist
|
| 127 |
+
|
| 128 |
+
# incentive to close fingers when reachDist is small
|
| 129 |
+
if reachDist < 0.04:
|
| 130 |
+
reachRew = -reachDist + max(actions[-1],0)/50
|
| 131 |
+
return reachRew, reachDist
|
| 132 |
+
|
| 133 |
+
def pickCompletionCriteria():
|
| 134 |
+
tolerance = 0.01
|
| 135 |
+
if objPos[2] >= (heightTarget- tolerance) and reachDist < 0.04:
|
| 136 |
+
return True
|
| 137 |
+
else:
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
if pickCompletionCriteria():
|
| 141 |
+
self.pickCompleted = True
|
| 142 |
+
|
| 143 |
+
def objDropped():
|
| 144 |
+
return (objPos[2] < (self.objHeight + 0.005)) and (placingDist >0.02) and (reachDist > 0.02)
|
| 145 |
+
|
| 146 |
+
def orig_pickReward():
|
| 147 |
+
hScale = 100
|
| 148 |
+
if self.pickCompleted and not(objDropped()):
|
| 149 |
+
return hScale*heightTarget
|
| 150 |
+
elif (reachDist < 0.04) and (objPos[2]> (self.objHeight + 0.005)) :
|
| 151 |
+
return hScale* min(heightTarget, objPos[2])
|
| 152 |
+
else:
|
| 153 |
+
return 0
|
| 154 |
+
|
| 155 |
+
def placeRewardMove():
|
| 156 |
+
c1 = 1000
|
| 157 |
+
c2 = 0.01
|
| 158 |
+
c3 = 0.001
|
| 159 |
+
|
| 160 |
+
placeRew = 1000*(self.maxPlacingDist - placingDist) + c1*(np.exp(-(placingDist**2)/c2) + np.exp(-(placingDist**2)/c3))
|
| 161 |
+
placeRew = max(placeRew,0)
|
| 162 |
+
cond = self.pickCompleted and (reachDist < 0.03) and not(objDropped())
|
| 163 |
+
if cond:
|
| 164 |
+
return [placeRew, placingDist]
|
| 165 |
+
else:
|
| 166 |
+
return [0 , placingDist]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
reachRew, reachDist = reachReward()
|
| 170 |
+
pickRew = orig_pickReward()
|
| 171 |
+
|
| 172 |
+
peg_pos = self.sim.model.body_pos[self.model.body_name2id('peg')]
|
| 173 |
+
nut_pos = self.get_body_com('RoundNut')
|
| 174 |
+
if abs(nut_pos[0] - peg_pos[0]) > 0.05 or \
|
| 175 |
+
abs(nut_pos[1] - peg_pos[1]) > 0.05:
|
| 176 |
+
placingDist = 0
|
| 177 |
+
reachRew = 0
|
| 178 |
+
reachDist = 0
|
| 179 |
+
pickRew = heightTarget*100
|
| 180 |
+
|
| 181 |
+
placeRew , placingDist = placeRewardMove()
|
| 182 |
+
assert ((placeRew >=0) and (pickRew>=0))
|
| 183 |
+
reward = reachRew + pickRew + placeRew
|
| 184 |
+
success = (abs(nut_pos[0] - peg_pos[0]) > 0.05 or abs(nut_pos[1] - peg_pos[1]) > 0.05) or placingDist < 0.02
|
| 185 |
+
|
| 186 |
+
return [reward, reachRew, reachDist, pickRew, placeRew, placingDist, float(success)]
|
Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_plate_slide_back_side.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from gym.spaces import Box
|
| 3 |
+
|
| 4 |
+
from metaworld.envs.asset_path_utils import full_v1_path_for
|
| 5 |
+
from metaworld.envs.mujoco.sawyer_xyz.sawyer_xyz_env import SawyerXYZEnv, _assert_task_is_set
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SawyerPlateSlideBackSideEnv(SawyerXYZEnv):
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
|
| 12 |
+
goal_low = (-0.1, 0.6, 0.015)
|
| 13 |
+
goal_high = (0.1, 0.6, 0.015)
|
| 14 |
+
hand_low = (-0.5, 0.40, 0.05)
|
| 15 |
+
hand_high = (0.5, 1, 0.5)
|
| 16 |
+
obj_low = (-0.25, 0.6, 0.02)
|
| 17 |
+
obj_high = (-0.25, 0.6, 0.02)
|
| 18 |
+
|
| 19 |
+
super().__init__(
|
| 20 |
+
self.model_name,
|
| 21 |
+
hand_low=hand_low,
|
| 22 |
+
hand_high=hand_high,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.init_config = {
|
| 26 |
+
'obj_init_angle': 0.3,
|
| 27 |
+
'obj_init_pos': np.array([-0.25, 0.6, 0.02], dtype=np.float32),
|
| 28 |
+
'hand_init_pos': np.array((0, 0.6, 0.2), dtype=np.float32),
|
| 29 |
+
}
|
| 30 |
+
self.goal = np.array([0., 0.6, 0.015])
|
| 31 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 32 |
+
self.obj_init_angle = self.init_config['obj_init_angle']
|
| 33 |
+
self.hand_init_pos = self.init_config['hand_init_pos']
|
| 34 |
+
|
| 35 |
+
self._random_reset_space = Box(
|
| 36 |
+
np.hstack((obj_low, goal_low)),
|
| 37 |
+
np.hstack((obj_high, goal_high)),
|
| 38 |
+
)
|
| 39 |
+
self.goal_space = Box(np.array(goal_low), np.array(goal_high))
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def model_name(self):
|
| 43 |
+
return full_v1_path_for('sawyer_xyz/sawyer_plate_slide_sideway.xml')
|
| 44 |
+
|
| 45 |
+
@_assert_task_is_set
|
| 46 |
+
def step(self, action):
|
| 47 |
+
ob = super().step(action)
|
| 48 |
+
reward, reachDist, pullDist = self.compute_reward(action, ob)
|
| 49 |
+
|
| 50 |
+
info = {
|
| 51 |
+
'reachDist': reachDist,
|
| 52 |
+
'goalDist': pullDist,
|
| 53 |
+
'epRew': reward,
|
| 54 |
+
'pickRew': None,
|
| 55 |
+
'success': float(pullDist <= 0.07)
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return ob, reward, False, info
|
| 59 |
+
|
| 60 |
+
def _get_pos_objects(self):
|
| 61 |
+
return self.data.get_geom_xpos('objGeom')
|
| 62 |
+
|
| 63 |
+
def _set_obj_xyz(self, pos):
|
| 64 |
+
qpos = self.data.qpos.flat.copy()
|
| 65 |
+
qvel = self.data.qvel.flat.copy()
|
| 66 |
+
qpos[9:11] = pos
|
| 67 |
+
self.set_state(qpos, qvel)
|
| 68 |
+
|
| 69 |
+
def reset_model(self):
|
| 70 |
+
self._reset_hand()
|
| 71 |
+
self._target_pos = self.goal.copy()
|
| 72 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 73 |
+
self.objHeight = self.data.get_geom_xpos('objGeom')[2]
|
| 74 |
+
|
| 75 |
+
if self.random_init:
|
| 76 |
+
obj_pos = self._get_state_rand_vec()
|
| 77 |
+
self.obj_init_pos = obj_pos[:3]
|
| 78 |
+
goal_pos = obj_pos[3:]
|
| 79 |
+
self._target_pos = goal_pos
|
| 80 |
+
|
| 81 |
+
self.sim.model.body_pos[self.model.body_name2id('cabinet')] = self.obj_init_pos
|
| 82 |
+
self._set_obj_xyz(np.array([-0.2, 0.]))
|
| 83 |
+
self.maxDist = np.linalg.norm(self.data.get_geom_xpos('objGeom')[:-1] - self._target_pos[:-1])
|
| 84 |
+
self.target_reward = 1000*self.maxDist + 1000*2
|
| 85 |
+
|
| 86 |
+
return self._get_obs()
|
| 87 |
+
|
| 88 |
+
def _reset_hand(self):
|
| 89 |
+
super()._reset_hand(10)
|
| 90 |
+
|
| 91 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 92 |
+
self.init_fingerCOM = (rightFinger + leftFinger)/2
|
| 93 |
+
|
| 94 |
+
def compute_reward(self, actions, obs):
|
| 95 |
+
del actions
|
| 96 |
+
|
| 97 |
+
objPos = obs[3:6]
|
| 98 |
+
|
| 99 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 100 |
+
fingerCOM = (rightFinger + leftFinger)/2
|
| 101 |
+
|
| 102 |
+
pullGoal = self._target_pos
|
| 103 |
+
|
| 104 |
+
reachDist = np.linalg.norm(objPos - fingerCOM)
|
| 105 |
+
|
| 106 |
+
pullDist = np.linalg.norm(objPos[:-1] - pullGoal[:-1])
|
| 107 |
+
|
| 108 |
+
c1 = 1000
|
| 109 |
+
c2 = 0.01
|
| 110 |
+
c3 = 0.001
|
| 111 |
+
if reachDist < 0.05:
|
| 112 |
+
pullRew = 1000*(self.maxDist - pullDist) + c1*(np.exp(-(pullDist**2)/c2) + np.exp(-(pullDist**2)/c3))
|
| 113 |
+
pullRew = max(pullRew, 0)
|
| 114 |
+
else:
|
| 115 |
+
pullRew = 0
|
| 116 |
+
|
| 117 |
+
reward = -reachDist + pullRew
|
| 118 |
+
|
| 119 |
+
return [reward, reachDist, pullDist]
|
Metaworld/metaworld/envs/mujoco/sawyer_xyz/v1/sawyer_plate_slide_side.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from gym.spaces import Box
|
| 3 |
+
|
| 4 |
+
from metaworld.envs.asset_path_utils import full_v1_path_for
|
| 5 |
+
from metaworld.envs.mujoco.sawyer_xyz.sawyer_xyz_env import SawyerXYZEnv, _assert_task_is_set
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SawyerPlateSlideSideEnv(SawyerXYZEnv):
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
|
| 12 |
+
goal_low = (-0.3, 0.6, 0.02)
|
| 13 |
+
goal_high = (-0.25, 0.7, 0.02)
|
| 14 |
+
hand_low = (-0.5, 0.40, 0.05)
|
| 15 |
+
hand_high = (0.5, 1, 0.5)
|
| 16 |
+
obj_low = (0., 0.6, 0.015)
|
| 17 |
+
obj_high = (0., 0.6, 0.015)
|
| 18 |
+
|
| 19 |
+
super().__init__(
|
| 20 |
+
self.model_name,
|
| 21 |
+
hand_low=hand_low,
|
| 22 |
+
hand_high=hand_high,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.init_config = {
|
| 26 |
+
'obj_init_angle': 0.3,
|
| 27 |
+
'obj_init_pos': np.array([0., 0.6, 0.015], dtype=np.float32),
|
| 28 |
+
'hand_init_pos': np.array((0, 0.6, 0.2), dtype=np.float32),
|
| 29 |
+
}
|
| 30 |
+
self.goal = np.array([-0.25, 0.6, 0.02])
|
| 31 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 32 |
+
self.obj_init_angle = self.init_config['obj_init_angle']
|
| 33 |
+
self.hand_init_pos = self.init_config['hand_init_pos']
|
| 34 |
+
|
| 35 |
+
self._random_reset_space = Box(
|
| 36 |
+
np.hstack((obj_low, goal_low)),
|
| 37 |
+
np.hstack((obj_high, goal_high)),
|
| 38 |
+
)
|
| 39 |
+
self.goal_space = Box(np.array(goal_low), np.array(goal_high))
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def model_name(self):
|
| 43 |
+
return full_v1_path_for('sawyer_xyz/sawyer_plate_slide_sideway.xml')
|
| 44 |
+
|
| 45 |
+
@_assert_task_is_set
|
| 46 |
+
def step(self, action):
|
| 47 |
+
ob = super().step(action)
|
| 48 |
+
reward, reachDist, pullDist = self.compute_reward(action, ob)
|
| 49 |
+
|
| 50 |
+
info = {
|
| 51 |
+
'reachDist': reachDist,
|
| 52 |
+
'goalDist': pullDist,
|
| 53 |
+
'epRew': reward,
|
| 54 |
+
'pickRew': None,
|
| 55 |
+
'success': float(pullDist <= 0.08)
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return ob, reward, False, info
|
| 59 |
+
|
| 60 |
+
def _get_pos_objects(self):
|
| 61 |
+
return self.data.get_geom_xpos('objGeom')
|
| 62 |
+
|
| 63 |
+
def _set_objCOM_marker(self):
|
| 64 |
+
objPos = self.data.get_geom_xpos('handle')
|
| 65 |
+
self.data.site_xpos[self.model.site_name2id('objSite')] = (
|
| 66 |
+
objPos
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def _set_obj_xyz(self, pos):
|
| 70 |
+
qpos = self.data.qpos.flat.copy()
|
| 71 |
+
qvel = self.data.qvel.flat.copy()
|
| 72 |
+
qpos[9:11] = pos
|
| 73 |
+
self.set_state(qpos, qvel)
|
| 74 |
+
|
| 75 |
+
def reset_model(self):
|
| 76 |
+
self._reset_hand()
|
| 77 |
+
self._target_pos = self.goal.copy()
|
| 78 |
+
self.obj_init_pos = self.init_config['obj_init_pos']
|
| 79 |
+
self.objHeight = self.data.get_geom_xpos('objGeom')[2]
|
| 80 |
+
|
| 81 |
+
if self.random_init:
|
| 82 |
+
obj_pos = self._get_state_rand_vec()
|
| 83 |
+
self.obj_init_pos = obj_pos[:3]
|
| 84 |
+
goal_pos = obj_pos[3:]
|
| 85 |
+
self._target_pos = goal_pos
|
| 86 |
+
|
| 87 |
+
self.sim.model.body_pos[self.model.body_name2id('cabinet')] = self._target_pos
|
| 88 |
+
self._set_obj_xyz(np.zeros(2))
|
| 89 |
+
self.maxDist = np.linalg.norm(self.obj_init_pos[:-1] - self._target_pos[:-1])
|
| 90 |
+
self.target_reward = 1000*self.maxDist + 1000*2
|
| 91 |
+
|
| 92 |
+
return self._get_obs()
|
| 93 |
+
|
| 94 |
+
def _reset_hand(self):
|
| 95 |
+
super()._reset_hand(10)
|
| 96 |
+
|
| 97 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 98 |
+
self.init_fingerCOM = (rightFinger + leftFinger)/2
|
| 99 |
+
|
| 100 |
+
def compute_reward(self, actions, obs):
|
| 101 |
+
del actions
|
| 102 |
+
|
| 103 |
+
objPos = obs[3:6]
|
| 104 |
+
|
| 105 |
+
rightFinger, leftFinger = self._get_site_pos('rightEndEffector'), self._get_site_pos('leftEndEffector')
|
| 106 |
+
fingerCOM = (rightFinger + leftFinger)/2
|
| 107 |
+
|
| 108 |
+
pullGoal = self._target_pos
|
| 109 |
+
|
| 110 |
+
reachDist = np.linalg.norm(objPos - fingerCOM)
|
| 111 |
+
|
| 112 |
+
pullDist = np.linalg.norm(objPos[:-1] - pullGoal[:-1])
|
| 113 |
+
|
| 114 |
+
c1 = 1000
|
| 115 |
+
c2 = 0.01
|
| 116 |
+
c3 = 0.001
|
| 117 |
+
if reachDist < 0.05:
|
| 118 |
+
pullRew = 1000*(self.maxDist - pullDist) + c1*(np.exp(-(pullDist**2)/c2) + np.exp(-(pullDist**2)/c3))
|
| 119 |
+
pullRew = max(pullRew, 0)
|
| 120 |
+
else:
|
| 121 |
+
pullRew = 0
|
| 122 |
+
reward = -reachDist + pullRew
|
| 123 |
+
|
| 124 |
+
return [reward, reachDist, pullDist]
|
Metaworld/metaworld/envs/reward_utils.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A set of reward utilities written by the authors of dm_control"""
|
| 2 |
+
|
| 3 |
+
from multiprocessing import Value
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
# The value returned by tolerance() at `margin` distance from `bounds` interval.
|
| 7 |
+
_DEFAULT_VALUE_AT_MARGIN = 0.1
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _sigmoids(x, value_at_1, sigmoid):
|
| 11 |
+
"""Returns 1 when `x` == 0, between 0 and 1 otherwise.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
x: A scalar or numpy array.
|
| 15 |
+
value_at_1: A float between 0 and 1 specifying the output when `x` == 1.
|
| 16 |
+
sigmoid: String, choice of sigmoid type.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
A numpy array with values between 0.0 and 1.0.
|
| 20 |
+
|
| 21 |
+
Raises:
|
| 22 |
+
ValueError: If not 0 < `value_at_1` < 1, except for `linear`, `cosine` and
|
| 23 |
+
`quadratic` sigmoids which allow `value_at_1` == 0.
|
| 24 |
+
ValueError: If `sigmoid` is of an unknown type.
|
| 25 |
+
"""
|
| 26 |
+
if sigmoid in ('cosine', 'linear', 'quadratic'):
|
| 27 |
+
if not 0 <= value_at_1 < 1:
|
| 28 |
+
raise ValueError(
|
| 29 |
+
'`value_at_1` must be nonnegative and smaller than 1, '
|
| 30 |
+
'got {}.'.format(value_at_1))
|
| 31 |
+
else:
|
| 32 |
+
if not 0 < value_at_1 < 1:
|
| 33 |
+
raise ValueError('`value_at_1` must be strictly between 0 and 1, '
|
| 34 |
+
'got {}.'.format(value_at_1))
|
| 35 |
+
|
| 36 |
+
if sigmoid == 'gaussian':
|
| 37 |
+
scale = np.sqrt(-2 * np.log(value_at_1))
|
| 38 |
+
return np.exp(-0.5 * (x * scale)**2)
|
| 39 |
+
|
| 40 |
+
elif sigmoid == 'hyperbolic':
|
| 41 |
+
scale = np.arccosh(1 / value_at_1)
|
| 42 |
+
return 1 / np.cosh(x * scale)
|
| 43 |
+
|
| 44 |
+
elif sigmoid == 'long_tail':
|
| 45 |
+
scale = np.sqrt(1 / value_at_1 - 1)
|
| 46 |
+
return 1 / ((x * scale)**2 + 1)
|
| 47 |
+
|
| 48 |
+
elif sigmoid == 'reciprocal':
|
| 49 |
+
scale = 1 / value_at_1 - 1
|
| 50 |
+
return 1 / (abs(x) * scale + 1)
|
| 51 |
+
|
| 52 |
+
elif sigmoid == 'cosine':
|
| 53 |
+
scale = np.arccos(2 * value_at_1 - 1) / np.pi
|
| 54 |
+
scaled_x = x * scale
|
| 55 |
+
return np.where(
|
| 56 |
+
abs(scaled_x) < 1, (1 + np.cos(np.pi * scaled_x)) / 2, 0.0)
|
| 57 |
+
|
| 58 |
+
elif sigmoid == 'linear':
|
| 59 |
+
scale = 1 - value_at_1
|
| 60 |
+
scaled_x = x * scale
|
| 61 |
+
return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)
|
| 62 |
+
|
| 63 |
+
elif sigmoid == 'quadratic':
|
| 64 |
+
scale = np.sqrt(1 - value_at_1)
|
| 65 |
+
scaled_x = x * scale
|
| 66 |
+
return np.where(abs(scaled_x) < 1, 1 - scaled_x**2, 0.0)
|
| 67 |
+
|
| 68 |
+
elif sigmoid == 'tanh_squared':
|
| 69 |
+
scale = np.arctanh(np.sqrt(1 - value_at_1))
|
| 70 |
+
return 1 - np.tanh(x * scale)**2
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError('Unknown sigmoid type {!r}.'.format(sigmoid))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def tolerance(x,
|
| 77 |
+
bounds=(0.0, 0.0),
|
| 78 |
+
margin=0.0,
|
| 79 |
+
sigmoid='gaussian',
|
| 80 |
+
value_at_margin=_DEFAULT_VALUE_AT_MARGIN):
|
| 81 |
+
"""Returns 1 when `x` falls inside the bounds, between 0 and 1 otherwise.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
x: A scalar or numpy array.
|
| 85 |
+
bounds: A tuple of floats specifying inclusive `(lower, upper)` bounds for
|
| 86 |
+
the target interval. These can be infinite if the interval is unbounded
|
| 87 |
+
at one or both ends, or they can be equal to one another if the target
|
| 88 |
+
value is exact.
|
| 89 |
+
margin: Float. Parameter that controls how steeply the output decreases as
|
| 90 |
+
`x` moves out-of-bounds.
|
| 91 |
+
* If `margin == 0` then the output will be 0 for all values of `x`
|
| 92 |
+
outside of `bounds`.
|
| 93 |
+
* If `margin > 0` then the output will decrease sigmoidally with
|
| 94 |
+
increasing distance from the nearest bound.
|
| 95 |
+
sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian',
|
| 96 |
+
'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
|
| 97 |
+
value_at_margin: A float between 0 and 1 specifying the output value when
|
| 98 |
+
the distance from `x` to the nearest bound is equal to `margin`. Ignored
|
| 99 |
+
if `margin == 0`.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
A float or numpy array with values between 0.0 and 1.0.
|
| 103 |
+
|
| 104 |
+
Raises:
|
| 105 |
+
ValueError: If `bounds[0] > bounds[1]`.
|
| 106 |
+
ValueError: If `margin` is negative.
|
| 107 |
+
"""
|
| 108 |
+
lower, upper = bounds
|
| 109 |
+
if lower > upper:
|
| 110 |
+
raise ValueError('Lower bound must be <= upper bound.')
|
| 111 |
+
if margin < 0:
|
| 112 |
+
raise ValueError('`margin` must be non-negative. Current value: {}'.format(margin))
|
| 113 |
+
|
| 114 |
+
in_bounds = np.logical_and(lower <= x, x <= upper)
|
| 115 |
+
if margin == 0:
|
| 116 |
+
value = np.where(in_bounds, 1.0, 0.0)
|
| 117 |
+
else:
|
| 118 |
+
d = np.where(x < lower, lower - x, x - upper) / margin
|
| 119 |
+
value = np.where(in_bounds, 1.0, _sigmoids(d, value_at_margin,
|
| 120 |
+
sigmoid))
|
| 121 |
+
|
| 122 |
+
return float(value) if np.isscalar(x) else value
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def inverse_tolerance(x,
|
| 126 |
+
bounds=(0.0, 0.0),
|
| 127 |
+
margin=0.0,
|
| 128 |
+
sigmoid='reciprocal'):
|
| 129 |
+
"""Returns 0 when `x` falls inside the bounds, between 1 and 0 otherwise.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
x: A scalar or numpy array.
|
| 133 |
+
bounds: A tuple of floats specifying inclusive `(lower, upper)` bounds for
|
| 134 |
+
the target interval. These can be infinite if the interval is unbounded
|
| 135 |
+
at one or both ends, or they can be equal to one another if the target
|
| 136 |
+
value is exact.
|
| 137 |
+
margin: Float. Parameter that controls how steeply the output decreases as
|
| 138 |
+
`x` moves out-of-bounds.
|
| 139 |
+
* If `margin == 0` then the output will be 0 for all values of `x`
|
| 140 |
+
outside of `bounds`.
|
| 141 |
+
* If `margin > 0` then the output will decrease sigmoidally with
|
| 142 |
+
increasing distance from the nearest bound.
|
| 143 |
+
sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian',
|
| 144 |
+
'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
|
| 145 |
+
value_at_margin: A float between 0 and 1 specifying the output value when
|
| 146 |
+
the distance from `x` to the nearest bound is equal to `margin`. Ignored
|
| 147 |
+
if `margin == 0`.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
A float or numpy array with values between 0.0 and 1.0.
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
ValueError: If `bounds[0] > bounds[1]`.
|
| 154 |
+
ValueError: If `margin` is negative.
|
| 155 |
+
"""
|
| 156 |
+
bound = tolerance(x,
|
| 157 |
+
bounds=bounds,
|
| 158 |
+
margin=margin,
|
| 159 |
+
sigmoid=sigmoid,
|
| 160 |
+
value_at_margin=0)
|
| 161 |
+
return 1 - bound
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def rect_prism_tolerance(curr, zero, one):
|
| 165 |
+
"""Computes a reward if curr is inside a rectangluar prism region.
|
| 166 |
+
|
| 167 |
+
The 3d points curr and zero specify 2 diagonal corners of a rectangular
|
| 168 |
+
prism that represents the decreasing region.
|
| 169 |
+
|
| 170 |
+
one represents the corner of the prism that has a reward of 1.
|
| 171 |
+
zero represents the diagonal opposite corner of the prism that has a reward
|
| 172 |
+
of 0.
|
| 173 |
+
Curr is the point that the prism reward region is being applied for.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
curr(np.ndarray): The point who's reward is being assessed.
|
| 177 |
+
shape is (3,).
|
| 178 |
+
zero(np.ndarray): One corner of the rectangular prism, with reward 0.
|
| 179 |
+
shape is (3,)
|
| 180 |
+
one(np.ndarray): The diagonal opposite corner of one, with reward 1.
|
| 181 |
+
shape is (3,)
|
| 182 |
+
"""
|
| 183 |
+
in_range = lambda a, b, c: float(b <= a <=c) if c >= b else float(c <= a <= b)
|
| 184 |
+
in_prism = (in_range(curr[0], zero[0], one[0]) and
|
| 185 |
+
in_range(curr[1], zero[1], one[1]) and
|
| 186 |
+
in_range(curr[2], zero[2], one[2]))
|
| 187 |
+
if in_prism:
|
| 188 |
+
diff = one - zero
|
| 189 |
+
x_scale = (curr[0] - zero[0]) / diff[0]
|
| 190 |
+
y_scale = (curr[1] - zero[1]) / diff[1]
|
| 191 |
+
z_scale = (curr[2] - zero[2]) / diff[2]
|
| 192 |
+
return x_scale * y_scale * z_scale
|
| 193 |
+
# return 0.01
|
| 194 |
+
else:
|
| 195 |
+
return 1.
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def hamacher_product(a, b):
|
| 200 |
+
"""The hamacher (t-norm) product of a and b.
|
| 201 |
+
|
| 202 |
+
computes (a * b) / ((a + b) - (a * b))
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
a (float): 1st term of hamacher product.
|
| 206 |
+
b (float): 2nd term of hamacher product.
|
| 207 |
+
Raises:
|
| 208 |
+
ValueError: a and b must range between 0 and 1
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
float: The hammacher product of a and b
|
| 212 |
+
"""
|
| 213 |
+
if not ((0. <= a <= 1.) and (0. <= b <= 1.)):
|
| 214 |
+
raise ValueError("a and b must range between 0 and 1")
|
| 215 |
+
|
| 216 |
+
denominator = a + b - (a * b)
|
| 217 |
+
h_prod = ((a * b) / denominator) if denominator > 0 else 0
|
| 218 |
+
|
| 219 |
+
assert 0. <= h_prod <= 1.
|
| 220 |
+
return h_prod
|
Metaworld/metaworld/policies/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (9.44 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_bin_picking_v2_policy.cpython-38.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_box_close_v1_policy.cpython-38.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_button_press_topdown_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_button_press_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_button_press_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_button_press_wall_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_button_press_wall_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_coffee_pull_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.94 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_coffee_push_v2_policy.cpython-38.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_dial_turn_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_disassemble_v1_policy.cpython-38.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_disassemble_v2_policy.cpython-38.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_door_close_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_door_lock_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_door_lock_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_door_open_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_door_unlock_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_door_unlock_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_drawer_close_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.64 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_drawer_open_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_faucet_close_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_hammer_v1_policy.cpython-38.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_hammer_v2_policy.cpython-38.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_handle_press_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_handle_press_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.55 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_handle_pull_v1_policy.cpython-38.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_handle_pull_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.71 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_peg_insertion_side_v2_policy.cpython-38.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
Metaworld/metaworld/policies/__pycache__/sawyer_peg_unplug_side_v2_policy.cpython-38.pyc
ADDED
|
Binary file (1.97 kB). View file
|
|
|