Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml +25 -0
- Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml +26 -0
- Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml +23 -0
- Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml +35 -0
- Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml +22 -0
- Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 +3 -0
- Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 +3 -0
- dexart-release/assets/sapien/102697/cues.txt +5 -0
- dexart-release/assets/sapien/102697/meta.json +1 -0
- dexart-release/assets/sapien/102697/mobility.urdf +502 -0
- dexart-release/assets/sapien/102697/mobility_v2.json +1 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl +12 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl +12 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj +109 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl +12 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj +217 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj +250 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj +169 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj +113 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj +105 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl +12 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl +12 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl +12 -0
- dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl +12 -0
- dexart-release/assets/sapien/102697/result.json +1 -0
- dexart-release/assets/sapien/102697/result_original.json +1 -0
- dexart-release/assets/sapien/102697/semantics.txt +5 -0
- dexart-release/dexart.egg-info/PKG-INFO +12 -0
- dexart-release/dexart.egg-info/SOURCES.txt +51 -0
- dexart-release/dexart.egg-info/dependency_links.txt +1 -0
- dexart-release/dexart.egg-info/requires.txt +4 -0
- dexart-release/dexart.egg-info/top_level.txt +1 -0
- dexart-release/examples/gen_demonstration_expert.py +238 -0
- dexart-release/examples/train.py +124 -0
- dexart-release/examples/utils.py +66 -0
- dexart-release/stable_baselines3/a2c/__init__.py +2 -0
- dexart-release/stable_baselines3/a2c/a2c.py +207 -0
- dexart-release/stable_baselines3/a2c/policies.py +7 -0
- dexart-release/stable_baselines3/common/__init__.py +0 -0
- dexart-release/stable_baselines3/common/base_class.py +835 -0
- dexart-release/stable_baselines3/common/buffers.py +1010 -0
- dexart-release/stable_baselines3/common/callbacks.py +602 -0
- dexart-release/stable_baselines3/common/distributions.py +699 -0
- dexart-release/stable_baselines3/common/env_util.py +104 -0
- dexart-release/stable_baselines3/common/evaluation.py +131 -0
- dexart-release/stable_baselines3/common/logger.py +644 -0
- dexart-release/stable_baselines3/common/monitor.py +239 -0
- dexart-release/stable_baselines3/common/noise.py +167 -0
- dexart-release/stable_baselines3/common/on_policy_algorithm.py +320 -0
.gitattributes
CHANGED
|
@@ -43,3 +43,5 @@ Metaworld/zarr_path:[[:space:]]data/metaworld_disassemble_expert.zarr/data/point
|
|
| 43 |
Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_hammer_expert.zarr/data/depth/0.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/16.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_drawer-open_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 43 |
Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_hammer_expert.zarr/data/depth/0.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/16.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_drawer-open_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
Metaworld/zarr_path:[[:space:]]data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 filter=lfs diff=lfs merge=lfs -text
|
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<mujoco>
|
| 2 |
+
<include file="../scene/basic_scene.xml"/>
|
| 3 |
+
<include file="../objects/assets/dial_dependencies.xml"/>
|
| 4 |
+
<include file="../objects/assets/xyz_base_dependencies.xml"/>
|
| 5 |
+
<worldbody>
|
| 6 |
+
<include file="../objects/assets/xyz_base.xml"/>
|
| 7 |
+
|
| 8 |
+
<body name="dial" pos="0 0.7 0.">
|
| 9 |
+
<include file="../objects/assets/dial.xml"/>
|
| 10 |
+
<site name="dialStart" pos="0 -0.05 0.035" size="0.005" rgba="0 0 1 1"/>
|
| 11 |
+
</body>
|
| 12 |
+
|
| 13 |
+
<site name="goal" pos="0. 0.74 0.07" size="0.02"
|
| 14 |
+
rgba=".8 0 0 1"/>
|
| 15 |
+
|
| 16 |
+
--> </worldbody>
|
| 17 |
+
|
| 18 |
+
<actuator>
|
| 19 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
|
| 20 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
|
| 21 |
+
</actuator>
|
| 22 |
+
<equality>
|
| 23 |
+
<weld body1="mocap" body2="hand" solref="0.02 1"></weld>
|
| 24 |
+
</equality>
|
| 25 |
+
</mujoco>
|
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<mujoco>
|
| 2 |
+
<include file="../scene/basic_scene.xml"/>
|
| 3 |
+
<include file="../objects/assets/doorlock_dependencies.xml"/>
|
| 4 |
+
<include file="../objects/assets/xyz_base_dependencies.xml"/>
|
| 5 |
+
|
| 6 |
+
<worldbody>
|
| 7 |
+
|
| 8 |
+
<include file="../objects/assets/xyz_base.xml"/>
|
| 9 |
+
|
| 10 |
+
<body name="door" pos="0 0.9 0.15">
|
| 11 |
+
<include file="../objects/assets/doorlockA.xml"/>
|
| 12 |
+
</body>
|
| 13 |
+
|
| 14 |
+
<site name="goal_lock" pos="0 0.74 0.12" size="0.01"
|
| 15 |
+
rgba="0 0.8 0 1"/>
|
| 16 |
+
<site name="goal_unlock" pos="0.09 0.74 0.211" size="0.01"
|
| 17 |
+
rgba="0 0 0.8 1"/>
|
| 18 |
+
</worldbody>
|
| 19 |
+
<actuator>
|
| 20 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
|
| 21 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
|
| 22 |
+
</actuator>
|
| 23 |
+
<equality>
|
| 24 |
+
<weld body1="mocap" body2="hand" solref="0.02 1"/>
|
| 25 |
+
</equality>
|
| 26 |
+
</mujoco>
|
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<mujoco>
|
| 2 |
+
<include file="../scene/basic_scene.xml"/>
|
| 3 |
+
<include file="../objects/assets/doorlock_dependencies.xml"/>
|
| 4 |
+
<include file="../objects/assets/xyz_base_dependencies.xml"/>
|
| 5 |
+
|
| 6 |
+
<worldbody>
|
| 7 |
+
<include file="../objects/assets/xyz_base.xml"/>
|
| 8 |
+
|
| 9 |
+
<body name="door" pos="0 0.9 0.15">
|
| 10 |
+
<include file="../objects/assets/doorlockB.xml"/>
|
| 11 |
+
</body>
|
| 12 |
+
|
| 13 |
+
<site name="goal" pos="-0.49 0.46 0.15" size="0.02"
|
| 14 |
+
rgba="0 0.8 0 1"/>
|
| 15 |
+
</worldbody>
|
| 16 |
+
<actuator>
|
| 17 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
|
| 18 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
|
| 19 |
+
</actuator>
|
| 20 |
+
<equality>
|
| 21 |
+
<weld body1="mocap" body2="hand" solref="0.02 1"/>
|
| 22 |
+
</equality>
|
| 23 |
+
</mujoco>
|
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<mujoco>
|
| 2 |
+
<include file="../scene/basic_scene.xml"/>
|
| 3 |
+
<include file="../objects/assets/faucet_dependencies.xml"/>
|
| 4 |
+
<include file="../objects/assets/xyz_base_dependencies.xml"/>
|
| 5 |
+
<worldbody>
|
| 6 |
+
<include file="../objects/assets/xyz_base.xml"/>
|
| 7 |
+
|
| 8 |
+
<body name="faucetBase" pos="0 0.8 0">
|
| 9 |
+
<include file="../objects/assets/faucet.xml"/>
|
| 10 |
+
|
| 11 |
+
</body>
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
<site name="goal_open" pos="0.175 0.785 0.125" size="0.005"
|
| 15 |
+
rgba="0 .8 0 1"/>
|
| 16 |
+
<site name="goal_close" pos="-0.175 0.785 0.125" size="0.005"
|
| 17 |
+
rgba="0 0 0.8 1"/>
|
| 18 |
+
|
| 19 |
+
<!-- <body name="box" pos="0 0.8 0.05">
|
| 20 |
+
<geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.1 0.05 0.05" name="box_left" conaffinity="1" pos="0 0 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
|
| 21 |
+
<geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.1 0.05 0.05" name="box_right" conaffinity="1" pos="0 0.16 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
|
| 22 |
+
<geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.035 0.03 0.05" name="box_front" conaffinity="1" pos="0.065 0.08 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
|
| 23 |
+
<geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.035 0.03 0.05" name="box_behind" conaffinity="1" pos="-0.065 0.08 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
|
| 24 |
+
<joint type="slide" range="-0.2 0." axis="0 1 0" name="goal_slidey" pos="0 0 0" damping="1.0"/>
|
| 25 |
+
</body>
|
| 26 |
+
--> </worldbody>
|
| 27 |
+
|
| 28 |
+
<actuator>
|
| 29 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
|
| 30 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
|
| 31 |
+
</actuator>
|
| 32 |
+
<equality>
|
| 33 |
+
<weld body1="mocap" body2="hand" solref="0.02 1"></weld>
|
| 34 |
+
</equality>
|
| 35 |
+
</mujoco>
|
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<mujoco>
|
| 2 |
+
<include file="../scene/basic_scene.xml"/>
|
| 3 |
+
<include file="../objects/assets/laptop_dependencies.xml"/>
|
| 4 |
+
<include file="../objects/assets/xyz_base_dependencies.xml"/>
|
| 5 |
+
|
| 6 |
+
<worldbody>
|
| 7 |
+
<include file="../objects/assets/xyz_base.xml"/>
|
| 8 |
+
|
| 9 |
+
<body name="laptop" pos="0 0.8 0">
|
| 10 |
+
<include file="../objects/assets/laptop.xml"/>
|
| 11 |
+
|
| 12 |
+
</body>
|
| 13 |
+
</worldbody>
|
| 14 |
+
|
| 15 |
+
<actuator>
|
| 16 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
|
| 17 |
+
<position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
|
| 18 |
+
</actuator>
|
| 19 |
+
<equality>
|
| 20 |
+
<weld body1="mocap" body2="hand" solref="0.02 1"></weld>
|
| 21 |
+
</equality>
|
| 22 |
+
</mujoco>
|
Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:902da15cc5607a827c4a6cf9b7c396bacd2bd244963aa4e36f400f543b23497b
|
| 3 |
+
size 1231019
|
Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32e76433f75e66cccd7b4db7962d3fc1ae4fe8915409efc427235456347323c4
|
| 3 |
+
size 1213234
|
dexart-release/assets/sapien/102697/cues.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
link_0 hinge lever button
|
| 2 |
+
link_1 slider pump_lid lid
|
| 3 |
+
link_2 hinge lid lid
|
| 4 |
+
link_3 hinge seat seat
|
| 5 |
+
link_4 static base_body base_body
|
dexart-release/assets/sapien/102697/meta.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"user_id": "haozhu", "model_cat": "Toilet", "model_id": "db252ecd6286a334733badcb2e574996-0", "version": "1", "anno_id": "2697", "time_in_sec": "31"}
|
dexart-release/assets/sapien/102697/mobility.urdf
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<robot name="partnet_db252ecd6286a334733badcb2e574996-0">
|
| 2 |
+
<link name="base" />
|
| 3 |
+
<link name="link_0">
|
| 4 |
+
<collision>
|
| 5 |
+
<origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
|
| 6 |
+
<geometry>
|
| 7 |
+
<mesh filename="new_objs/102697_link_0_2.obj" />
|
| 8 |
+
</geometry>
|
| 9 |
+
</collision>
|
| 10 |
+
<collision>
|
| 11 |
+
<origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
|
| 12 |
+
<geometry>
|
| 13 |
+
<mesh filename="new_objs/102697_link_0_0.obj" />
|
| 14 |
+
</geometry>
|
| 15 |
+
</collision>
|
| 16 |
+
<collision>
|
| 17 |
+
<origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
|
| 18 |
+
<geometry>
|
| 19 |
+
<mesh filename="new_objs/102697_link_0_1.obj" />
|
| 20 |
+
</geometry>
|
| 21 |
+
</collision>
|
| 22 |
+
<visual name="button-part["id]">
|
| 23 |
+
<origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
|
| 24 |
+
<geometry>
|
| 25 |
+
<mesh filename="textured_objs/original-8.obj" />
|
| 26 |
+
</geometry>
|
| 27 |
+
</visual>
|
| 28 |
+
</link>
|
| 29 |
+
<joint name="joint_0" type="revolute">
|
| 30 |
+
<origin xyz="-0.387179130380233 0.5229989412652956 -0.2586584187924479" />
|
| 31 |
+
<axis xyz="-0.91844856752952 -6.467369864908537e-05 0.39554042096893877" />
|
| 32 |
+
<child link="link_0" />
|
| 33 |
+
<parent link="link_4" />
|
| 34 |
+
<limit lower="0.0" upper="0.5235987755982988" />
|
| 35 |
+
</joint>
|
| 36 |
+
<link name="link_1">
|
| 37 |
+
<collision>
|
| 38 |
+
<origin xyz="0 0 0" />
|
| 39 |
+
<geometry>
|
| 40 |
+
<mesh filename="new_objs/102697_link_1_6.obj" />
|
| 41 |
+
</geometry>
|
| 42 |
+
</collision>
|
| 43 |
+
<collision>
|
| 44 |
+
<origin xyz="0 0 0" />
|
| 45 |
+
<geometry>
|
| 46 |
+
<mesh filename="new_objs/102697_link_1_14.obj" />
|
| 47 |
+
</geometry>
|
| 48 |
+
</collision>
|
| 49 |
+
<collision>
|
| 50 |
+
<origin xyz="0 0 0" />
|
| 51 |
+
<geometry>
|
| 52 |
+
<mesh filename="new_objs/102697_link_1_11.obj" />
|
| 53 |
+
</geometry>
|
| 54 |
+
</collision>
|
| 55 |
+
<collision>
|
| 56 |
+
<origin xyz="0 0 0" />
|
| 57 |
+
<geometry>
|
| 58 |
+
<mesh filename="new_objs/102697_link_1_4.obj" />
|
| 59 |
+
</geometry>
|
| 60 |
+
</collision>
|
| 61 |
+
<collision>
|
| 62 |
+
<origin xyz="0 0 0" />
|
| 63 |
+
<geometry>
|
| 64 |
+
<mesh filename="new_objs/102697_link_1_8.obj" />
|
| 65 |
+
</geometry>
|
| 66 |
+
</collision>
|
| 67 |
+
<collision>
|
| 68 |
+
<origin xyz="0 0 0" />
|
| 69 |
+
<geometry>
|
| 70 |
+
<mesh filename="new_objs/102697_link_1_3.obj" />
|
| 71 |
+
</geometry>
|
| 72 |
+
</collision>
|
| 73 |
+
<collision>
|
| 74 |
+
<origin xyz="0 0 0" />
|
| 75 |
+
<geometry>
|
| 76 |
+
<mesh filename="new_objs/102697_link_1_1.obj" />
|
| 77 |
+
</geometry>
|
| 78 |
+
</collision>
|
| 79 |
+
<collision>
|
| 80 |
+
<origin xyz="0 0 0" />
|
| 81 |
+
<geometry>
|
| 82 |
+
<mesh filename="new_objs/102697_link_1_12.obj" />
|
| 83 |
+
</geometry>
|
| 84 |
+
</collision>
|
| 85 |
+
<collision>
|
| 86 |
+
<origin xyz="0 0 0" />
|
| 87 |
+
<geometry>
|
| 88 |
+
<mesh filename="new_objs/102697_link_1_5.obj" />
|
| 89 |
+
</geometry>
|
| 90 |
+
</collision>
|
| 91 |
+
<collision>
|
| 92 |
+
<origin xyz="0 0 0" />
|
| 93 |
+
<geometry>
|
| 94 |
+
<mesh filename="new_objs/102697_link_1_13.obj" />
|
| 95 |
+
</geometry>
|
| 96 |
+
</collision>
|
| 97 |
+
<collision>
|
| 98 |
+
<origin xyz="0 0 0" />
|
| 99 |
+
<geometry>
|
| 100 |
+
<mesh filename="new_objs/102697_link_1_0.obj" />
|
| 101 |
+
</geometry>
|
| 102 |
+
</collision>
|
| 103 |
+
<collision>
|
| 104 |
+
<origin xyz="0 0 0" />
|
| 105 |
+
<geometry>
|
| 106 |
+
<mesh filename="new_objs/102697_link_1_10.obj" />
|
| 107 |
+
</geometry>
|
| 108 |
+
</collision>
|
| 109 |
+
<collision>
|
| 110 |
+
<origin xyz="0 0 0" />
|
| 111 |
+
<geometry>
|
| 112 |
+
<mesh filename="new_objs/102697_link_1_2.obj" />
|
| 113 |
+
</geometry>
|
| 114 |
+
</collision>
|
| 115 |
+
<collision>
|
| 116 |
+
<origin xyz="0 0 0" />
|
| 117 |
+
<geometry>
|
| 118 |
+
<mesh filename="new_objs/102697_link_1_9.obj" />
|
| 119 |
+
</geometry>
|
| 120 |
+
</collision>
|
| 121 |
+
<collision>
|
| 122 |
+
<origin xyz="0 0 0" />
|
| 123 |
+
<geometry>
|
| 124 |
+
<mesh filename="new_objs/102697_link_1_7.obj" />
|
| 125 |
+
</geometry>
|
| 126 |
+
</collision>
|
| 127 |
+
<visual name="lid-part["id]">
|
| 128 |
+
<origin xyz="0 0 0" />
|
| 129 |
+
<geometry>
|
| 130 |
+
<mesh filename="textured_objs/original-4.obj" />
|
| 131 |
+
</geometry>
|
| 132 |
+
</visual>
|
| 133 |
+
<visual name="lid-part["id]">
|
| 134 |
+
<origin xyz="0 0 0" />
|
| 135 |
+
<geometry>
|
| 136 |
+
<mesh filename="textured_objs/original-3.obj" />
|
| 137 |
+
</geometry>
|
| 138 |
+
</visual>
|
| 139 |
+
</link>
|
| 140 |
+
<joint name="joint_1" type="prismatic">
|
| 141 |
+
<origin xyz="0 0 0" />
|
| 142 |
+
<axis xyz="0 1 0" />
|
| 143 |
+
<child link="link_1" />
|
| 144 |
+
<parent link="link_4" />
|
| 145 |
+
<limit lower="0" upper="0.02" />
|
| 146 |
+
</joint>
|
| 147 |
+
<link name="link_2">
|
| 148 |
+
<collision>
|
| 149 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 150 |
+
<geometry>
|
| 151 |
+
<mesh filename="new_objs/102697_link_2_0.obj" />
|
| 152 |
+
</geometry>
|
| 153 |
+
</collision>
|
| 154 |
+
<visual name="lid-part["id]">
|
| 155 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 156 |
+
<geometry>
|
| 157 |
+
<mesh filename="textured_objs/original-21.obj" />
|
| 158 |
+
</geometry>
|
| 159 |
+
</visual>
|
| 160 |
+
<visual name="lid-part["id]">
|
| 161 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 162 |
+
<geometry>
|
| 163 |
+
<mesh filename="textured_objs/original-20.obj" />
|
| 164 |
+
</geometry>
|
| 165 |
+
</visual>
|
| 166 |
+
</link>
|
| 167 |
+
<joint name="joint_2" type="revolute">
|
| 168 |
+
<origin xyz="0 0.05794601786221419 -0.06205458525801076" />
|
| 169 |
+
<axis xyz="-1 0 0" />
|
| 170 |
+
<child link="link_2" />
|
| 171 |
+
<parent link="link_4" />
|
| 172 |
+
<limit lower="-0.0" upper="1.6406094968746698" />
|
| 173 |
+
</joint>
|
| 174 |
+
<link name="link_3">
|
| 175 |
+
<collision>
|
| 176 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 177 |
+
<geometry>
|
| 178 |
+
<mesh filename="new_objs/102697_link_3_2.obj" />
|
| 179 |
+
</geometry>
|
| 180 |
+
</collision>
|
| 181 |
+
<collision>
|
| 182 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 183 |
+
<geometry>
|
| 184 |
+
<mesh filename="new_objs/102697_link_3_6.obj" />
|
| 185 |
+
</geometry>
|
| 186 |
+
</collision>
|
| 187 |
+
<collision>
|
| 188 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 189 |
+
<geometry>
|
| 190 |
+
<mesh filename="new_objs/102697_link_3_1.obj" />
|
| 191 |
+
</geometry>
|
| 192 |
+
</collision>
|
| 193 |
+
<collision>
|
| 194 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 195 |
+
<geometry>
|
| 196 |
+
<mesh filename="new_objs/102697_link_3_0.obj" />
|
| 197 |
+
</geometry>
|
| 198 |
+
</collision>
|
| 199 |
+
<collision>
|
| 200 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 201 |
+
<geometry>
|
| 202 |
+
<mesh filename="new_objs/102697_link_3_4.obj" />
|
| 203 |
+
</geometry>
|
| 204 |
+
</collision>
|
| 205 |
+
<collision>
|
| 206 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 207 |
+
<geometry>
|
| 208 |
+
<mesh filename="new_objs/102697_link_3_5.obj" />
|
| 209 |
+
</geometry>
|
| 210 |
+
</collision>
|
| 211 |
+
<collision>
|
| 212 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 213 |
+
<geometry>
|
| 214 |
+
<mesh filename="new_objs/102697_link_3_7.obj" />
|
| 215 |
+
</geometry>
|
| 216 |
+
</collision>
|
| 217 |
+
<collision>
|
| 218 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 219 |
+
<geometry>
|
| 220 |
+
<mesh filename="new_objs/102697_link_3_3.obj" />
|
| 221 |
+
</geometry>
|
| 222 |
+
</collision>
|
| 223 |
+
<visual name="seat-part["id]">
|
| 224 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 225 |
+
<geometry>
|
| 226 |
+
<mesh filename="textured_objs/original-18.obj" />
|
| 227 |
+
</geometry>
|
| 228 |
+
</visual>
|
| 229 |
+
<visual name="seat-part["id]">
|
| 230 |
+
<origin xyz="0 -0.05794601786221419 0.06205458525801076" />
|
| 231 |
+
<geometry>
|
| 232 |
+
<mesh filename="textured_objs/original-17.obj" />
|
| 233 |
+
</geometry>
|
| 234 |
+
</visual>
|
| 235 |
+
</link>
|
| 236 |
+
<joint name="joint_3" type="revolute">
|
| 237 |
+
<origin xyz="0 0.05794601786221419 -0.06205458525801076" />
|
| 238 |
+
<axis xyz="-1 0 0" />
|
| 239 |
+
<child link="link_3" />
|
| 240 |
+
<parent link="link_4" />
|
| 241 |
+
<limit lower="-0.0" upper="1.6406094968746698" />
|
| 242 |
+
</joint>
|
| 243 |
+
<link name="link_4">
|
| 244 |
+
<collision>
|
| 245 |
+
<origin xyz="0 0 0" />
|
| 246 |
+
<geometry>
|
| 247 |
+
<mesh filename="new_objs/102697_link_4_0.obj" />
|
| 248 |
+
</geometry>
|
| 249 |
+
</collision>
|
| 250 |
+
<collision>
|
| 251 |
+
<origin xyz="0 0 0" />
|
| 252 |
+
<geometry>
|
| 253 |
+
<mesh filename="new_objs/102697_link_4_18.obj" />
|
| 254 |
+
</geometry>
|
| 255 |
+
</collision>
|
| 256 |
+
<collision>
|
| 257 |
+
<origin xyz="0 0 0" />
|
| 258 |
+
<geometry>
|
| 259 |
+
<mesh filename="new_objs/102697_link_4_22.obj" />
|
| 260 |
+
</geometry>
|
| 261 |
+
</collision>
|
| 262 |
+
<collision>
|
| 263 |
+
<origin xyz="0 0 0" />
|
| 264 |
+
<geometry>
|
| 265 |
+
<mesh filename="new_objs/102697_link_4_12.obj" />
|
| 266 |
+
</geometry>
|
| 267 |
+
</collision>
|
| 268 |
+
<collision>
|
| 269 |
+
<origin xyz="0 0 0" />
|
| 270 |
+
<geometry>
|
| 271 |
+
<mesh filename="new_objs/102697_link_4_27.obj" />
|
| 272 |
+
</geometry>
|
| 273 |
+
</collision>
|
| 274 |
+
<collision>
|
| 275 |
+
<origin xyz="0 0 0" />
|
| 276 |
+
<geometry>
|
| 277 |
+
<mesh filename="new_objs/102697_link_4_23.obj" />
|
| 278 |
+
</geometry>
|
| 279 |
+
</collision>
|
| 280 |
+
<collision>
|
| 281 |
+
<origin xyz="0 0 0" />
|
| 282 |
+
<geometry>
|
| 283 |
+
<mesh filename="new_objs/102697_link_4_33.obj" />
|
| 284 |
+
</geometry>
|
| 285 |
+
</collision>
|
| 286 |
+
<collision>
|
| 287 |
+
<origin xyz="0 0 0" />
|
| 288 |
+
<geometry>
|
| 289 |
+
<mesh filename="new_objs/102697_link_4_2.obj" />
|
| 290 |
+
</geometry>
|
| 291 |
+
</collision>
|
| 292 |
+
<collision>
|
| 293 |
+
<origin xyz="0 0 0" />
|
| 294 |
+
<geometry>
|
| 295 |
+
<mesh filename="new_objs/102697_link_4_13.obj" />
|
| 296 |
+
</geometry>
|
| 297 |
+
</collision>
|
| 298 |
+
<collision>
|
| 299 |
+
<origin xyz="0 0 0" />
|
| 300 |
+
<geometry>
|
| 301 |
+
<mesh filename="new_objs/102697_link_4_28.obj" />
|
| 302 |
+
</geometry>
|
| 303 |
+
</collision>
|
| 304 |
+
<collision>
|
| 305 |
+
<origin xyz="0 0 0" />
|
| 306 |
+
<geometry>
|
| 307 |
+
<mesh filename="new_objs/102697_link_4_34.obj" />
|
| 308 |
+
</geometry>
|
| 309 |
+
</collision>
|
| 310 |
+
<collision>
|
| 311 |
+
<origin xyz="0 0 0" />
|
| 312 |
+
<geometry>
|
| 313 |
+
<mesh filename="new_objs/102697_link_4_7.obj" />
|
| 314 |
+
</geometry>
|
| 315 |
+
</collision>
|
| 316 |
+
<collision>
|
| 317 |
+
<origin xyz="0 0 0" />
|
| 318 |
+
<geometry>
|
| 319 |
+
<mesh filename="new_objs/102697_link_4_1.obj" />
|
| 320 |
+
</geometry>
|
| 321 |
+
</collision>
|
| 322 |
+
<collision>
|
| 323 |
+
<origin xyz="0 0 0" />
|
| 324 |
+
<geometry>
|
| 325 |
+
<mesh filename="new_objs/102697_link_4_10.obj" />
|
| 326 |
+
</geometry>
|
| 327 |
+
</collision>
|
| 328 |
+
<collision>
|
| 329 |
+
<origin xyz="0 0 0" />
|
| 330 |
+
<geometry>
|
| 331 |
+
<mesh filename="new_objs/102697_link_4_6.obj" />
|
| 332 |
+
</geometry>
|
| 333 |
+
</collision>
|
| 334 |
+
<collision>
|
| 335 |
+
<origin xyz="0 0 0" />
|
| 336 |
+
<geometry>
|
| 337 |
+
<mesh filename="new_objs/102697_link_4_19.obj" />
|
| 338 |
+
</geometry>
|
| 339 |
+
</collision>
|
| 340 |
+
<collision>
|
| 341 |
+
<origin xyz="0 0 0" />
|
| 342 |
+
<geometry>
|
| 343 |
+
<mesh filename="new_objs/102697_link_4_3.obj" />
|
| 344 |
+
</geometry>
|
| 345 |
+
</collision>
|
| 346 |
+
<collision>
|
| 347 |
+
<origin xyz="0 0 0" />
|
| 348 |
+
<geometry>
|
| 349 |
+
<mesh filename="new_objs/102697_link_4_26.obj" />
|
| 350 |
+
</geometry>
|
| 351 |
+
</collision>
|
| 352 |
+
<collision>
|
| 353 |
+
<origin xyz="0 0 0" />
|
| 354 |
+
<geometry>
|
| 355 |
+
<mesh filename="new_objs/102697_link_4_8.obj" />
|
| 356 |
+
</geometry>
|
| 357 |
+
</collision>
|
| 358 |
+
<collision>
|
| 359 |
+
<origin xyz="0 0 0" />
|
| 360 |
+
<geometry>
|
| 361 |
+
<mesh filename="new_objs/102697_link_4_25.obj" />
|
| 362 |
+
</geometry>
|
| 363 |
+
</collision>
|
| 364 |
+
<collision>
|
| 365 |
+
<origin xyz="0 0 0" />
|
| 366 |
+
<geometry>
|
| 367 |
+
<mesh filename="new_objs/102697_link_4_11.obj" />
|
| 368 |
+
</geometry>
|
| 369 |
+
</collision>
|
| 370 |
+
<collision>
|
| 371 |
+
<origin xyz="0 0 0" />
|
| 372 |
+
<geometry>
|
| 373 |
+
<mesh filename="new_objs/102697_link_4_21.obj" />
|
| 374 |
+
</geometry>
|
| 375 |
+
</collision>
|
| 376 |
+
<collision>
|
| 377 |
+
<origin xyz="0 0 0" />
|
| 378 |
+
<geometry>
|
| 379 |
+
<mesh filename="new_objs/102697_link_4_9.obj" />
|
| 380 |
+
</geometry>
|
| 381 |
+
</collision>
|
| 382 |
+
<collision>
|
| 383 |
+
<origin xyz="0 0 0" />
|
| 384 |
+
<geometry>
|
| 385 |
+
<mesh filename="new_objs/102697_link_4_17.obj" />
|
| 386 |
+
</geometry>
|
| 387 |
+
</collision>
|
| 388 |
+
<collision>
|
| 389 |
+
<origin xyz="0 0 0" />
|
| 390 |
+
<geometry>
|
| 391 |
+
<mesh filename="new_objs/102697_link_4_30.obj" />
|
| 392 |
+
</geometry>
|
| 393 |
+
</collision>
|
| 394 |
+
<collision>
|
| 395 |
+
<origin xyz="0 0 0" />
|
| 396 |
+
<geometry>
|
| 397 |
+
<mesh filename="new_objs/102697_link_4_16.obj" />
|
| 398 |
+
</geometry>
|
| 399 |
+
</collision>
|
| 400 |
+
<collision>
|
| 401 |
+
<origin xyz="0 0 0" />
|
| 402 |
+
<geometry>
|
| 403 |
+
<mesh filename="new_objs/102697_link_4_31.obj" />
|
| 404 |
+
</geometry>
|
| 405 |
+
</collision>
|
| 406 |
+
<collision>
|
| 407 |
+
<origin xyz="0 0 0" />
|
| 408 |
+
<geometry>
|
| 409 |
+
<mesh filename="new_objs/102697_link_4_20.obj" />
|
| 410 |
+
</geometry>
|
| 411 |
+
</collision>
|
| 412 |
+
<collision>
|
| 413 |
+
<origin xyz="0 0 0" />
|
| 414 |
+
<geometry>
|
| 415 |
+
<mesh filename="new_objs/102697_link_4_5.obj" />
|
| 416 |
+
</geometry>
|
| 417 |
+
</collision>
|
| 418 |
+
<collision>
|
| 419 |
+
<origin xyz="0 0 0" />
|
| 420 |
+
<geometry>
|
| 421 |
+
<mesh filename="new_objs/102697_link_4_29.obj" />
|
| 422 |
+
</geometry>
|
| 423 |
+
</collision>
|
| 424 |
+
<collision>
|
| 425 |
+
<origin xyz="0 0 0" />
|
| 426 |
+
<geometry>
|
| 427 |
+
<mesh filename="new_objs/102697_link_4_15.obj" />
|
| 428 |
+
</geometry>
|
| 429 |
+
</collision>
|
| 430 |
+
<collision>
|
| 431 |
+
<origin xyz="0 0 0" />
|
| 432 |
+
<geometry>
|
| 433 |
+
<mesh filename="new_objs/102697_link_4_4.obj" />
|
| 434 |
+
</geometry>
|
| 435 |
+
</collision>
|
| 436 |
+
<collision>
|
| 437 |
+
<origin xyz="0 0 0" />
|
| 438 |
+
<geometry>
|
| 439 |
+
<mesh filename="new_objs/102697_link_4_24.obj" />
|
| 440 |
+
</geometry>
|
| 441 |
+
</collision>
|
| 442 |
+
<collision>
|
| 443 |
+
<origin xyz="0 0 0" />
|
| 444 |
+
<geometry>
|
| 445 |
+
<mesh filename="new_objs/102697_link_4_32.obj" />
|
| 446 |
+
</geometry>
|
| 447 |
+
</collision>
|
| 448 |
+
<collision>
|
| 449 |
+
<origin xyz="0 0 0" />
|
| 450 |
+
<geometry>
|
| 451 |
+
<mesh filename="new_objs/102697_link_4_14.obj" />
|
| 452 |
+
</geometry>
|
| 453 |
+
</collision>
|
| 454 |
+
<visual name="base_body-part["id]">
|
| 455 |
+
<origin xyz="0 0 0" />
|
| 456 |
+
<geometry>
|
| 457 |
+
<mesh filename="textured_objs/original-13.obj" />
|
| 458 |
+
</geometry>
|
| 459 |
+
</visual>
|
| 460 |
+
<visual name="base_body-part["id]">
|
| 461 |
+
<origin xyz="0 0 0" />
|
| 462 |
+
<geometry>
|
| 463 |
+
<mesh filename="textured_objs/original-6.obj" />
|
| 464 |
+
</geometry>
|
| 465 |
+
</visual>
|
| 466 |
+
<visual name="base_body-part["id]">
|
| 467 |
+
<origin xyz="0 0 0" />
|
| 468 |
+
<geometry>
|
| 469 |
+
<mesh filename="textured_objs/original-15.obj" />
|
| 470 |
+
</geometry>
|
| 471 |
+
</visual>
|
| 472 |
+
<visual name="base_body-part["id]">
|
| 473 |
+
<origin xyz="0 0 0" />
|
| 474 |
+
<geometry>
|
| 475 |
+
<mesh filename="textured_objs/original-11.obj" />
|
| 476 |
+
</geometry>
|
| 477 |
+
</visual>
|
| 478 |
+
<visual name="base_body-part["id]">
|
| 479 |
+
<origin xyz="0 0 0" />
|
| 480 |
+
<geometry>
|
| 481 |
+
<mesh filename="textured_objs/original-7.obj" />
|
| 482 |
+
</geometry>
|
| 483 |
+
</visual>
|
| 484 |
+
<visual name="base_body-part["id]">
|
| 485 |
+
<origin xyz="0 0 0" />
|
| 486 |
+
<geometry>
|
| 487 |
+
<mesh filename="textured_objs/original-10.obj" />
|
| 488 |
+
</geometry>
|
| 489 |
+
</visual>
|
| 490 |
+
<visual name="base_body-part["id]">
|
| 491 |
+
<origin xyz="0 0 0" />
|
| 492 |
+
<geometry>
|
| 493 |
+
<mesh filename="textured_objs/original-14.obj" />
|
| 494 |
+
</geometry>
|
| 495 |
+
</visual>
|
| 496 |
+
</link>
|
| 497 |
+
<joint name="joint_4" type="fixed">
|
| 498 |
+
<origin rpy="1.570796326794897 0 -1.570796326794897" xyz="0 0 0" />
|
| 499 |
+
<child link="link_4" />
|
| 500 |
+
<parent link="base" />
|
| 501 |
+
</joint>
|
| 502 |
+
</robot>
|
dexart-release/assets/sapien/102697/mobility_v2.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[{"id":0,"parent":4,"joint":"hinge","name":"lever","parts":[{"id":1,"name":"button"}],"jointData":{"axis":{"origin":[-0.387179130380233,0.5229989412652956,-0.2586584187924479],"direction":[-0.91844856752952,-0.00006467369864908537,0.39554042096893877]},"limit":{"a":0,"b":30,"noLimit":false}}},{"id":1,"parent":4,"joint":"slider","name":"pump_lid","parts":[{"id":2,"name":"lid"}],"jointData":{"axis":{"origin":[0,0,0],"direction":[0,1,0]},"limit":{"a":0,"b":0.02,"noLimit":false,"rotates":false,"noRotationLimit":false,"rotationLimit":0}}},{"id":2,"parent":4,"joint":"hinge","name":"lid","parts":[{"id":3,"name":"lid"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":3,"parent":4,"joint":"hinge","name":"seat","parts":[{"id":4,"name":"seat"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":4,"parent":-1,"joint":"static","name":"base_body","parts":[{"id":5,"name":"base_body"}],"jointData":{}}]
|
dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender MTL File: 'None'
|
| 2 |
+
# Material Count: 1
|
| 3 |
+
|
| 4 |
+
newmtl Shape.14067
|
| 5 |
+
Ns 400.000000
|
| 6 |
+
Ka 0.400000 0.400000 0.400000
|
| 7 |
+
Kd 0.336000 0.232000 0.584000
|
| 8 |
+
Ks 0.250000 0.250000 0.250000
|
| 9 |
+
Ke 0.000000 0.000000 0.000000
|
| 10 |
+
Ni 1.000000
|
| 11 |
+
d 0.500000
|
| 12 |
+
illum 2
|
dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender MTL File: 'None'
|
| 2 |
+
# Material Count: 1
|
| 3 |
+
|
| 4 |
+
newmtl Shape.14069
|
| 5 |
+
Ns 400.000000
|
| 6 |
+
Ka 0.400000 0.400000 0.400000
|
| 7 |
+
Kd 0.296000 0.784000 0.192000
|
| 8 |
+
Ks 0.250000 0.250000 0.250000
|
| 9 |
+
Ke 0.000000 0.000000 0.000000
|
| 10 |
+
Ni 1.000000
|
| 11 |
+
d 0.500000
|
| 12 |
+
illum 2
|
dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender v2.79 (sub 0) OBJ File: ''
|
| 2 |
+
# www.blender.org
|
| 3 |
+
mtllib 102697_link_1_14.mtl
|
| 4 |
+
o Shape_IndexedFaceSet.014_Shape_IndexedFaceSet.14061
|
| 5 |
+
v 0.381286 0.716318 -0.278277
|
| 6 |
+
v 0.467366 0.669121 -0.458755
|
| 7 |
+
v 0.467366 0.666342 -0.458755
|
| 8 |
+
v 0.314641 0.669121 -0.486500
|
| 9 |
+
v 0.314641 0.666342 -0.275488
|
| 10 |
+
v 0.314641 0.716318 -0.486500
|
| 11 |
+
v 0.442356 0.716318 -0.486520
|
| 12 |
+
v 0.406252 0.666342 -0.275488
|
| 13 |
+
v 0.314641 0.716318 -0.275488
|
| 14 |
+
v 0.453478 0.702433 -0.433757
|
| 15 |
+
v 0.464367 0.666416 -0.486207
|
| 16 |
+
v 0.397925 0.702433 -0.275488
|
| 17 |
+
v 0.386833 0.666342 -0.486500
|
| 18 |
+
v 0.447917 0.713539 -0.453197
|
| 19 |
+
v 0.433456 0.670746 -0.350149
|
| 20 |
+
v 0.411813 0.705207 -0.311600
|
| 21 |
+
v 0.453478 0.710760 -0.486520
|
| 22 |
+
v 0.459025 0.702433 -0.455986
|
| 23 |
+
v 0.397925 0.713539 -0.300506
|
| 24 |
+
v 0.409033 0.674673 -0.283855
|
| 25 |
+
v 0.430634 0.693883 -0.355457
|
| 26 |
+
v 0.461805 0.669121 -0.436525
|
| 27 |
+
v 0.447917 0.710760 -0.439314
|
| 28 |
+
vn 0.6201 0.7564 0.2082
|
| 29 |
+
vn -1.0000 0.0000 0.0000
|
| 30 |
+
vn 0.0000 1.0000 0.0000
|
| 31 |
+
vn -0.0002 0.0000 -1.0000
|
| 32 |
+
vn 0.0000 -1.0000 0.0000
|
| 33 |
+
vn 0.0000 0.0000 1.0000
|
| 34 |
+
vn 0.9941 0.0000 -0.1086
|
| 35 |
+
vn 0.0406 0.2433 0.9691
|
| 36 |
+
vn -0.0385 -0.9992 -0.0132
|
| 37 |
+
vn 0.0010 -1.0000 -0.0028
|
| 38 |
+
vn 0.1781 0.9826 0.0522
|
| 39 |
+
vn -0.0001 -0.0002 -1.0000
|
| 40 |
+
vn 0.9633 0.2356 -0.1285
|
| 41 |
+
vn -0.0000 -0.0004 -1.0000
|
| 42 |
+
vn 0.0038 -0.0061 -1.0000
|
| 43 |
+
vn 0.4470 0.8945 -0.0000
|
| 44 |
+
vn 0.7133 0.6982 0.0608
|
| 45 |
+
vn 0.9470 0.2175 0.2363
|
| 46 |
+
vn 0.9624 0.2498 -0.1067
|
| 47 |
+
vn 0.5705 0.7506 0.3332
|
| 48 |
+
vn 0.2834 0.9545 0.0928
|
| 49 |
+
vn 0.6574 0.6887 0.3057
|
| 50 |
+
vn 0.8546 0.1972 0.4804
|
| 51 |
+
vn 0.9385 0.0321 0.3438
|
| 52 |
+
vn 0.8973 0.2493 0.3642
|
| 53 |
+
vn 0.9155 0.2217 0.3357
|
| 54 |
+
vn 0.8943 0.3344 0.2974
|
| 55 |
+
vn 0.9251 0.1885 0.3296
|
| 56 |
+
vn 0.9350 0.1837 0.3034
|
| 57 |
+
vn 0.9701 0.0000 0.2427
|
| 58 |
+
vn 0.8018 -0.5344 0.2674
|
| 59 |
+
vn 0.9470 0.2170 0.2369
|
| 60 |
+
vn 0.8672 -0.4031 0.2922
|
| 61 |
+
vn 0.9325 0.2086 0.2948
|
| 62 |
+
vn 0.7291 0.6431 0.2341
|
| 63 |
+
vn 0.7174 0.6831 0.1368
|
| 64 |
+
vn 0.7541 0.6292 0.1882
|
| 65 |
+
vn 0.5142 0.8410 0.1683
|
| 66 |
+
usemtl Shape.14069
|
| 67 |
+
s off
|
| 68 |
+
f 19//1 16//1 23//1
|
| 69 |
+
f 4//2 5//2 6//2
|
| 70 |
+
f 6//3 1//3 7//3
|
| 71 |
+
f 4//4 6//4 7//4
|
| 72 |
+
f 5//5 3//5 8//5
|
| 73 |
+
f 5//6 8//6 9//6
|
| 74 |
+
f 1//3 6//3 9//3
|
| 75 |
+
f 6//2 5//2 9//2
|
| 76 |
+
f 2//7 3//7 11//7
|
| 77 |
+
f 9//6 8//6 12//6
|
| 78 |
+
f 1//8 9//8 12//8
|
| 79 |
+
f 5//9 4//9 13//9
|
| 80 |
+
f 3//5 5//5 13//5
|
| 81 |
+
f 11//10 3//10 13//10
|
| 82 |
+
f 7//11 1//11 14//11
|
| 83 |
+
f 4//12 7//12 17//12
|
| 84 |
+
f 2//13 11//13 17//13
|
| 85 |
+
f 13//14 4//14 17//14
|
| 86 |
+
f 11//15 13//15 17//15
|
| 87 |
+
f 7//16 14//16 17//16
|
| 88 |
+
f 17//17 14//17 18//17
|
| 89 |
+
f 10//18 2//18 18//18
|
| 90 |
+
f 2//19 17//19 18//19
|
| 91 |
+
f 1//20 12//20 19//20
|
| 92 |
+
f 14//21 1//21 19//21
|
| 93 |
+
f 12//22 16//22 19//22
|
| 94 |
+
f 12//23 8//23 20//23
|
| 95 |
+
f 8//24 15//24 20//24
|
| 96 |
+
f 16//25 12//25 20//25
|
| 97 |
+
f 16//26 20//26 21//26
|
| 98 |
+
f 10//27 16//27 21//27
|
| 99 |
+
f 20//28 15//28 21//28
|
| 100 |
+
f 21//29 15//29 22//29
|
| 101 |
+
f 3//30 2//30 22//30
|
| 102 |
+
f 8//31 3//31 22//31
|
| 103 |
+
f 2//32 10//32 22//32
|
| 104 |
+
f 15//33 8//33 22//33
|
| 105 |
+
f 10//34 21//34 22//34
|
| 106 |
+
f 16//35 10//35 23//35
|
| 107 |
+
f 18//36 14//36 23//36
|
| 108 |
+
f 10//37 18//37 23//37
|
| 109 |
+
f 14//38 19//38 23//38
|
dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender MTL File: 'None'
|
| 2 |
+
# Material Count: 1
|
| 3 |
+
|
| 4 |
+
newmtl Shape.14060
|
| 5 |
+
Ns 400.000000
|
| 6 |
+
Ka 0.400000 0.400000 0.400000
|
| 7 |
+
Kd 0.576000 0.288000 0.088000
|
| 8 |
+
Ks 0.250000 0.250000 0.250000
|
| 9 |
+
Ke 0.000000 0.000000 0.000000
|
| 10 |
+
Ni 1.000000
|
| 11 |
+
d 0.500000
|
| 12 |
+
illum 2
|
dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender v2.79 (sub 0) OBJ File: ''
|
| 2 |
+
# www.blender.org
|
| 3 |
+
mtllib 102697_link_3_0.mtl
|
| 4 |
+
o Shape_IndexedFaceSet_Shape_IndexedFaceSet.16309
|
| 5 |
+
v -0.318776 0.060655 0.107693
|
| 6 |
+
v -0.063870 0.047615 -0.076335
|
| 7 |
+
v -0.006959 0.042434 0.076198
|
| 8 |
+
v -0.006959 0.088979 -0.019441
|
| 9 |
+
v -0.309458 0.088979 0.114981
|
| 10 |
+
v -0.182789 0.086387 -0.042715
|
| 11 |
+
v -0.257765 0.042434 0.006414
|
| 12 |
+
v -0.312052 0.042434 0.114981
|
| 13 |
+
v -0.167256 0.073461 0.114981
|
| 14 |
+
v -0.006959 0.042434 -0.068570
|
| 15 |
+
v -0.006959 0.073461 0.060669
|
| 16 |
+
v -0.006959 0.088979 -0.065989
|
| 17 |
+
v -0.262922 0.086387 0.008996
|
| 18 |
+
v -0.151785 0.042434 0.114981
|
| 19 |
+
v -0.154809 0.063235 -0.062488
|
| 20 |
+
v -0.255171 0.088979 0.114981
|
| 21 |
+
v -0.265516 0.065707 -0.001350
|
| 22 |
+
v -0.058683 0.083799 -0.073734
|
| 23 |
+
v -0.155045 0.046362 -0.055734
|
| 24 |
+
v -0.006959 0.055369 0.078780
|
| 25 |
+
v -0.317239 0.052781 0.091707
|
| 26 |
+
v -0.309458 0.083799 0.086544
|
| 27 |
+
v -0.069027 0.088979 -0.065989
|
| 28 |
+
v -0.006959 0.070873 -0.078916
|
| 29 |
+
v -0.159536 0.083799 -0.058243
|
| 30 |
+
v -0.107812 0.042434 -0.063407
|
| 31 |
+
v -0.257765 0.045027 -0.001350
|
| 32 |
+
v -0.159811 0.051020 -0.059069
|
| 33 |
+
v -0.306895 0.088979 0.096890
|
| 34 |
+
v -0.149191 0.060536 0.114981
|
| 35 |
+
v -0.056120 0.078632 -0.076335
|
| 36 |
+
v -0.314646 0.045027 0.096890
|
| 37 |
+
v -0.006959 0.047615 -0.076335
|
| 38 |
+
v -0.319802 0.076049 0.114981
|
| 39 |
+
v -0.260328 0.078632 -0.001350
|
| 40 |
+
v -0.312052 0.068285 0.081361
|
| 41 |
+
v -0.250915 0.087266 0.009661
|
| 42 |
+
v -0.006959 0.070873 0.065852
|
| 43 |
+
v -0.304301 0.050193 0.068434
|
| 44 |
+
v -0.268110 0.070873 0.003833
|
| 45 |
+
v -0.006959 0.060536 0.076198
|
| 46 |
+
v -0.006959 0.086387 -0.071152
|
| 47 |
+
v -0.006959 0.078632 0.034833
|
| 48 |
+
v -0.167107 0.053534 -0.056233
|
| 49 |
+
v -0.162264 0.065707 -0.059140
|
| 50 |
+
v -0.317239 0.045027 0.114981
|
| 51 |
+
v -0.198290 0.081220 -0.037551
|
| 52 |
+
v -0.247420 0.042434 -0.001350
|
| 53 |
+
vn -0.2189 -0.8734 -0.4350
|
| 54 |
+
vn 0.0000 -1.0000 0.0000
|
| 55 |
+
vn 0.0000 0.0000 1.0000
|
| 56 |
+
vn 1.0000 0.0000 0.0000
|
| 57 |
+
vn -0.0000 1.0000 0.0000
|
| 58 |
+
vn 0.2540 -0.1893 0.9485
|
| 59 |
+
vn -0.0785 0.9576 -0.2772
|
| 60 |
+
vn -0.0984 0.7614 -0.6407
|
| 61 |
+
vn -0.0223 -0.8995 -0.4364
|
| 62 |
+
vn -0.1804 -0.1959 -0.9639
|
| 63 |
+
vn -0.1633 -0.5893 -0.7912
|
| 64 |
+
vn -0.1721 -0.6838 -0.7091
|
| 65 |
+
vn -0.3065 -0.7412 -0.5972
|
| 66 |
+
vn -0.5117 0.8123 -0.2799
|
| 67 |
+
vn 0.2435 0.3404 0.9082
|
| 68 |
+
vn 0.2453 -0.0352 0.9688
|
| 69 |
+
vn -0.1444 0.0361 -0.9889
|
| 70 |
+
vn -0.0504 0.0126 -0.9986
|
| 71 |
+
vn -0.1621 0.1636 -0.9731
|
| 72 |
+
vn -0.1399 0.3890 -0.9106
|
| 73 |
+
vn -0.2162 -0.9704 -0.1081
|
| 74 |
+
vn -0.4800 -0.8321 -0.2779
|
| 75 |
+
vn 0.0000 -0.8318 -0.5550
|
| 76 |
+
vn 0.0000 -0.1103 -0.9939
|
| 77 |
+
vn -0.9962 -0.0274 -0.0823
|
| 78 |
+
vn -0.7761 0.6209 -0.1100
|
| 79 |
+
vn -0.7791 0.6162 -0.1155
|
| 80 |
+
vn -0.3952 0.6848 -0.6123
|
| 81 |
+
vn -0.9554 0.1496 -0.2548
|
| 82 |
+
vn -0.9305 0.2462 -0.2713
|
| 83 |
+
vn -0.0669 0.9924 -0.1037
|
| 84 |
+
vn -0.0356 0.9974 -0.0631
|
| 85 |
+
vn -0.0693 0.9955 -0.0640
|
| 86 |
+
vn -0.0240 0.9991 -0.0350
|
| 87 |
+
vn 0.1545 0.8754 0.4580
|
| 88 |
+
vn 0.1519 0.8843 0.4415
|
| 89 |
+
vn -0.8035 -0.3012 -0.5135
|
| 90 |
+
vn -0.7750 -0.5093 -0.3742
|
| 91 |
+
vn -0.6037 -0.7165 -0.3495
|
| 92 |
+
vn -0.8716 -0.0235 -0.4897
|
| 93 |
+
vn -0.8750 -0.0297 -0.4832
|
| 94 |
+
vn -0.7861 0.4152 -0.4579
|
| 95 |
+
vn -0.7546 0.4201 -0.5041
|
| 96 |
+
vn -0.7121 0.2858 -0.6413
|
| 97 |
+
vn -0.8694 0.0560 -0.4909
|
| 98 |
+
vn -0.8355 0.2946 -0.4637
|
| 99 |
+
vn 0.2448 0.4334 0.8673
|
| 100 |
+
vn 0.2223 0.6897 0.6891
|
| 101 |
+
vn 0.0000 0.8937 -0.4487
|
| 102 |
+
vn -0.0128 0.8231 -0.5677
|
| 103 |
+
vn 0.0237 0.4474 -0.8940
|
| 104 |
+
vn 0.0214 0.4580 -0.8887
|
| 105 |
+
vn 0.1009 0.9773 0.1863
|
| 106 |
+
vn 0.1037 0.9753 0.1952
|
| 107 |
+
vn -0.4960 -0.1859 -0.8482
|
| 108 |
+
vn -0.3894 -0.0969 -0.9159
|
| 109 |
+
vn -0.4480 -0.3961 -0.8015
|
| 110 |
+
vn -0.3790 0.1027 -0.9197
|
| 111 |
+
vn -0.4224 -0.0481 -0.9051
|
| 112 |
+
vn -0.4884 -0.0141 -0.8725
|
| 113 |
+
vn -0.9926 -0.1156 -0.0385
|
| 114 |
+
vn -0.4462 -0.8926 -0.0640
|
| 115 |
+
vn -0.9107 -0.3918 -0.1305
|
| 116 |
+
vn -0.9960 -0.0823 0.0336
|
| 117 |
+
vn -0.4229 0.5449 -0.7240
|
| 118 |
+
vn -0.4253 0.5896 -0.6867
|
| 119 |
+
vn -0.5000 0.2007 -0.8425
|
| 120 |
+
vn -0.4868 0.0799 -0.8698
|
| 121 |
+
vn -0.4738 0.1147 -0.8731
|
| 122 |
+
vn -0.1247 -0.9517 -0.2806
|
| 123 |
+
vn -0.2313 -0.9228 -0.3082
|
| 124 |
+
usemtl Shape.16319
|
| 125 |
+
s off
|
| 126 |
+
f 27//1 19//1 48//1
|
| 127 |
+
f 7//2 3//2 8//2
|
| 128 |
+
f 5//3 8//3 9//3
|
| 129 |
+
f 4//4 3//4 10//4
|
| 130 |
+
f 3//2 7//2 10//2
|
| 131 |
+
f 3//4 4//4 11//4
|
| 132 |
+
f 5//5 4//5 12//5
|
| 133 |
+
f 4//4 10//4 12//4
|
| 134 |
+
f 8//2 3//2 14//2
|
| 135 |
+
f 9//3 8//3 14//3
|
| 136 |
+
f 4//5 5//5 16//5
|
| 137 |
+
f 5//3 9//3 16//3
|
| 138 |
+
f 3//4 11//4 20//4
|
| 139 |
+
f 14//6 3//6 20//6
|
| 140 |
+
f 5//5 12//5 23//5
|
| 141 |
+
f 12//4 10//4 24//4
|
| 142 |
+
f 6//7 23//7 25//7
|
| 143 |
+
f 23//8 18//8 25//8
|
| 144 |
+
f 2//9 10//9 26//9
|
| 145 |
+
f 10//2 7//2 26//2
|
| 146 |
+
f 15//10 2//10 28//10
|
| 147 |
+
f 2//11 26//11 28//11
|
| 148 |
+
f 26//12 19//12 28//12
|
| 149 |
+
f 19//13 27//13 28//13
|
| 150 |
+
f 13//14 22//14 29//14
|
| 151 |
+
f 5//5 23//5 29//5
|
| 152 |
+
f 9//3 14//3 30//3
|
| 153 |
+
f 20//15 9//15 30//15
|
| 154 |
+
f 14//16 20//16 30//16
|
| 155 |
+
f 2//17 15//17 31//17
|
| 156 |
+
f 24//18 2//18 31//18
|
| 157 |
+
f 15//19 25//19 31//19
|
| 158 |
+
f 25//20 18//20 31//20
|
| 159 |
+
f 7//21 8//21 32//21
|
| 160 |
+
f 27//22 7//22 32//22
|
| 161 |
+
f 10//23 2//23 33//23
|
| 162 |
+
f 2//24 24//24 33//24
|
| 163 |
+
f 24//4 10//4 33//4
|
| 164 |
+
f 8//3 5//3 34//3
|
| 165 |
+
f 21//25 1//25 34//25
|
| 166 |
+
f 5//26 29//26 34//26
|
| 167 |
+
f 29//27 22//27 34//27
|
| 168 |
+
f 13//28 6//28 35//28
|
| 169 |
+
f 21//29 34//29 36//29
|
| 170 |
+
f 34//30 22//30 36//30
|
| 171 |
+
f 6//31 13//31 37//31
|
| 172 |
+
f 23//32 6//32 37//32
|
| 173 |
+
f 13//33 29//33 37//33
|
| 174 |
+
f 29//34 23//34 37//34
|
| 175 |
+
f 16//35 9//35 38//35
|
| 176 |
+
f 11//36 16//36 38//36
|
| 177 |
+
f 20//4 11//4 38//4
|
| 178 |
+
f 17//37 27//37 39//37
|
| 179 |
+
f 32//38 21//38 39//38
|
| 180 |
+
f 27//39 32//39 39//39
|
| 181 |
+
f 36//40 17//40 39//40
|
| 182 |
+
f 21//41 36//41 39//41
|
| 183 |
+
f 22//42 13//42 40//42
|
| 184 |
+
f 13//43 35//43 40//43
|
| 185 |
+
f 35//44 17//44 40//44
|
| 186 |
+
f 17//45 36//45 40//45
|
| 187 |
+
f 36//46 22//46 40//46
|
| 188 |
+
f 9//47 20//47 41//47
|
| 189 |
+
f 38//48 9//48 41//48
|
| 190 |
+
f 20//4 38//4 41//4
|
| 191 |
+
f 23//49 12//49 42//49
|
| 192 |
+
f 18//50 23//50 42//50
|
| 193 |
+
f 12//4 24//4 42//4
|
| 194 |
+
f 24//51 31//51 42//51
|
| 195 |
+
f 31//52 18//52 42//52
|
| 196 |
+
f 11//4 4//4 43//4
|
| 197 |
+
f 4//53 16//53 43//53
|
| 198 |
+
f 16//54 11//54 43//54
|
| 199 |
+
f 27//55 17//55 44//55
|
| 200 |
+
f 15//56 28//56 44//56
|
| 201 |
+
f 28//57 27//57 44//57
|
| 202 |
+
f 25//58 15//58 45//58
|
| 203 |
+
f 15//59 44//59 45//59
|
| 204 |
+
f 44//60 17//60 45//60
|
| 205 |
+
f 1//61 21//61 46//61
|
| 206 |
+
f 32//62 8//62 46//62
|
| 207 |
+
f 21//63 32//63 46//63
|
| 208 |
+
f 34//64 1//64 46//64
|
| 209 |
+
f 8//3 34//3 46//3
|
| 210 |
+
f 6//65 25//65 47//65
|
| 211 |
+
f 35//66 6//66 47//66
|
| 212 |
+
f 17//67 35//67 47//67
|
| 213 |
+
f 45//68 17//68 47//68
|
| 214 |
+
f 25//69 45//69 47//69
|
| 215 |
+
f 26//2 7//2 48//2
|
| 216 |
+
f 19//70 26//70 48//70
|
| 217 |
+
f 7//71 27//71 48//71
|
dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender v2.79 (sub 0) OBJ File: ''
|
| 2 |
+
# www.blender.org
|
| 3 |
+
mtllib 102697_link_3_5.mtl
|
| 4 |
+
o Shape_IndexedFaceSet.005_Shape_IndexedFaceSet.16314
|
| 5 |
+
v -0.006959 0.055369 0.663225
|
| 6 |
+
v -0.035413 0.042434 0.789889
|
| 7 |
+
v -0.006985 0.088979 0.766613
|
| 8 |
+
v -0.224145 0.088979 0.575274
|
| 9 |
+
v -0.112985 0.042434 0.575296
|
| 10 |
+
v -0.244834 0.042434 0.624424
|
| 11 |
+
v -0.216380 0.083799 0.676129
|
| 12 |
+
v -0.128515 0.070873 0.575274
|
| 13 |
+
v -0.154338 0.047615 0.743337
|
| 14 |
+
v -0.061288 0.088979 0.779539
|
| 15 |
+
v -0.006959 0.042434 0.792466
|
| 16 |
+
v -0.006985 0.073461 0.683903
|
| 17 |
+
v -0.006959 0.042434 0.668378
|
| 18 |
+
v -0.260364 0.088979 0.590820
|
| 19 |
+
v -0.030254 0.068285 0.800239
|
| 20 |
+
v -0.221565 0.065707 0.678706
|
| 21 |
+
v -0.268103 0.042434 0.575274
|
| 22 |
+
v -0.146599 0.081220 0.745913
|
| 23 |
+
v -0.081951 0.055369 0.784714
|
| 24 |
+
v -0.006985 0.088979 0.789889
|
| 25 |
+
v -0.270683 0.060536 0.598550
|
| 26 |
+
v -0.102641 0.042434 0.764037
|
| 27 |
+
v -0.206061 0.055369 0.696785
|
| 28 |
+
v -0.150775 0.086767 0.728688
|
| 29 |
+
v -0.087137 0.076049 0.779539
|
| 30 |
+
v -0.154338 0.065707 0.745913
|
| 31 |
+
v -0.262944 0.081220 0.603725
|
| 32 |
+
v -0.273262 0.081220 0.575274
|
| 33 |
+
v -0.262944 0.047615 0.606323
|
| 34 |
+
v -0.032834 0.047615 0.797640
|
| 35 |
+
v -0.006985 0.068294 0.800239
|
| 36 |
+
v -0.244834 0.088979 0.619249
|
| 37 |
+
v -0.030254 0.083799 0.795064
|
| 38 |
+
v -0.006985 0.070878 0.676129
|
| 39 |
+
v -0.087137 0.050193 0.782138
|
| 40 |
+
v -0.006985 0.081215 0.722659
|
| 41 |
+
v -0.213800 0.078632 0.683881
|
| 42 |
+
v -0.242254 0.052781 0.645102
|
| 43 |
+
v -0.180186 0.081220 0.575274
|
| 44 |
+
v -0.143993 0.083799 0.745913
|
| 45 |
+
v -0.079372 0.083799 0.779539
|
| 46 |
+
v -0.209848 0.046828 0.679494
|
| 47 |
+
v -0.262944 0.088979 0.575274
|
| 48 |
+
v -0.009565 0.047615 0.660627
|
| 49 |
+
v -0.092296 0.068285 0.779539
|
| 50 |
+
v -0.159523 0.070873 0.740739
|
| 51 |
+
v -0.275868 0.068285 0.583047
|
| 52 |
+
v -0.273262 0.045027 0.575274
|
| 53 |
+
v -0.032834 0.055369 0.800239
|
| 54 |
+
v -0.125910 0.068285 0.575274
|
| 55 |
+
v -0.006985 0.063123 0.668378
|
| 56 |
+
v -0.074186 0.045027 0.782138
|
| 57 |
+
v -0.204443 0.044694 0.679584
|
| 58 |
+
vn -0.3528 -0.8802 0.3176
|
| 59 |
+
vn 0.0000 -1.0000 -0.0000
|
| 60 |
+
vn -0.0000 1.0000 -0.0000
|
| 61 |
+
vn 1.0000 0.0000 0.0000
|
| 62 |
+
vn 0.0000 0.0000 -1.0000
|
| 63 |
+
vn 1.0000 0.0008 0.0000
|
| 64 |
+
vn 1.0000 0.0006 0.0001
|
| 65 |
+
vn -0.6739 -0.1041 0.7314
|
| 66 |
+
vn -0.8426 0.1891 0.5042
|
| 67 |
+
vn -0.7059 0.6605 0.2560
|
| 68 |
+
vn -0.0513 -0.8223 0.5667
|
| 69 |
+
vn 1.0000 0.0006 0.0013
|
| 70 |
+
vn -0.1903 0.9645 0.1830
|
| 71 |
+
vn -0.0704 0.9943 0.0806
|
| 72 |
+
vn -0.5935 0.7366 0.3242
|
| 73 |
+
vn -0.6160 0.6947 0.3714
|
| 74 |
+
vn -0.0901 0.8767 0.4726
|
| 75 |
+
vn 0.0988 0.4453 0.8899
|
| 76 |
+
vn -0.0001 0.3164 0.9486
|
| 77 |
+
vn 1.0000 0.0023 -0.0008
|
| 78 |
+
vn 0.2531 0.9181 -0.3050
|
| 79 |
+
vn -0.3097 -0.7486 0.5862
|
| 80 |
+
vn -0.4927 -0.1227 0.8615
|
| 81 |
+
vn 1.0000 0.0017 -0.0003
|
| 82 |
+
vn 0.1515 0.9734 -0.1719
|
| 83 |
+
vn 1.0000 0.0019 -0.0004
|
| 84 |
+
vn 0.1659 0.9670 -0.1935
|
| 85 |
+
vn -0.7073 0.1481 0.6912
|
| 86 |
+
vn -0.7933 0.3506 0.4977
|
| 87 |
+
vn -0.8130 0.2852 0.5077
|
| 88 |
+
vn -0.8528 0.0077 0.5221
|
| 89 |
+
vn -0.7975 -0.2017 0.5686
|
| 90 |
+
vn -0.6081 -0.6770 0.4146
|
| 91 |
+
vn -0.8486 -0.2184 0.4819
|
| 92 |
+
vn 0.1702 0.9644 -0.2025
|
| 93 |
+
vn 0.1908 0.9528 -0.2362
|
| 94 |
+
vn -0.2498 0.9330 0.2591
|
| 95 |
+
vn -0.1520 0.9623 0.2257
|
| 96 |
+
vn -0.5596 0.5915 0.5805
|
| 97 |
+
vn -0.5675 0.5734 0.5909
|
| 98 |
+
vn -0.2916 0.2921 0.9108
|
| 99 |
+
vn -0.4189 0.4197 0.8052
|
| 100 |
+
vn -0.2076 0.7248 0.6569
|
| 101 |
+
vn -0.2872 0.3031 0.9086
|
| 102 |
+
vn -0.2434 0.8497 0.4677
|
| 103 |
+
vn -0.4183 0.4227 0.8039
|
| 104 |
+
vn -0.5264 -0.7109 0.4664
|
| 105 |
+
vn -0.5764 -0.6997 0.4220
|
| 106 |
+
vn -0.6036 -0.6544 0.4554
|
| 107 |
+
vn -0.5981 0.7953 0.0993
|
| 108 |
+
vn 0.6344 0.0453 -0.7717
|
| 109 |
+
vn 0.8968 -0.1637 -0.4110
|
| 110 |
+
vn 0.5171 -0.6211 -0.5890
|
| 111 |
+
vn -0.3140 0.1256 0.9411
|
| 112 |
+
vn -0.3097 0.2058 0.9283
|
| 113 |
+
vn -0.4499 0.2990 0.8416
|
| 114 |
+
vn -0.4707 0.2348 0.8505
|
| 115 |
+
vn -0.4460 0.0014 0.8950
|
| 116 |
+
vn -0.4762 -0.0095 0.8793
|
| 117 |
+
vn -0.5346 0.2667 0.8020
|
| 118 |
+
vn -0.6914 0.0290 0.7219
|
| 119 |
+
vn -0.6165 0.4454 0.6493
|
| 120 |
+
vn -0.7048 0.1501 0.6933
|
| 121 |
+
vn -0.8834 0.2281 0.4094
|
| 122 |
+
vn -0.8746 0.3668 0.3172
|
| 123 |
+
vn -0.9481 0.0000 -0.3179
|
| 124 |
+
vn -0.4393 -0.8740 0.2080
|
| 125 |
+
vn -0.4713 -0.8521 0.2276
|
| 126 |
+
vn -0.8850 -0.3363 0.3221
|
| 127 |
+
vn -0.9562 -0.1834 0.2281
|
| 128 |
+
vn -0.3008 0.0601 0.9518
|
| 129 |
+
vn 0.1250 -0.3153 0.9407
|
| 130 |
+
vn 0.1424 -0.2848 0.9480
|
| 131 |
+
vn 0.0000 0.0000 1.0000
|
| 132 |
+
vn -0.2970 -0.1701 0.9396
|
| 133 |
+
vn -0.2748 -0.3056 0.9117
|
| 134 |
+
vn 0.5887 0.2937 -0.7531
|
| 135 |
+
vn 0.0001 -0.0008 -1.0000
|
| 136 |
+
vn 0.9999 0.0100 -0.0100
|
| 137 |
+
vn 0.5062 0.6097 -0.6100
|
| 138 |
+
vn 0.5596 0.4600 -0.6894
|
| 139 |
+
vn 0.5341 0.5376 -0.6524
|
| 140 |
+
vn -0.1298 -0.9323 0.3377
|
| 141 |
+
vn -0.1701 -0.7922 0.5861
|
| 142 |
+
vn -0.3013 -0.7554 0.5819
|
| 143 |
+
vn -0.2439 -0.6115 0.7527
|
| 144 |
+
vn -0.1402 -0.9798 0.1428
|
| 145 |
+
vn -0.1678 -0.9699 0.1763
|
| 146 |
+
vn -0.3550 -0.8867 0.2963
|
| 147 |
+
usemtl Shape.16324
|
| 148 |
+
s off
|
| 149 |
+
f 9//1 42//1 53//1
|
| 150 |
+
f 5//2 2//2 6//2
|
| 151 |
+
f 3//3 4//3 10//3
|
| 152 |
+
f 2//2 5//2 11//2
|
| 153 |
+
f 1//4 11//4 13//4
|
| 154 |
+
f 11//2 5//2 13//2
|
| 155 |
+
f 10//3 4//3 14//3
|
| 156 |
+
f 5//2 6//2 17//2
|
| 157 |
+
f 4//5 8//5 17//5
|
| 158 |
+
f 1//6 3//6 20//6
|
| 159 |
+
f 3//3 10//3 20//3
|
| 160 |
+
f 11//7 1//7 20//7
|
| 161 |
+
f 6//2 2//2 22//2
|
| 162 |
+
f 23//8 9//8 26//8
|
| 163 |
+
f 21//9 16//9 27//9
|
| 164 |
+
f 4//5 17//5 28//5
|
| 165 |
+
f 27//10 14//10 28//10
|
| 166 |
+
f 2//11 11//11 30//11
|
| 167 |
+
f 11//12 20//12 31//12
|
| 168 |
+
f 10//3 14//3 32//3
|
| 169 |
+
f 7//13 24//13 32//13
|
| 170 |
+
f 24//14 10//14 32//14
|
| 171 |
+
f 14//15 27//15 32//15
|
| 172 |
+
f 27//16 7//16 32//16
|
| 173 |
+
f 20//17 10//17 33//17
|
| 174 |
+
f 31//18 20//18 33//18
|
| 175 |
+
f 15//19 31//19 33//19
|
| 176 |
+
f 12//20 1//20 34//20
|
| 177 |
+
f 8//21 12//21 34//21
|
| 178 |
+
f 9//22 22//22 35//22
|
| 179 |
+
f 26//23 9//23 35//23
|
| 180 |
+
f 3//24 1//24 36//24
|
| 181 |
+
f 4//25 3//25 36//25
|
| 182 |
+
f 1//26 12//26 36//26
|
| 183 |
+
f 12//27 4//27 36//27
|
| 184 |
+
f 16//28 23//28 37//28
|
| 185 |
+
f 7//29 27//29 37//29
|
| 186 |
+
f 27//30 16//30 37//30
|
| 187 |
+
f 16//31 21//31 38//31
|
| 188 |
+
f 23//32 16//32 38//32
|
| 189 |
+
f 29//33 6//33 38//33
|
| 190 |
+
f 21//34 29//34 38//34
|
| 191 |
+
f 8//5 4//5 39//5
|
| 192 |
+
f 4//35 12//35 39//35
|
| 193 |
+
f 12//36 8//36 39//36
|
| 194 |
+
f 24//37 7//37 40//37
|
| 195 |
+
f 10//38 24//38 40//38
|
| 196 |
+
f 7//39 37//39 40//39
|
| 197 |
+
f 37//40 18//40 40//40
|
| 198 |
+
f 25//41 15//41 41//41
|
| 199 |
+
f 18//42 25//42 41//42
|
| 200 |
+
f 33//43 10//43 41//43
|
| 201 |
+
f 15//44 33//44 41//44
|
| 202 |
+
f 10//45 40//45 41//45
|
| 203 |
+
f 40//46 18//46 41//46
|
| 204 |
+
f 9//47 23//47 42//47
|
| 205 |
+
f 38//48 6//48 42//48
|
| 206 |
+
f 23//49 38//49 42//49
|
| 207 |
+
f 14//3 4//3 43//3
|
| 208 |
+
f 4//5 28//5 43//5
|
| 209 |
+
f 28//50 14//50 43//50
|
| 210 |
+
f 5//51 1//51 44//51
|
| 211 |
+
f 1//52 13//52 44//52
|
| 212 |
+
f 13//53 5//53 44//53
|
| 213 |
+
f 19//54 15//54 45//54
|
| 214 |
+
f 15//55 25//55 45//55
|
| 215 |
+
f 25//56 18//56 45//56
|
| 216 |
+
f 18//57 26//57 45//57
|
| 217 |
+
f 35//58 19//58 45//58
|
| 218 |
+
f 26//59 35//59 45//59
|
| 219 |
+
f 26//60 18//60 46//60
|
| 220 |
+
f 23//61 26//61 46//61
|
| 221 |
+
f 18//62 37//62 46//62
|
| 222 |
+
f 37//63 23//63 46//63
|
| 223 |
+
f 21//64 27//64 47//64
|
| 224 |
+
f 27//65 28//65 47//65
|
| 225 |
+
f 47//66 28//66 48//66
|
| 226 |
+
f 17//67 6//67 48//67
|
| 227 |
+
f 28//5 17//5 48//5
|
| 228 |
+
f 6//68 29//68 48//68
|
| 229 |
+
f 29//69 21//69 48//69
|
| 230 |
+
f 21//70 47//70 48//70
|
| 231 |
+
f 15//71 19//71 49//71
|
| 232 |
+
f 30//72 11//72 49//72
|
| 233 |
+
f 11//73 31//73 49//73
|
| 234 |
+
f 31//74 15//74 49//74
|
| 235 |
+
f 19//75 35//75 49//75
|
| 236 |
+
f 35//76 30//76 49//76
|
| 237 |
+
f 1//77 5//77 50//77
|
| 238 |
+
f 5//78 17//78 50//78
|
| 239 |
+
f 17//5 8//5 50//5
|
| 240 |
+
f 34//79 1//79 51//79
|
| 241 |
+
f 8//80 34//80 51//80
|
| 242 |
+
f 1//81 50//81 51//81
|
| 243 |
+
f 50//82 8//82 51//82
|
| 244 |
+
f 22//83 2//83 52//83
|
| 245 |
+
f 2//84 30//84 52//84
|
| 246 |
+
f 35//85 22//85 52//85
|
| 247 |
+
f 30//86 35//86 52//86
|
| 248 |
+
f 6//87 22//87 53//87
|
| 249 |
+
f 22//88 9//88 53//88
|
| 250 |
+
f 42//89 6//89 53//89
|
dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender v2.79 (sub 0) OBJ File: ''
|
| 2 |
+
# www.blender.org
|
| 3 |
+
mtllib 102697_link_3_6.mtl
|
| 4 |
+
o Shape_IndexedFaceSet.006_Shape_IndexedFaceSet.16315
|
| 5 |
+
v 0.267139 0.042434 0.557173
|
| 6 |
+
v 0.174056 0.055369 0.316720
|
| 7 |
+
v 0.220597 0.088979 0.557173
|
| 8 |
+
v 0.117181 0.047615 0.557173
|
| 9 |
+
v 0.321419 0.088979 0.316720
|
| 10 |
+
v 0.176650 0.042434 0.316720
|
| 11 |
+
v 0.192149 0.073461 0.316720
|
| 12 |
+
v 0.272305 0.086387 0.557173
|
| 13 |
+
v 0.324013 0.042434 0.316720
|
| 14 |
+
v 0.132681 0.070873 0.549382
|
| 15 |
+
v 0.311086 0.052781 0.466641
|
| 16 |
+
v 0.272305 0.088979 0.316720
|
| 17 |
+
v 0.124942 0.042434 0.557173
|
| 18 |
+
v 0.308492 0.081220 0.461463
|
| 19 |
+
v 0.137869 0.073461 0.557173
|
| 20 |
+
v 0.282638 0.068285 0.549382
|
| 21 |
+
v 0.298159 0.088979 0.469231
|
| 22 |
+
v 0.334346 0.068290 0.319286
|
| 23 |
+
v 0.308492 0.045027 0.458873
|
| 24 |
+
v 0.124942 0.065707 0.554560
|
| 25 |
+
v 0.334346 0.052781 0.316697
|
| 26 |
+
v 0.184389 0.070873 0.321875
|
| 27 |
+
v 0.285232 0.050193 0.539048
|
| 28 |
+
v 0.324013 0.083799 0.358079
|
| 29 |
+
v 0.267139 0.088979 0.557173
|
| 30 |
+
v 0.137869 0.073461 0.551971
|
| 31 |
+
v 0.119775 0.057957 0.554560
|
| 32 |
+
v 0.174056 0.045027 0.316720
|
| 33 |
+
v 0.316253 0.042434 0.383948
|
| 34 |
+
v 0.331774 0.047610 0.321875
|
| 35 |
+
v 0.329180 0.083803 0.316697
|
| 36 |
+
v 0.272305 0.045027 0.557173
|
| 37 |
+
v 0.220597 0.088979 0.551971
|
| 38 |
+
v 0.119775 0.045027 0.551971
|
| 39 |
+
v 0.313680 0.088979 0.386514
|
| 40 |
+
v 0.282638 0.042434 0.520946
|
| 41 |
+
v 0.179222 0.063128 0.316720
|
| 42 |
+
vn -0.0002 0.0002 -1.0000
|
| 43 |
+
vn 0.0000 0.0000 1.0000
|
| 44 |
+
vn 0.0000 -1.0000 -0.0000
|
| 45 |
+
vn 0.0000 1.0000 0.0000
|
| 46 |
+
vn 0.8767 0.3663 0.3117
|
| 47 |
+
vn 0.9441 0.1404 0.2983
|
| 48 |
+
vn 0.6868 0.6920 0.2223
|
| 49 |
+
vn 0.9787 0.1197 0.1671
|
| 50 |
+
vn -0.3480 0.2786 0.8951
|
| 51 |
+
vn -0.5230 0.8497 0.0660
|
| 52 |
+
vn 0.0000 -0.0022 -1.0000
|
| 53 |
+
vn 0.9879 -0.0256 0.1532
|
| 54 |
+
vn -0.6125 0.7781 -0.1392
|
| 55 |
+
vn 0.9409 -0.0559 0.3340
|
| 56 |
+
vn 0.8012 -0.5355 0.2670
|
| 57 |
+
vn 0.9537 0.2610 0.1497
|
| 58 |
+
vn 0.4429 0.8828 0.1562
|
| 59 |
+
vn -0.1899 0.9808 -0.0438
|
| 60 |
+
vn -0.1844 0.9829 0.0000
|
| 61 |
+
vn -0.4464 0.8948 0.0000
|
| 62 |
+
vn -0.3652 0.9271 -0.0843
|
| 63 |
+
vn -0.4071 0.9087 -0.0925
|
| 64 |
+
vn -0.9578 0.1845 -0.2206
|
| 65 |
+
vn -0.4906 0.3271 0.8076
|
| 66 |
+
vn -0.9731 0.0000 -0.2302
|
| 67 |
+
vn -0.0001 -0.0001 -1.0000
|
| 68 |
+
vn 0.8869 -0.4394 0.1424
|
| 69 |
+
vn 0.6665 -0.6663 -0.3344
|
| 70 |
+
vn 0.9361 -0.3202 0.1452
|
| 71 |
+
vn 0.5256 -0.8486 0.0607
|
| 72 |
+
vn 0.6247 -0.7755 0.0915
|
| 73 |
+
vn 0.0000 0.0044 -1.0000
|
| 74 |
+
vn -0.0003 0.0014 -1.0000
|
| 75 |
+
vn 0.7031 0.1171 -0.7014
|
| 76 |
+
vn 0.9363 0.3313 0.1169
|
| 77 |
+
vn -0.0001 0.0000 -1.0000
|
| 78 |
+
vn 0.6020 0.0000 0.7985
|
| 79 |
+
vn 0.8242 -0.1871 0.5345
|
| 80 |
+
vn -0.1842 0.9821 -0.0405
|
| 81 |
+
vn -0.5501 -0.8240 0.1356
|
| 82 |
+
vn -0.3801 -0.9213 -0.0817
|
| 83 |
+
vn -0.8628 -0.4647 -0.1991
|
| 84 |
+
vn -0.6977 -0.6980 -0.1610
|
| 85 |
+
vn 0.6532 0.7472 0.1226
|
| 86 |
+
vn 0.6915 0.7120 0.1216
|
| 87 |
+
vn 0.5539 0.8303 0.0614
|
| 88 |
+
vn 0.6054 0.7923 0.0757
|
| 89 |
+
vn 0.6261 -0.7451 0.2297
|
| 90 |
+
vn 0.2367 -0.9698 0.0581
|
| 91 |
+
vn 0.4406 -0.8777 0.1885
|
| 92 |
+
vn 0.6228 -0.7475 0.2311
|
| 93 |
+
vn -0.5476 0.6851 -0.4804
|
| 94 |
+
vn -0.7588 0.6260 -0.1800
|
| 95 |
+
vn -0.8168 0.5439 -0.1923
|
| 96 |
+
vn -0.8165 0.5444 -0.1922
|
| 97 |
+
vn -0.0002 0.0001 -1.0000
|
| 98 |
+
usemtl Shape.16325
|
| 99 |
+
s off
|
| 100 |
+
f 7//1 31//1 37//1
|
| 101 |
+
f 1//2 3//2 4//2
|
| 102 |
+
f 3//2 1//2 8//2
|
| 103 |
+
f 1//3 6//3 9//3
|
| 104 |
+
f 3//4 5//4 12//4
|
| 105 |
+
f 1//2 4//2 13//2
|
| 106 |
+
f 6//3 1//3 13//3
|
| 107 |
+
f 4//2 3//2 15//2
|
| 108 |
+
f 14//5 8//5 16//5
|
| 109 |
+
f 11//6 14//6 16//6
|
| 110 |
+
f 5//4 3//4 17//4
|
| 111 |
+
f 8//7 14//7 17//7
|
| 112 |
+
f 14//8 11//8 18//8
|
| 113 |
+
f 4//9 15//9 20//9
|
| 114 |
+
f 15//10 10//10 20//10
|
| 115 |
+
f 9//11 6//11 21//11
|
| 116 |
+
f 18//12 11//12 21//12
|
| 117 |
+
f 20//13 10//13 22//13
|
| 118 |
+
f 11//14 16//14 23//14
|
| 119 |
+
f 19//15 11//15 23//15
|
| 120 |
+
f 14//16 18//16 24//16
|
| 121 |
+
f 3//2 8//2 25//2
|
| 122 |
+
f 17//4 3//4 25//4
|
| 123 |
+
f 8//17 17//17 25//17
|
| 124 |
+
f 12//18 7//18 26//18
|
| 125 |
+
f 15//19 3//19 26//19
|
| 126 |
+
f 10//20 15//20 26//20
|
| 127 |
+
f 7//21 22//21 26//21
|
| 128 |
+
f 22//22 10//22 26//22
|
| 129 |
+
f 2//23 4//23 27//23
|
| 130 |
+
f 4//24 20//24 27//24
|
| 131 |
+
f 4//25 2//25 28//25
|
| 132 |
+
f 21//26 6//26 28//26
|
| 133 |
+
f 1//3 9//3 29//3
|
| 134 |
+
f 11//27 19//27 30//27
|
| 135 |
+
f 9//28 21//28 30//28
|
| 136 |
+
f 21//29 11//29 30//29
|
| 137 |
+
f 29//30 9//30 30//30
|
| 138 |
+
f 19//31 29//31 30//31
|
| 139 |
+
f 12//32 5//32 31//32
|
| 140 |
+
f 7//33 12//33 31//33
|
| 141 |
+
f 18//34 21//34 31//34
|
| 142 |
+
f 24//35 18//35 31//35
|
| 143 |
+
f 28//36 2//36 31//36
|
| 144 |
+
f 21//36 28//36 31//36
|
| 145 |
+
f 8//2 1//2 32//2
|
| 146 |
+
f 16//37 8//37 32//37
|
| 147 |
+
f 23//38 16//38 32//38
|
| 148 |
+
f 3//4 12//4 33//4
|
| 149 |
+
f 26//19 3//19 33//19
|
| 150 |
+
f 12//39 26//39 33//39
|
| 151 |
+
f 13//40 4//40 34//40
|
| 152 |
+
f 6//41 13//41 34//41
|
| 153 |
+
f 4//42 28//42 34//42
|
| 154 |
+
f 28//43 6//43 34//43
|
| 155 |
+
f 5//4 17//4 35//4
|
| 156 |
+
f 17//44 14//44 35//44
|
| 157 |
+
f 14//45 24//45 35//45
|
| 158 |
+
f 31//46 5//46 35//46
|
| 159 |
+
f 24//47 31//47 35//47
|
| 160 |
+
f 19//48 23//48 36//48
|
| 161 |
+
f 1//3 29//3 36//3
|
| 162 |
+
f 29//49 19//49 36//49
|
| 163 |
+
f 32//50 1//50 36//50
|
| 164 |
+
f 23//51 32//51 36//51
|
| 165 |
+
f 22//52 7//52 37//52
|
| 166 |
+
f 20//53 22//53 37//53
|
| 167 |
+
f 2//54 27//54 37//54
|
| 168 |
+
f 27//55 20//55 37//55
|
| 169 |
+
f 31//56 2//56 37//56
|
dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender v2.79 (sub 0) OBJ File: ''
|
| 2 |
+
# www.blender.org
|
| 3 |
+
mtllib 102697_link_4_11.mtl
|
| 4 |
+
o Shape_IndexedFaceSet.011_Shape_IndexedFaceSet.7497
|
| 5 |
+
v -0.323117 -0.045453 0.089563
|
| 6 |
+
v -0.106232 0.036714 0.031262
|
| 7 |
+
v -0.106232 -0.014976 0.031262
|
| 8 |
+
v -0.207356 0.064630 -0.188472
|
| 9 |
+
v -0.342601 0.058882 0.201139
|
| 10 |
+
v -0.106232 -0.022362 -0.168161
|
| 11 |
+
v -0.269912 -0.053693 0.229603
|
| 12 |
+
v -0.099529 0.061840 -0.178611
|
| 13 |
+
v -0.268728 0.058882 0.223327
|
| 14 |
+
v -0.217054 -0.014976 -0.168161
|
| 15 |
+
v -0.114215 -0.048395 -0.086531
|
| 16 |
+
v -0.354324 -0.050570 0.218569
|
| 17 |
+
v -0.106232 -0.051916 0.009150
|
| 18 |
+
v -0.349976 0.051485 0.178989
|
| 19 |
+
v -0.239154 0.021942 0.215931
|
| 20 |
+
v -0.224454 0.036714 -0.168161
|
| 21 |
+
v -0.170368 -0.047738 -0.078248
|
| 22 |
+
v -0.331607 -0.019798 0.104076
|
| 23 |
+
v -0.246554 -0.022362 0.223327
|
| 24 |
+
v -0.202304 -0.022362 -0.168161
|
| 25 |
+
v -0.126343 0.050010 0.012102
|
| 26 |
+
v -0.246554 0.044099 0.223327
|
| 27 |
+
v -0.224454 -0.000205 -0.168161
|
| 28 |
+
vn -0.9305 0.0000 -0.3663
|
| 29 |
+
vn 0.9995 0.0000 0.0319
|
| 30 |
+
vn 0.0870 -0.1295 -0.9878
|
| 31 |
+
vn -0.0040 0.9999 0.0134
|
| 32 |
+
vn 0.0243 0.9996 0.0176
|
| 33 |
+
vn -0.2829 0.1803 0.9420
|
| 34 |
+
vn -0.1273 0.0565 0.9902
|
| 35 |
+
vn 0.9995 -0.0157 0.0262
|
| 36 |
+
vn 0.9966 -0.0810 -0.0135
|
| 37 |
+
vn 0.8380 -0.5383 -0.0897
|
| 38 |
+
vn -0.4515 0.8806 -0.1437
|
| 39 |
+
vn -0.9523 0.1448 0.2687
|
| 40 |
+
vn 0.8116 0.0000 0.5842
|
| 41 |
+
vn -0.8922 0.3024 -0.3355
|
| 42 |
+
vn -0.0793 -0.9951 -0.0587
|
| 43 |
+
vn -0.0330 -0.9990 -0.0300
|
| 44 |
+
vn -0.0269 -0.9992 -0.0280
|
| 45 |
+
vn -0.0169 -0.9992 -0.0354
|
| 46 |
+
vn -0.9542 -0.1811 -0.2380
|
| 47 |
+
vn -0.9782 -0.0375 -0.2042
|
| 48 |
+
vn -0.9321 0.1193 -0.3421
|
| 49 |
+
vn 0.7252 -0.4335 0.5349
|
| 50 |
+
vn 0.7686 -0.3286 0.5489
|
| 51 |
+
vn 0.8076 -0.0366 0.5886
|
| 52 |
+
vn -0.0000 -0.2274 -0.9738
|
| 53 |
+
vn -0.4300 -0.8588 -0.2785
|
| 54 |
+
vn -0.1163 -0.2322 -0.9657
|
| 55 |
+
vn 0.0000 -0.9527 -0.3038
|
| 56 |
+
vn -0.2235 -0.9559 -0.1904
|
| 57 |
+
vn -0.0489 -0.9657 -0.2552
|
| 58 |
+
vn 0.4655 0.8769 0.1198
|
| 59 |
+
vn 0.3830 0.8970 0.2205
|
| 60 |
+
vn 0.1907 0.9778 0.0875
|
| 61 |
+
vn 0.5197 0.7795 0.3497
|
| 62 |
+
vn 0.0368 0.0552 0.9978
|
| 63 |
+
vn 0.8066 0.0736 0.5865
|
| 64 |
+
vn 0.2595 0.0000 0.9657
|
| 65 |
+
vn 0.7069 0.0000 0.7073
|
| 66 |
+
vn -0.8241 -0.4128 -0.3879
|
| 67 |
+
vn -0.3724 -0.1866 -0.9091
|
| 68 |
+
vn -0.7650 0.0000 -0.6440
|
| 69 |
+
vn -0.9238 -0.0961 -0.3705
|
| 70 |
+
usemtl Shape.7503
|
| 71 |
+
s off
|
| 72 |
+
f 18//1 16//1 23//1
|
| 73 |
+
f 2//2 3//2 8//2
|
| 74 |
+
f 6//3 4//3 8//3
|
| 75 |
+
f 4//4 5//4 9//4
|
| 76 |
+
f 8//5 4//5 9//5
|
| 77 |
+
f 9//6 5//6 12//6
|
| 78 |
+
f 7//7 9//7 12//7
|
| 79 |
+
f 8//8 3//8 13//8
|
| 80 |
+
f 6//9 8//9 13//9
|
| 81 |
+
f 11//10 6//10 13//10
|
| 82 |
+
f 5//11 4//11 14//11
|
| 83 |
+
f 12//12 5//12 14//12
|
| 84 |
+
f 3//13 2//13 15//13
|
| 85 |
+
f 14//14 4//14 16//14
|
| 86 |
+
f 12//15 1//15 17//15
|
| 87 |
+
f 7//16 12//16 17//16
|
| 88 |
+
f 13//17 7//17 17//17
|
| 89 |
+
f 11//18 13//18 17//18
|
| 90 |
+
f 1//19 12//19 18//19
|
| 91 |
+
f 12//20 14//20 18//20
|
| 92 |
+
f 14//21 16//21 18//21
|
| 93 |
+
f 7//22 13//22 19//22
|
| 94 |
+
f 13//23 3//23 19//23
|
| 95 |
+
f 3//24 15//24 19//24
|
| 96 |
+
f 4//25 6//25 20//25
|
| 97 |
+
f 1//26 10//26 20//26
|
| 98 |
+
f 10//27 4//27 20//27
|
| 99 |
+
f 6//28 11//28 20//28
|
| 100 |
+
f 17//29 1//29 20//29
|
| 101 |
+
f 11//30 17//30 20//30
|
| 102 |
+
f 2//31 8//31 21//31
|
| 103 |
+
f 9//32 2//32 21//32
|
| 104 |
+
f 8//33 9//33 21//33
|
| 105 |
+
f 2//34 9//34 22//34
|
| 106 |
+
f 9//35 7//35 22//35
|
| 107 |
+
f 15//36 2//36 22//36
|
| 108 |
+
f 7//37 19//37 22//37
|
| 109 |
+
f 19//38 15//38 22//38
|
| 110 |
+
f 10//39 1//39 23//39
|
| 111 |
+
f 4//40 10//40 23//40
|
| 112 |
+
f 16//41 4//41 23//41
|
| 113 |
+
f 1//42 18//42 23//42
|
dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender v2.79 (sub 0) OBJ File: ''
|
| 2 |
+
# www.blender.org
|
| 3 |
+
mtllib 102697_link_4_19.mtl
|
| 4 |
+
o Shape_IndexedFaceSet.019_Shape_IndexedFaceSet.7505
|
| 5 |
+
v 0.418781 0.627924 -0.508767
|
| 6 |
+
v 0.428328 0.622978 -0.508489
|
| 7 |
+
v -0.056475 0.145332 -0.517066
|
| 8 |
+
v -0.076686 0.620168 -0.567029
|
| 9 |
+
v 0.366468 0.620168 -0.567029
|
| 10 |
+
v -0.076686 0.125411 -0.567029
|
| 11 |
+
v -0.054501 0.599577 -0.517235
|
| 12 |
+
v 0.320187 0.115689 -0.504922
|
| 13 |
+
v 0.322183 0.125411 -0.567029
|
| 14 |
+
v 0.418075 0.269311 -0.508853
|
| 15 |
+
v 0.411683 0.362817 -0.537486
|
| 16 |
+
v -0.067659 0.626136 -0.530364
|
| 17 |
+
v 0.388636 0.635034 -0.507943
|
| 18 |
+
v 0.389069 0.133686 -0.508902
|
| 19 |
+
v 0.351689 0.214056 -0.567029
|
| 20 |
+
v 0.314606 0.110688 -0.506236
|
| 21 |
+
v 0.413343 0.608120 -0.537486
|
| 22 |
+
v 0.390810 0.209582 -0.537486
|
| 23 |
+
v 0.432230 0.376647 -0.508872
|
| 24 |
+
v 0.373857 0.605402 -0.567029
|
| 25 |
+
v 0.359079 0.620168 -0.507943
|
| 26 |
+
v -0.067415 0.625895 -0.551722
|
| 27 |
+
vn 0.0041 0.9694 -0.2455
|
| 28 |
+
vn 0.4154 0.7751 -0.4762
|
| 29 |
+
vn 0.0000 0.0000 -1.0000
|
| 30 |
+
vn -0.0322 0.0005 0.9995
|
| 31 |
+
vn -0.9254 -0.0111 0.3788
|
| 32 |
+
vn -0.9710 -0.0000 0.2391
|
| 33 |
+
vn -0.7030 0.0033 0.7112
|
| 34 |
+
vn -0.0527 0.4214 0.9053
|
| 35 |
+
vn 0.0729 0.1956 0.9780
|
| 36 |
+
vn 0.2100 0.9266 -0.3119
|
| 37 |
+
vn 0.0149 0.0038 0.9999
|
| 38 |
+
vn 0.0612 -0.0134 0.9980
|
| 39 |
+
vn -0.0967 -0.9105 0.4021
|
| 40 |
+
vn -0.0476 -0.2038 0.9778
|
| 41 |
+
vn 0.0000 -0.9719 -0.2354
|
| 42 |
+
vn 0.2833 -0.9396 -0.1922
|
| 43 |
+
vn 0.1599 -0.4139 0.8961
|
| 44 |
+
vn 0.5278 0.6137 -0.5872
|
| 45 |
+
vn 0.8262 -0.1125 -0.5520
|
| 46 |
+
vn 0.6491 -0.2811 -0.7069
|
| 47 |
+
vn 0.8763 -0.1873 -0.4438
|
| 48 |
+
vn 0.5947 -0.0810 -0.7998
|
| 49 |
+
vn 0.5773 -0.1922 -0.7936
|
| 50 |
+
vn 0.0375 -0.0010 0.9993
|
| 51 |
+
vn 0.0502 -0.0064 0.9987
|
| 52 |
+
vn 0.8317 -0.1098 -0.5442
|
| 53 |
+
vn 0.8852 0.0147 -0.4650
|
| 54 |
+
vn 0.8135 -0.0055 -0.5815
|
| 55 |
+
vn 0.4969 -0.0281 -0.8673
|
| 56 |
+
vn 0.5624 0.2814 -0.7775
|
| 57 |
+
vn 0.5992 -0.0041 -0.8006
|
| 58 |
+
vn -0.0228 0.0077 0.9997
|
| 59 |
+
vn -0.0249 0.0495 0.9985
|
| 60 |
+
vn -0.0031 0.0062 1.0000
|
| 61 |
+
vn 0.0000 0.9366 -0.3504
|
| 62 |
+
vn -0.5068 0.8619 -0.0155
|
| 63 |
+
vn -0.0189 0.9998 -0.0115
|
| 64 |
+
usemtl Shape.7511
|
| 65 |
+
s off
|
| 66 |
+
f 13//1 5//1 22//1
|
| 67 |
+
f 1//2 2//2 5//2
|
| 68 |
+
f 4//3 5//3 6//3
|
| 69 |
+
f 7//4 3//4 8//4
|
| 70 |
+
f 6//3 5//3 9//3
|
| 71 |
+
f 6//5 3//5 12//5
|
| 72 |
+
f 4//6 6//6 12//6
|
| 73 |
+
f 3//7 7//7 12//7
|
| 74 |
+
f 12//8 7//8 13//8
|
| 75 |
+
f 2//9 1//9 13//9
|
| 76 |
+
f 1//10 5//10 13//10
|
| 77 |
+
f 8//11 2//11 13//11
|
| 78 |
+
f 10//12 8//12 14//12
|
| 79 |
+
f 9//3 5//3 15//3
|
| 80 |
+
f 3//13 6//13 16//13
|
| 81 |
+
f 8//14 3//14 16//14
|
| 82 |
+
f 6//15 9//15 16//15
|
| 83 |
+
f 9//16 14//16 16//16
|
| 84 |
+
f 14//17 8//17 16//17
|
| 85 |
+
f 5//18 2//18 17//18
|
| 86 |
+
f 11//19 10//19 18//19
|
| 87 |
+
f 14//20 9//20 18//20
|
| 88 |
+
f 10//21 14//21 18//21
|
| 89 |
+
f 15//22 11//22 18//22
|
| 90 |
+
f 9//23 15//23 18//23
|
| 91 |
+
f 2//24 8//24 19//24
|
| 92 |
+
f 8//25 10//25 19//25
|
| 93 |
+
f 10//26 11//26 19//26
|
| 94 |
+
f 17//27 2//27 19//27
|
| 95 |
+
f 11//28 17//28 19//28
|
| 96 |
+
f 11//29 15//29 20//29
|
| 97 |
+
f 15//3 5//3 20//3
|
| 98 |
+
f 5//30 17//30 20//30
|
| 99 |
+
f 17//31 11//31 20//31
|
| 100 |
+
f 7//32 8//32 21//32
|
| 101 |
+
f 13//33 7//33 21//33
|
| 102 |
+
f 8//34 13//34 21//34
|
| 103 |
+
f 5//35 4//35 22//35
|
| 104 |
+
f 4//36 12//36 22//36
|
| 105 |
+
f 12//37 13//37 22//37
|
dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender MTL File: 'None'
|
| 2 |
+
# Material Count: 1
|
| 3 |
+
|
| 4 |
+
newmtl Shape.7495
|
| 5 |
+
Ns 400.000000
|
| 6 |
+
Ka 0.400000 0.400000 0.400000
|
| 7 |
+
Kd 0.168000 0.496000 0.216000
|
| 8 |
+
Ks 0.250000 0.250000 0.250000
|
| 9 |
+
Ke 0.000000 0.000000 0.000000
|
| 10 |
+
Ni 1.000000
|
| 11 |
+
d 0.500000
|
| 12 |
+
illum 2
|
dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender MTL File: 'None'
|
| 2 |
+
# Material Count: 1
|
| 3 |
+
|
| 4 |
+
newmtl Shape.7525
|
| 5 |
+
Ns 400.000000
|
| 6 |
+
Ka 0.400000 0.400000 0.400000
|
| 7 |
+
Kd 0.272000 0.624000 0.536000
|
| 8 |
+
Ks 0.250000 0.250000 0.250000
|
| 9 |
+
Ke 0.000000 0.000000 0.000000
|
| 10 |
+
Ni 1.000000
|
| 11 |
+
d 0.500000
|
| 12 |
+
illum 2
|
dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender MTL File: 'None'
|
| 2 |
+
# Material Count: 1
|
| 3 |
+
|
| 4 |
+
newmtl Shape.7496
|
| 5 |
+
Ns 400.000000
|
| 6 |
+
Ka 0.400000 0.400000 0.400000
|
| 7 |
+
Kd 0.720000 0.472000 0.504000
|
| 8 |
+
Ks 0.250000 0.250000 0.250000
|
| 9 |
+
Ke 0.000000 0.000000 0.000000
|
| 10 |
+
Ni 1.000000
|
| 11 |
+
d 0.500000
|
| 12 |
+
illum 2
|
dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Blender MTL File: 'None'
|
| 2 |
+
# Material Count: 1
|
| 3 |
+
|
| 4 |
+
newmtl Shape.7500
|
| 5 |
+
Ns 400.000000
|
| 6 |
+
Ka 0.400000 0.400000 0.400000
|
| 7 |
+
Kd 0.184000 0.536000 0.280000
|
| 8 |
+
Ks 0.250000 0.250000 0.250000
|
| 9 |
+
Ke 0.000000 0.000000 0.000000
|
| 10 |
+
Ni 1.000000
|
| 11 |
+
d 0.500000
|
| 12 |
+
illum 2
|
dexart-release/assets/sapien/102697/result.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[{"text": "Toilet", "name": "Toilet", "id": 0, "children": [{"text": "button", "name": "button", "id": 1, "objs": ["original-8"]}, {"text": "lid", "name": "lid", "id": 2, "objs": ["original-4", "original-3"]}, {"text": "lid", "name": "lid", "id": 3, "objs": ["original-21", "original-20"]}, {"text": "seat", "name": "seat", "id": 4, "objs": ["original-18", "original-17"]}, {"text": "base_body", "name": "base_body", "id": 5, "objs": ["original-13", "original-6", "original-15", "original-11", "original-7", "original-10", "original-14"]}]}]
|
dexart-release/assets/sapien/102697/result_original.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[{"id": 0, "name": "Toilet", "text": "Toilet", "children": [{"id": 1, "name": "button", "text": "button", "objs": ["original-8"]}, {"id": 2, "name": "lid", "text": "lid", "objs": ["original-4", "original-3"]}, {"id": 3, "name": "lid", "text": "lid", "objs": ["original-21", "original-20"]}, {"id": 4, "name": "seat", "text": "seat", "objs": ["original-18", "original-17"]}]}]
|
dexart-release/assets/sapien/102697/semantics.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
link_0 hinge lever
|
| 2 |
+
link_1 slider pump_lid
|
| 3 |
+
link_2 hinge lid
|
| 4 |
+
link_3 hinge seat
|
| 5 |
+
link_4 static toilet_body
|
dexart-release/dexart.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: dexart
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: UNKNOWN
|
| 5 |
+
Home-page: https://github.com/Kami-code/dexart-release
|
| 6 |
+
Author: Xiaolong Wang's Lab
|
| 7 |
+
License: UNKNOWN
|
| 8 |
+
Platform: UNKNOWN
|
| 9 |
+
License-File: LICENSE
|
| 10 |
+
|
| 11 |
+
UNKNOWN
|
| 12 |
+
|
dexart-release/dexart.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
setup.py
|
| 4 |
+
dexart.egg-info/PKG-INFO
|
| 5 |
+
dexart.egg-info/SOURCES.txt
|
| 6 |
+
dexart.egg-info/dependency_links.txt
|
| 7 |
+
dexart.egg-info/requires.txt
|
| 8 |
+
dexart.egg-info/top_level.txt
|
| 9 |
+
stable_baselines3/__init__.py
|
| 10 |
+
stable_baselines3/pickle_utils.py
|
| 11 |
+
stable_baselines3/a2c/__init__.py
|
| 12 |
+
stable_baselines3/a2c/a2c.py
|
| 13 |
+
stable_baselines3/a2c/policies.py
|
| 14 |
+
stable_baselines3/common/__init__.py
|
| 15 |
+
stable_baselines3/common/base_class.py
|
| 16 |
+
stable_baselines3/common/buffers.py
|
| 17 |
+
stable_baselines3/common/callbacks.py
|
| 18 |
+
stable_baselines3/common/distributions.py
|
| 19 |
+
stable_baselines3/common/env_util.py
|
| 20 |
+
stable_baselines3/common/evaluation.py
|
| 21 |
+
stable_baselines3/common/logger.py
|
| 22 |
+
stable_baselines3/common/monitor.py
|
| 23 |
+
stable_baselines3/common/noise.py
|
| 24 |
+
stable_baselines3/common/on_policy_algorithm.py
|
| 25 |
+
stable_baselines3/common/policies.py
|
| 26 |
+
stable_baselines3/common/preprocessing.py
|
| 27 |
+
stable_baselines3/common/running_mean_std.py
|
| 28 |
+
stable_baselines3/common/save_util.py
|
| 29 |
+
stable_baselines3/common/torch_layers.py
|
| 30 |
+
stable_baselines3/common/type_aliases.py
|
| 31 |
+
stable_baselines3/common/utils.py
|
| 32 |
+
stable_baselines3/common/vec_env/__init__.py
|
| 33 |
+
stable_baselines3/common/vec_env/base_vec_env.py
|
| 34 |
+
stable_baselines3/common/vec_env/dummy_vec_env.py
|
| 35 |
+
stable_baselines3/common/vec_env/maniskill2_utils_common.py
|
| 36 |
+
stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py
|
| 37 |
+
stable_baselines3/common/vec_env/maniskill2_vec_env.py
|
| 38 |
+
stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py
|
| 39 |
+
stable_baselines3/common/vec_env/stacked_observations.py
|
| 40 |
+
stable_baselines3/common/vec_env/subproc_vec_env.py
|
| 41 |
+
stable_baselines3/common/vec_env/util.py
|
| 42 |
+
stable_baselines3/common/vec_env/vec_check_nan.py
|
| 43 |
+
stable_baselines3/common/vec_env/vec_extract_dict_obs.py
|
| 44 |
+
stable_baselines3/common/vec_env/vec_frame_stack.py
|
| 45 |
+
stable_baselines3/common/vec_env/vec_monitor.py
|
| 46 |
+
stable_baselines3/common/vec_env/vec_normalize.py
|
| 47 |
+
stable_baselines3/common/vec_env/vec_transpose.py
|
| 48 |
+
stable_baselines3/common/vec_env/vec_video_recorder.py
|
| 49 |
+
stable_baselines3/ppo/__init__.py
|
| 50 |
+
stable_baselines3/ppo/policies.py
|
| 51 |
+
stable_baselines3/ppo/ppo.py
|
dexart-release/dexart.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
dexart-release/dexart.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transforms3d
|
| 2 |
+
sapien==2.2.1
|
| 3 |
+
numpy
|
| 4 |
+
open3d>=0.15.1
|
dexart-release/dexart.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
stable_baselines3
|
dexart-release/examples/gen_demonstration_expert.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import zarr
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import pytorch3d.ops as torch3d_ops
|
| 12 |
+
from dexart.env.task_setting import TRAIN_CONFIG, RANDOM_CONFIG
|
| 13 |
+
from dexart.env.create_env import create_env
|
| 14 |
+
from stable_baselines3 import PPO
|
| 15 |
+
# from examples.train import get_3d_policy_kwargs
|
| 16 |
+
from train import get_3d_policy_kwargs
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from termcolor import cprint
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument('--task_name', type=str, required=True)
|
| 24 |
+
parser.add_argument('--checkpoint_path', type=str, required=True)
|
| 25 |
+
parser.add_argument('--num_episodes', type=int, default=10, help='number of total episodes')
|
| 26 |
+
parser.add_argument('--use_test_set', dest='use_test_set', action='store_true', default=False)
|
| 27 |
+
parser.add_argument('--root_dir', type=str, default='data', help='directory to save data')
|
| 28 |
+
parser.add_argument('--img_size', type=int, default=84, help='image size')
|
| 29 |
+
parser.add_argument('--num_points', type=int, default=1024, help='number of points in point cloud')
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
return args
|
| 32 |
+
|
| 33 |
+
def downsample_with_fps(points: np.ndarray, num_points: int = 512):
|
| 34 |
+
# fast point cloud sampling using torch3d
|
| 35 |
+
points = torch.from_numpy(points).unsqueeze(0).cuda()
|
| 36 |
+
num_points = torch.tensor([num_points]).cuda()
|
| 37 |
+
# remember to only use coord to sample
|
| 38 |
+
_, sampled_indices = torch3d_ops.sample_farthest_points(points=points[...,:3], K=num_points)
|
| 39 |
+
points = points.squeeze(0).cpu().numpy()
|
| 40 |
+
points = points[sampled_indices.squeeze(0).cpu().numpy()]
|
| 41 |
+
return points
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
args = parse_args()
|
| 45 |
+
task_name = args.task_name
|
| 46 |
+
use_test_set = args.use_test_set
|
| 47 |
+
checkpoint_path = args.checkpoint_path
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
save_dir = os.path.join(args.root_dir, 'dexart_'+args.task_name+'_expert.zarr')
|
| 51 |
+
if os.path.exists(save_dir):
|
| 52 |
+
cprint('Data already exists at {}'.format(save_dir), 'red')
|
| 53 |
+
cprint("If you want to overwrite, delete the existing directory first.", "red")
|
| 54 |
+
cprint("Do you want to overwrite? (y/n)", "red")
|
| 55 |
+
# user_input = input()
|
| 56 |
+
user_input = 'y'
|
| 57 |
+
if user_input == 'y':
|
| 58 |
+
cprint('Overwriting {}'.format(save_dir), 'red')
|
| 59 |
+
os.system('rm -rf {}'.format(save_dir))
|
| 60 |
+
else:
|
| 61 |
+
cprint('Exiting', 'red')
|
| 62 |
+
return
|
| 63 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if use_test_set:
|
| 67 |
+
indeces = TRAIN_CONFIG[task_name]['unseen']
|
| 68 |
+
cprint(f"using unseen instances {indeces}", 'yellow')
|
| 69 |
+
else:
|
| 70 |
+
indeces = TRAIN_CONFIG[task_name]['seen']
|
| 71 |
+
cprint(f"using seen instances {indeces}", 'yellow')
|
| 72 |
+
|
| 73 |
+
rand_pos = RANDOM_CONFIG[task_name]['rand_pos']
|
| 74 |
+
rand_degree = RANDOM_CONFIG[task_name]['rand_degree']
|
| 75 |
+
env = create_env(task_name=task_name,
|
| 76 |
+
use_visual_obs=True,
|
| 77 |
+
use_gui=False,
|
| 78 |
+
is_eval=True,
|
| 79 |
+
pc_noise=True,
|
| 80 |
+
pc_seg=True,
|
| 81 |
+
index=indeces,
|
| 82 |
+
img_type='robot',
|
| 83 |
+
rand_pos=rand_pos,
|
| 84 |
+
rand_degree=rand_degree)
|
| 85 |
+
|
| 86 |
+
policy = PPO.load(checkpoint_path, env, 'cuda:0',
|
| 87 |
+
policy_kwargs=get_3d_policy_kwargs(extractor_name='smallpn'),
|
| 88 |
+
check_obs_space=False, force_load=True)
|
| 89 |
+
|
| 90 |
+
eval_instances = len(env.instance_list)
|
| 91 |
+
num_episodes = args.num_episodes
|
| 92 |
+
cprint(f"generate {num_episodes} episodes in total", 'yellow')
|
| 93 |
+
|
| 94 |
+
success_list = []
|
| 95 |
+
reward_list = []
|
| 96 |
+
|
| 97 |
+
total_count = 0
|
| 98 |
+
img_arrays = []
|
| 99 |
+
point_cloud_arrays = []
|
| 100 |
+
depth_arrays = []
|
| 101 |
+
state_arrays = []
|
| 102 |
+
imagin_robot_arrays = []
|
| 103 |
+
action_arrays = []
|
| 104 |
+
episode_ends_arrays = []
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
with tqdm(total=num_episodes) as pbar:
|
| 108 |
+
num_success = 0
|
| 109 |
+
while num_success < num_episodes:
|
| 110 |
+
|
| 111 |
+
# obs dict keys: 'instance_1-seg_gt', 'instance_1-point_cloud',
|
| 112 |
+
# 'instance_1-rgb', 'imagination_robot', 'state', 'oracle_state'
|
| 113 |
+
obs = env.reset()
|
| 114 |
+
eval_success = False
|
| 115 |
+
reward_sum = 0
|
| 116 |
+
|
| 117 |
+
img_arrays_sub = []
|
| 118 |
+
point_cloud_arrays_sub = []
|
| 119 |
+
depth_arrays_sub = []
|
| 120 |
+
state_arrays_sub = []
|
| 121 |
+
imagin_robot_arrays_sub = []
|
| 122 |
+
action_arrays_sub = []
|
| 123 |
+
total_count_sub = 0
|
| 124 |
+
for j in range(env.horizon):
|
| 125 |
+
|
| 126 |
+
if isinstance(obs, dict):
|
| 127 |
+
for key, value in obs.items():
|
| 128 |
+
obs[key] = value[np.newaxis, :]
|
| 129 |
+
else:
|
| 130 |
+
obs = obs[np.newaxis, :]
|
| 131 |
+
action = policy.predict(observation=obs, deterministic=True)[0]
|
| 132 |
+
|
| 133 |
+
# fetch data
|
| 134 |
+
total_count_sub += 1
|
| 135 |
+
obs_state = obs['state'][0] # (32)
|
| 136 |
+
obs_imagin_robot = obs['imagination_robot'][0] # (96,7)
|
| 137 |
+
obs_point_cloud = obs['instance_1-point_cloud'][0] # (1024,3)
|
| 138 |
+
obs_depth = obs['instance_1-depth'][0] # (84,84)
|
| 139 |
+
|
| 140 |
+
if obs_point_cloud.shape[0] > args.num_points:
|
| 141 |
+
obs_point_cloud = downsample_with_fps(obs_point_cloud, num_points=args.num_points)
|
| 142 |
+
obs_image = obs['instance_1-rgb'][0] # (84,84,3), [0,1]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# to 0-255
|
| 146 |
+
obs_image = (obs_image*255).astype(np.uint8)
|
| 147 |
+
|
| 148 |
+
# interpolate to target image size
|
| 149 |
+
if obs_image.shape[0] != args.img_size:
|
| 150 |
+
obs_image = F.interpolate(torch.from_numpy(obs_image).permute(2,0,1).unsqueeze(0),
|
| 151 |
+
size=args.img_size).squeeze().permute(1,2,0).numpy()
|
| 152 |
+
# save data
|
| 153 |
+
img_arrays_sub.append(obs_image)
|
| 154 |
+
imagin_robot_arrays_sub.append(obs_imagin_robot)
|
| 155 |
+
point_cloud_arrays_sub.append(obs_point_cloud)
|
| 156 |
+
depth_arrays_sub.append(obs_depth)
|
| 157 |
+
state_arrays_sub.append(obs_state)
|
| 158 |
+
action_arrays_sub.append(action)
|
| 159 |
+
|
| 160 |
+
# step
|
| 161 |
+
obs, reward, done, _ = env.step(action)
|
| 162 |
+
reward_sum += reward
|
| 163 |
+
if env.is_eval_done:
|
| 164 |
+
eval_success = True
|
| 165 |
+
if done:
|
| 166 |
+
break
|
| 167 |
+
|
| 168 |
+
if eval_success:
|
| 169 |
+
total_count += total_count_sub
|
| 170 |
+
episode_ends_arrays.append(total_count) # the index of the last step of the episode
|
| 171 |
+
reward_list.append(reward_sum)
|
| 172 |
+
success_list.append(int(eval_success))
|
| 173 |
+
|
| 174 |
+
img_arrays.extend(img_arrays_sub)
|
| 175 |
+
imagin_robot_arrays.extend(imagin_robot_arrays_sub)
|
| 176 |
+
point_cloud_arrays.extend(point_cloud_arrays_sub)
|
| 177 |
+
depth_arrays.extend(depth_arrays_sub)
|
| 178 |
+
state_arrays.extend(state_arrays_sub)
|
| 179 |
+
action_arrays.extend(action_arrays_sub)
|
| 180 |
+
|
| 181 |
+
num_success += 1
|
| 182 |
+
|
| 183 |
+
pbar.update(1)
|
| 184 |
+
pbar.set_description(f"reward = {reward_sum}, success = {eval_success}")
|
| 185 |
+
else:
|
| 186 |
+
print("episode failed. continue.")
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
cprint(f"reward_mean = {np.mean(reward_list)}, success rate = {np.mean(success_list)}", 'yellow')
|
| 191 |
+
|
| 192 |
+
###############################
|
| 193 |
+
# save data
|
| 194 |
+
###############################
|
| 195 |
+
# create zarr file
|
| 196 |
+
zarr_root = zarr.group(save_dir)
|
| 197 |
+
zarr_data = zarr_root.create_group('data')
|
| 198 |
+
zarr_meta = zarr_root.create_group('meta')
|
| 199 |
+
# save img, state, action arrays into data, and episode ends arrays into meta
|
| 200 |
+
img_arrays = np.stack(img_arrays, axis=0)
|
| 201 |
+
if img_arrays.shape[1] == 3: # make channel last
|
| 202 |
+
img_arrays = np.transpose(img_arrays, (0,2,3,1))
|
| 203 |
+
state_arrays = np.stack(state_arrays, axis=0)
|
| 204 |
+
imagin_robot_arrays = np.stack(imagin_robot_arrays, axis=0)
|
| 205 |
+
point_cloud_arrays = np.stack(point_cloud_arrays, axis=0)
|
| 206 |
+
depth_arrays = np.stack(depth_arrays, axis=0)
|
| 207 |
+
action_arrays = np.stack(action_arrays, axis=0)
|
| 208 |
+
episode_ends_arrays = np.array(episode_ends_arrays)
|
| 209 |
+
|
| 210 |
+
compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=1)
|
| 211 |
+
img_chunk_size = (env.horizon, img_arrays.shape[1], img_arrays.shape[2], img_arrays.shape[3])
|
| 212 |
+
imagin_robot_chunk_size = (env.horizon, imagin_robot_arrays.shape[1], imagin_robot_arrays.shape[2])
|
| 213 |
+
point_cloud_chunk_size = (env.horizon, point_cloud_arrays.shape[1], point_cloud_arrays.shape[2])
|
| 214 |
+
depth_chunk_size = (env.horizon, depth_arrays.shape[1], depth_arrays.shape[2])
|
| 215 |
+
state_chunk_size = (env.horizon, state_arrays.shape[1])
|
| 216 |
+
action_chunk_size = (env.horizon, action_arrays.shape[1])
|
| 217 |
+
zarr_data.create_dataset('img', data=img_arrays, chunks=img_chunk_size, dtype='uint8', overwrite=True, compressor=compressor)
|
| 218 |
+
zarr_data.create_dataset('state', data=state_arrays, chunks=state_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
|
| 219 |
+
zarr_data.create_dataset('imagin_robot', data=imagin_robot_arrays, chunks=imagin_robot_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
|
| 220 |
+
zarr_data.create_dataset('point_cloud', data=point_cloud_arrays, chunks=point_cloud_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
|
| 221 |
+
zarr_data.create_dataset('depth', data=depth_arrays, chunks=depth_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
|
| 222 |
+
zarr_data.create_dataset('action', data=action_arrays, chunks=action_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
|
| 223 |
+
zarr_meta.create_dataset('episode_ends', data=episode_ends_arrays, dtype='int64', overwrite=True, compressor=compressor)
|
| 224 |
+
|
| 225 |
+
# print shape
|
| 226 |
+
cprint(f'img shape: {img_arrays.shape}, range: [{np.min(img_arrays)}, {np.max(img_arrays)}]', 'green')
|
| 227 |
+
cprint(f'imagin_robot shape: {imagin_robot_arrays.shape}, range: [{np.min(imagin_robot_arrays)}, {np.max(imagin_robot_arrays)}]', 'green')
|
| 228 |
+
cprint(f'point_cloud shape: {point_cloud_arrays.shape}, range: [{np.min(point_cloud_arrays)}, {np.max(point_cloud_arrays)}]', 'green')
|
| 229 |
+
cprint(f'depth shape: {depth_arrays.shape}, range: [{np.min(depth_arrays)}, {np.max(depth_arrays)}]', 'green')
|
| 230 |
+
cprint(f'state shape: {state_arrays.shape}, range: [{np.min(state_arrays)}, {np.max(state_arrays)}]', 'green')
|
| 231 |
+
cprint(f'action shape: {action_arrays.shape}, range: [{np.min(action_arrays)}, {np.max(action_arrays)}]', 'green')
|
| 232 |
+
cprint(f'Saved zarr file to {save_dir}', 'green')
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
main()
|
| 237 |
+
|
| 238 |
+
|
dexart-release/examples/train.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 4 |
+
import random
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import argparse
|
| 8 |
+
from dexart.env.create_env import create_env
|
| 9 |
+
from dexart.env.task_setting import TRAIN_CONFIG, IMG_CONFIG, RANDOM_CONFIG
|
| 10 |
+
from stable_baselines3.common.torch_layers import PointNetImaginationExtractorGP
|
| 11 |
+
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
| 12 |
+
from stable_baselines3.ppo import PPO
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_3d_policy_kwargs(extractor_name):
|
| 19 |
+
feature_extractor_class = PointNetImaginationExtractorGP
|
| 20 |
+
feature_extractor_kwargs = {"pc_key": "instance_1-point_cloud", "gt_key": "instance_1-seg_gt",
|
| 21 |
+
"extractor_name": extractor_name,
|
| 22 |
+
"imagination_keys": [f'imagination_{key}' for key in IMG_CONFIG['robot'].keys()],
|
| 23 |
+
"state_key": "state"}
|
| 24 |
+
|
| 25 |
+
policy_kwargs = {
|
| 26 |
+
"features_extractor_class": feature_extractor_class,
|
| 27 |
+
"features_extractor_kwargs": feature_extractor_kwargs,
|
| 28 |
+
"net_arch": [dict(pi=[64, 64], vf=[64, 64])],
|
| 29 |
+
"activation_fn": nn.ReLU,
|
| 30 |
+
}
|
| 31 |
+
return policy_kwargs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == '__main__':
|
| 35 |
+
parser = argparse.ArgumentParser()
|
| 36 |
+
parser.add_argument('--n', type=int, default=10)
|
| 37 |
+
parser.add_argument('--workers', type=int, default=1)
|
| 38 |
+
parser.add_argument('--lr', type=float, default=3e-4)
|
| 39 |
+
parser.add_argument('--ep', type=int, default=10)
|
| 40 |
+
parser.add_argument('--bs', type=int, default=10)
|
| 41 |
+
parser.add_argument('--seed', type=int, default=100)
|
| 42 |
+
parser.add_argument('--iter', type=int, default=1000)
|
| 43 |
+
parser.add_argument('--freeze', dest='freeze', action='store_true', default=False)
|
| 44 |
+
parser.add_argument('--task_name', type=str, default="laptop")
|
| 45 |
+
parser.add_argument('--extractor_name', type=str, default="smallpn")
|
| 46 |
+
parser.add_argument('--pretrain_path', type=str, default=None)
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
task_name = args.task_name
|
| 50 |
+
extractor_name = args.extractor_name
|
| 51 |
+
seed = args.seed if args.seed >= 0 else random.randint(0, 100000)
|
| 52 |
+
pretrain_path = args.pretrain_path
|
| 53 |
+
horizon = 200
|
| 54 |
+
env_iter = args.iter * horizon * args.n
|
| 55 |
+
print(f"freeze: {args.freeze}")
|
| 56 |
+
|
| 57 |
+
rand_pos = RANDOM_CONFIG[task_name]['rand_pos']
|
| 58 |
+
rand_degree = RANDOM_CONFIG[task_name]['rand_degree']
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def create_env_fn():
|
| 62 |
+
seen_indeces = TRAIN_CONFIG[task_name]['seen']
|
| 63 |
+
environment = create_env(task_name=task_name,
|
| 64 |
+
use_visual_obs=True,
|
| 65 |
+
use_gui=False,
|
| 66 |
+
is_eval=False,
|
| 67 |
+
pc_noise=True,
|
| 68 |
+
index=seen_indeces,
|
| 69 |
+
img_type='robot',
|
| 70 |
+
rand_pos=rand_pos,
|
| 71 |
+
rand_degree=rand_degree
|
| 72 |
+
)
|
| 73 |
+
return environment
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def create_eval_env_fn():
|
| 77 |
+
unseen_indeces = TRAIN_CONFIG[task_name]['unseen']
|
| 78 |
+
environment = create_env(task_name=task_name,
|
| 79 |
+
use_visual_obs=True,
|
| 80 |
+
use_gui=False,
|
| 81 |
+
is_eval=True,
|
| 82 |
+
pc_noise=True,
|
| 83 |
+
index=unseen_indeces,
|
| 84 |
+
img_type='robot',
|
| 85 |
+
rand_pos=rand_pos,
|
| 86 |
+
rand_degree=rand_degree)
|
| 87 |
+
return environment
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
env = SubprocVecEnv([create_env_fn] * args.workers, "spawn") # train on a list of envs.
|
| 91 |
+
|
| 92 |
+
model = PPO("PointCloudPolicy", env, verbose=1,
|
| 93 |
+
n_epochs=args.ep,
|
| 94 |
+
n_steps=(args.n // args.workers) * horizon,
|
| 95 |
+
learning_rate=args.lr,
|
| 96 |
+
batch_size=args.bs,
|
| 97 |
+
seed=seed,
|
| 98 |
+
policy_kwargs=get_3d_policy_kwargs(extractor_name=extractor_name),
|
| 99 |
+
min_lr=args.lr,
|
| 100 |
+
max_lr=args.lr,
|
| 101 |
+
adaptive_kl=0.02,
|
| 102 |
+
target_kl=0.2,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if pretrain_path is not None:
|
| 106 |
+
state_dict: OrderedDict = torch.load(pretrain_path)
|
| 107 |
+
model.policy.features_extractor.extractor.load_state_dict(state_dict, strict=False)
|
| 108 |
+
print("load pretrained model: ", pretrain_path)
|
| 109 |
+
|
| 110 |
+
rollout = int(model.num_timesteps / (horizon * args.n))
|
| 111 |
+
|
| 112 |
+
# after loading or init the model, then freeze it if needed
|
| 113 |
+
if args.freeze:
|
| 114 |
+
model.policy.features_extractor.extractor.eval()
|
| 115 |
+
for param in model.policy.features_extractor.extractor.parameters():
|
| 116 |
+
param.requires_grad = False
|
| 117 |
+
print("freeze model!")
|
| 118 |
+
|
| 119 |
+
model.learn(
|
| 120 |
+
total_timesteps=int(env_iter),
|
| 121 |
+
reset_num_timesteps=False,
|
| 122 |
+
iter_start=rollout,
|
| 123 |
+
callback=None
|
| 124 |
+
)
|
dexart-release/examples/utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: UTF-8 -*-
|
| 3 |
+
import numpy as np
|
| 4 |
+
from dexart.env.task_setting import ROBUSTNESS_INIT_CAMERA_CONFIG
|
| 5 |
+
import open3d as o3d
|
| 6 |
+
|
| 7 |
+
def visualize_observation(obs, use_seg=False, img_type=None):
|
| 8 |
+
def visualize_pc_with_seg_label(cloud):
|
| 9 |
+
pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cloud[:, :3]))
|
| 10 |
+
|
| 11 |
+
def map(feature):
|
| 12 |
+
color = np.zeros((feature.shape[0], 3))
|
| 13 |
+
COLOR20 = np.array(
|
| 14 |
+
[[230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48],
|
| 15 |
+
[145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190],
|
| 16 |
+
[0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0],
|
| 17 |
+
[170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128]]) / 255
|
| 18 |
+
for i in range(feature.shape[0]):
|
| 19 |
+
for j in range(feature.shape[1]):
|
| 20 |
+
if feature[i, j] == 1:
|
| 21 |
+
color[i, :] = COLOR20[j, :]
|
| 22 |
+
return color
|
| 23 |
+
|
| 24 |
+
color = map(cloud[:, 3:])
|
| 25 |
+
pc.colors = o3d.utility.Vector3dVector(color)
|
| 26 |
+
return pc
|
| 27 |
+
|
| 28 |
+
pc = obs["instance_1-point_cloud"]
|
| 29 |
+
if use_seg:
|
| 30 |
+
gt_seg = obs["instance_1-seg_gt"]
|
| 31 |
+
pc = np.concatenate([pc, gt_seg], axis=1)
|
| 32 |
+
pc = visualize_pc_with_seg_label(pc)
|
| 33 |
+
if img_type == "robot":
|
| 34 |
+
robot_pc = obs["imagination_robot"]
|
| 35 |
+
pc += visualize_pc_with_seg_label(robot_pc)
|
| 36 |
+
else:
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
return pc
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_viewpoint_camera_parameter():
|
| 42 |
+
robustness_init_camera_config = ROBUSTNESS_INIT_CAMERA_CONFIG['laptop']
|
| 43 |
+
r = robustness_init_camera_config['r']
|
| 44 |
+
phi = robustness_init_camera_config['phi']
|
| 45 |
+
theta = robustness_init_camera_config['theta']
|
| 46 |
+
center = robustness_init_camera_config['center']
|
| 47 |
+
|
| 48 |
+
x0, y0, z0 = center
|
| 49 |
+
# phi in [0, pi/2]
|
| 50 |
+
# theta in [0, 2 * pi]
|
| 51 |
+
x = x0 + r * np.sin(phi) * np.cos(theta)
|
| 52 |
+
y = y0 + r * np.sin(phi) * np.sin(theta)
|
| 53 |
+
z = z0 + r * np.cos(phi)
|
| 54 |
+
|
| 55 |
+
cam_pos = np.array([x, y, z])
|
| 56 |
+
forward = np.array([x0 - x, y0 - y, z0 - z])
|
| 57 |
+
forward /= np.linalg.norm(forward)
|
| 58 |
+
|
| 59 |
+
left = np.cross([0, 0, 1], forward)
|
| 60 |
+
left = left / np.linalg.norm(left)
|
| 61 |
+
|
| 62 |
+
up = np.cross(forward, left)
|
| 63 |
+
mat44 = np.eye(4)
|
| 64 |
+
mat44[:3, :3] = np.stack([forward, left, up], axis=1)
|
| 65 |
+
mat44[:3, 3] = cam_pos
|
| 66 |
+
return cam_pos, center, up, mat44
|
dexart-release/stable_baselines3/a2c/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stable_baselines3.a2c.a2c import A2C
|
| 2 |
+
from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
dexart-release/stable_baselines3/a2c/a2c.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Type, Union
|
| 2 |
+
|
| 3 |
+
import torch as th
|
| 4 |
+
from gym import spaces
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
| 8 |
+
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
|
| 9 |
+
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
| 10 |
+
from stable_baselines3.common.utils import explained_variance
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class A2C(OnPolicyAlgorithm):
|
| 14 |
+
"""
|
| 15 |
+
Advantage Actor Critic (A2C)
|
| 16 |
+
|
| 17 |
+
Paper: https://arxiv.org/abs/1602.01783
|
| 18 |
+
Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
|
| 19 |
+
and Stable Baselines (https://github.com/hill-a/stable-baselines)
|
| 20 |
+
|
| 21 |
+
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
|
| 22 |
+
|
| 23 |
+
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
| 24 |
+
:param env: The environment to learn from (if registered in Gym, can be str)
|
| 25 |
+
:param learning_rate: The learning rate, it can be a function
|
| 26 |
+
of the current progress remaining (from 1 to 0)
|
| 27 |
+
:param n_steps: The number of steps to run for each environment per update
|
| 28 |
+
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
| 29 |
+
:param gamma: Discount factor
|
| 30 |
+
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
| 31 |
+
Equivalent to classic advantage when set to 1.
|
| 32 |
+
:param ent_coef: Entropy coefficient for the loss calculation
|
| 33 |
+
:param vf_coef: Value function coefficient for the loss calculation
|
| 34 |
+
:param max_grad_norm: The maximum value for the gradient clipping
|
| 35 |
+
:param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator
|
| 36 |
+
of RMSProp update
|
| 37 |
+
:param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer
|
| 38 |
+
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
| 39 |
+
instead of action noise exploration (default: False)
|
| 40 |
+
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
| 41 |
+
Default: -1 (only sample at the beginning of the rollout)
|
| 42 |
+
:param normalize_advantage: Whether to normalize or not the advantage
|
| 43 |
+
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
| 44 |
+
:param create_eval_env: Whether to create a second environment that will be
|
| 45 |
+
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
| 46 |
+
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
| 47 |
+
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
| 48 |
+
:param seed: Seed for the pseudo random generators
|
| 49 |
+
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
| 50 |
+
Setting it to auto, the code will be run on the GPU if possible.
|
| 51 |
+
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
| 55 |
+
"MlpPolicy": ActorCriticPolicy,
|
| 56 |
+
"CnnPolicy": ActorCriticCnnPolicy,
|
| 57 |
+
"MultiInputPolicy": MultiInputActorCriticPolicy,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
policy: Union[str, Type[ActorCriticPolicy]],
|
| 63 |
+
env: Union[GymEnv, str],
|
| 64 |
+
learning_rate: Union[float, Schedule] = 7e-4,
|
| 65 |
+
n_steps: int = 5,
|
| 66 |
+
gamma: float = 0.99,
|
| 67 |
+
gae_lambda: float = 1.0,
|
| 68 |
+
ent_coef: float = 0.0,
|
| 69 |
+
vf_coef: float = 0.5,
|
| 70 |
+
max_grad_norm: float = 0.5,
|
| 71 |
+
rms_prop_eps: float = 1e-5,
|
| 72 |
+
use_rms_prop: bool = True,
|
| 73 |
+
use_sde: bool = False,
|
| 74 |
+
sde_sample_freq: int = -1,
|
| 75 |
+
normalize_advantage: bool = False,
|
| 76 |
+
tensorboard_log: Optional[str] = None,
|
| 77 |
+
create_eval_env: bool = False,
|
| 78 |
+
policy_kwargs: Optional[Dict[str, Any]] = None,
|
| 79 |
+
verbose: int = 0,
|
| 80 |
+
seed: Optional[int] = None,
|
| 81 |
+
device: Union[th.device, str] = "auto",
|
| 82 |
+
_init_setup_model: bool = True,
|
| 83 |
+
):
|
| 84 |
+
|
| 85 |
+
super().__init__(
|
| 86 |
+
policy,
|
| 87 |
+
env,
|
| 88 |
+
learning_rate=learning_rate,
|
| 89 |
+
n_steps=n_steps,
|
| 90 |
+
gamma=gamma,
|
| 91 |
+
gae_lambda=gae_lambda,
|
| 92 |
+
ent_coef=ent_coef,
|
| 93 |
+
vf_coef=vf_coef,
|
| 94 |
+
max_grad_norm=max_grad_norm,
|
| 95 |
+
use_sde=use_sde,
|
| 96 |
+
sde_sample_freq=sde_sample_freq,
|
| 97 |
+
tensorboard_log=tensorboard_log,
|
| 98 |
+
policy_kwargs=policy_kwargs,
|
| 99 |
+
verbose=verbose,
|
| 100 |
+
device=device,
|
| 101 |
+
create_eval_env=create_eval_env,
|
| 102 |
+
seed=seed,
|
| 103 |
+
_init_setup_model=False,
|
| 104 |
+
supported_action_spaces=(
|
| 105 |
+
spaces.Box,
|
| 106 |
+
spaces.Discrete,
|
| 107 |
+
spaces.MultiDiscrete,
|
| 108 |
+
spaces.MultiBinary,
|
| 109 |
+
),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.normalize_advantage = normalize_advantage
|
| 113 |
+
|
| 114 |
+
# Update optimizer inside the policy if we want to use RMSProp
|
| 115 |
+
# (original implementation) rather than Adam
|
| 116 |
+
if use_rms_prop and "optimizer_class" not in self.policy_kwargs:
|
| 117 |
+
self.policy_kwargs["optimizer_class"] = th.optim.RMSprop
|
| 118 |
+
self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0)
|
| 119 |
+
|
| 120 |
+
if _init_setup_model:
|
| 121 |
+
self._setup_model()
|
| 122 |
+
|
| 123 |
+
def train(self) -> None:
|
| 124 |
+
"""
|
| 125 |
+
Update policy using the currently gathered
|
| 126 |
+
rollout buffer (one gradient step over whole data).
|
| 127 |
+
"""
|
| 128 |
+
# Switch to train mode (this affects batch norm / dropout)
|
| 129 |
+
self.policy.set_training_mode(True)
|
| 130 |
+
|
| 131 |
+
# Update optimizer learning rate
|
| 132 |
+
self._update_learning_rate(self.policy.optimizer)
|
| 133 |
+
|
| 134 |
+
# This will only loop once (get all data in one go)
|
| 135 |
+
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
| 136 |
+
|
| 137 |
+
actions = rollout_data.actions
|
| 138 |
+
if isinstance(self.action_space, spaces.Discrete):
|
| 139 |
+
# Convert discrete action from float to long
|
| 140 |
+
actions = actions.long().flatten()
|
| 141 |
+
|
| 142 |
+
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
| 143 |
+
values = values.flatten()
|
| 144 |
+
|
| 145 |
+
# Normalize advantage (not present in the original implementation)
|
| 146 |
+
advantages = rollout_data.advantages
|
| 147 |
+
if self.normalize_advantage:
|
| 148 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 149 |
+
|
| 150 |
+
# Policy gradient loss
|
| 151 |
+
policy_loss = -(advantages * log_prob).mean()
|
| 152 |
+
|
| 153 |
+
# Value loss using the TD(gae_lambda) target
|
| 154 |
+
value_loss = F.mse_loss(rollout_data.returns, values)
|
| 155 |
+
|
| 156 |
+
# Entropy loss favor exploration
|
| 157 |
+
if entropy is None:
|
| 158 |
+
# Approximate entropy when no analytical form
|
| 159 |
+
entropy_loss = -th.mean(-log_prob)
|
| 160 |
+
else:
|
| 161 |
+
entropy_loss = -th.mean(entropy)
|
| 162 |
+
|
| 163 |
+
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
| 164 |
+
|
| 165 |
+
# Optimization step
|
| 166 |
+
self.policy.optimizer.zero_grad()
|
| 167 |
+
loss.backward()
|
| 168 |
+
|
| 169 |
+
# Clip grad norm
|
| 170 |
+
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
| 171 |
+
self.policy.optimizer.step()
|
| 172 |
+
|
| 173 |
+
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
| 174 |
+
|
| 175 |
+
self._n_updates += 1
|
| 176 |
+
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
| 177 |
+
self.logger.record("train/explained_variance", explained_var)
|
| 178 |
+
self.logger.record("train/entropy_loss", entropy_loss.item())
|
| 179 |
+
self.logger.record("train/policy_loss", policy_loss.item())
|
| 180 |
+
self.logger.record("train/value_loss", value_loss.item())
|
| 181 |
+
if hasattr(self.policy, "log_std"):
|
| 182 |
+
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
| 183 |
+
|
| 184 |
+
def learn(
|
| 185 |
+
self,
|
| 186 |
+
total_timesteps: int,
|
| 187 |
+
callback: MaybeCallback = None,
|
| 188 |
+
log_interval: int = 100,
|
| 189 |
+
eval_env: Optional[GymEnv] = None,
|
| 190 |
+
eval_freq: int = -1,
|
| 191 |
+
n_eval_episodes: int = 5,
|
| 192 |
+
tb_log_name: str = "A2C",
|
| 193 |
+
eval_log_path: Optional[str] = None,
|
| 194 |
+
reset_num_timesteps: bool = True,
|
| 195 |
+
) -> "A2C":
|
| 196 |
+
|
| 197 |
+
return super().learn(
|
| 198 |
+
total_timesteps=total_timesteps,
|
| 199 |
+
callback=callback,
|
| 200 |
+
log_interval=log_interval,
|
| 201 |
+
eval_env=eval_env,
|
| 202 |
+
eval_freq=eval_freq,
|
| 203 |
+
n_eval_episodes=n_eval_episodes,
|
| 204 |
+
tb_log_name=tb_log_name,
|
| 205 |
+
eval_log_path=eval_log_path,
|
| 206 |
+
reset_num_timesteps=reset_num_timesteps,
|
| 207 |
+
)
|
dexart-release/stable_baselines3/a2c/policies.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is here just to define MlpPolicy/CnnPolicy
|
| 2 |
+
# that work for A2C
|
| 3 |
+
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
| 4 |
+
|
| 5 |
+
MlpPolicy = ActorCriticPolicy
|
| 6 |
+
CnnPolicy = ActorCriticCnnPolicy
|
| 7 |
+
MultiInputPolicy = MultiInputActorCriticPolicy
|
dexart-release/stable_baselines3/common/__init__.py
ADDED
|
File without changes
|
dexart-release/stable_baselines3/common/base_class.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base classes for RL algorithms."""
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import pathlib
|
| 5 |
+
import time
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from collections import deque
|
| 8 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
| 9 |
+
|
| 10 |
+
import gym
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch as th
|
| 13 |
+
|
| 14 |
+
from stable_baselines3.common import utils
|
| 15 |
+
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
| 16 |
+
from stable_baselines3.common.env_util import is_wrapped
|
| 17 |
+
from stable_baselines3.common.logger import Logger
|
| 18 |
+
from stable_baselines3.common.monitor import Monitor
|
| 19 |
+
from stable_baselines3.common.noise import ActionNoise
|
| 20 |
+
from stable_baselines3.common.policies import BasePolicy
|
| 21 |
+
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
|
| 22 |
+
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
|
| 23 |
+
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
| 24 |
+
from stable_baselines3.common.utils import (
|
| 25 |
+
check_for_correct_spaces,
|
| 26 |
+
get_device,
|
| 27 |
+
get_schedule_fn,
|
| 28 |
+
get_system_info,
|
| 29 |
+
set_random_seed,
|
| 30 |
+
update_learning_rate,
|
| 31 |
+
)
|
| 32 |
+
from stable_baselines3.common.vec_env import (
|
| 33 |
+
DummyVecEnv,
|
| 34 |
+
VecEnv,
|
| 35 |
+
VecNormalize,
|
| 36 |
+
VecTransposeImage,
|
| 37 |
+
is_vecenv_wrapped,
|
| 38 |
+
unwrap_vec_normalize,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
|
| 43 |
+
"""If env is a string, make the environment; otherwise, return env.
|
| 44 |
+
|
| 45 |
+
:param env: The environment to learn from.
|
| 46 |
+
:param verbose: logging verbosity
|
| 47 |
+
:return A Gym (vector) environment.
|
| 48 |
+
"""
|
| 49 |
+
if isinstance(env, str):
|
| 50 |
+
if verbose >= 1:
|
| 51 |
+
print(f"Creating environment from the given name '{env}'")
|
| 52 |
+
env = gym.make(env)
|
| 53 |
+
return env
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class BaseAlgorithm(ABC):
|
| 57 |
+
"""
|
| 58 |
+
The base of RL algorithms
|
| 59 |
+
|
| 60 |
+
:param policy: Policy object
|
| 61 |
+
:param env: The environment to learn from
|
| 62 |
+
(if registered in Gym, can be str. Can be None for loading trained models)
|
| 63 |
+
:param learning_rate: learning rate for the optimizer,
|
| 64 |
+
it can be a function of the current progress remaining (from 1 to 0)
|
| 65 |
+
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
| 66 |
+
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
| 67 |
+
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
| 68 |
+
:param device: Device on which the code should run.
|
| 69 |
+
By default, it will try to use a Cuda compatible device and fallback to cpu
|
| 70 |
+
if it is not possible.
|
| 71 |
+
:param support_multi_env: Whether the algorithm supports training
|
| 72 |
+
with multiple environments (as in A2C)
|
| 73 |
+
:param create_eval_env: Whether to create a second environment that will be
|
| 74 |
+
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
| 75 |
+
:param monitor_wrapper: When creating an environment, whether to wrap it
|
| 76 |
+
or not in a Monitor wrapper.
|
| 77 |
+
:param seed: Seed for the pseudo random generators
|
| 78 |
+
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
| 79 |
+
instead of action noise exploration (default: False)
|
| 80 |
+
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
| 81 |
+
Default: -1 (only sample at the beginning of the rollout)
|
| 82 |
+
:param supported_action_spaces: The action spaces supported by the algorithm.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# Policy aliases (see _get_policy_from_name())
|
| 86 |
+
policy_aliases: Dict[str, Type[BasePolicy]] = {}
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
policy: Type[BasePolicy],
|
| 91 |
+
env: Union[GymEnv, str, None],
|
| 92 |
+
learning_rate: Union[float, Schedule],
|
| 93 |
+
policy_kwargs: Optional[Dict[str, Any]] = None,
|
| 94 |
+
tensorboard_log: Optional[str] = None,
|
| 95 |
+
verbose: int = 0,
|
| 96 |
+
device: Union[th.device, str] = "auto",
|
| 97 |
+
support_multi_env: bool = False,
|
| 98 |
+
create_eval_env: bool = False,
|
| 99 |
+
monitor_wrapper: bool = True,
|
| 100 |
+
seed: Optional[int] = None,
|
| 101 |
+
use_sde: bool = False,
|
| 102 |
+
sde_sample_freq: int = -1,
|
| 103 |
+
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
| 104 |
+
):
|
| 105 |
+
if isinstance(policy, str):
|
| 106 |
+
self.policy_class = self._get_policy_from_name(policy)
|
| 107 |
+
else:
|
| 108 |
+
self.policy_class = policy
|
| 109 |
+
|
| 110 |
+
self.device = get_device(device)
|
| 111 |
+
if verbose > 0:
|
| 112 |
+
print(f"Using {self.device} device")
|
| 113 |
+
|
| 114 |
+
self.env = None # type: Optional[GymEnv]
|
| 115 |
+
# get VecNormalize object if needed
|
| 116 |
+
self._vec_normalize_env = unwrap_vec_normalize(env)
|
| 117 |
+
self.verbose = verbose
|
| 118 |
+
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
|
| 119 |
+
self.observation_space = None # type: Optional[gym.spaces.Space]
|
| 120 |
+
self.action_space = None # type: Optional[gym.spaces.Space]
|
| 121 |
+
self.n_envs = None
|
| 122 |
+
self.num_timesteps = 0
|
| 123 |
+
# Used for updating schedules
|
| 124 |
+
self._total_timesteps = 0
|
| 125 |
+
# Used for computing fps, it is updated at each call of learn()
|
| 126 |
+
self._num_timesteps_at_start = 0
|
| 127 |
+
self.eval_env = None
|
| 128 |
+
self.seed = seed
|
| 129 |
+
self.action_noise = None # type: Optional[ActionNoise]
|
| 130 |
+
self.start_time = None
|
| 131 |
+
self.policy = None
|
| 132 |
+
self.learning_rate = learning_rate
|
| 133 |
+
self.tensorboard_log = tensorboard_log
|
| 134 |
+
self.lr_schedule = None # type: Optional[Schedule]
|
| 135 |
+
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
|
| 136 |
+
self._last_episode_starts = None # type: Optional[np.ndarray]
|
| 137 |
+
# When using VecNormalize:
|
| 138 |
+
self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
|
| 139 |
+
self._episode_num = 0
|
| 140 |
+
# Used for gSDE only
|
| 141 |
+
self.use_sde = use_sde
|
| 142 |
+
self.sde_sample_freq = sde_sample_freq
|
| 143 |
+
# Track the training progress remaining (from 1 to 0)
|
| 144 |
+
# this is used to update the learning rate
|
| 145 |
+
self._current_progress_remaining = 1
|
| 146 |
+
# Buffers for logging
|
| 147 |
+
self.ep_info_buffer = None # type: Optional[deque]
|
| 148 |
+
self.ep_success_buffer = None # type: Optional[deque]
|
| 149 |
+
# For logging (and TD3 delayed updates)
|
| 150 |
+
self._n_updates = 0 # type: int
|
| 151 |
+
# The logger object
|
| 152 |
+
self._logger = None # type: Logger
|
| 153 |
+
# Whether the user passed a custom logger or not
|
| 154 |
+
self._custom_logger = False
|
| 155 |
+
|
| 156 |
+
# Create and wrap the env if needed
|
| 157 |
+
if env is not None:
|
| 158 |
+
if isinstance(env, str):
|
| 159 |
+
if create_eval_env:
|
| 160 |
+
self.eval_env = maybe_make_env(env, self.verbose)
|
| 161 |
+
|
| 162 |
+
env = maybe_make_env(env, self.verbose)
|
| 163 |
+
env = self._wrap_env(env, self.verbose, monitor_wrapper)
|
| 164 |
+
|
| 165 |
+
self.observation_space = env.observation_space
|
| 166 |
+
self.action_space = env.action_space
|
| 167 |
+
self.n_envs = env.num_envs
|
| 168 |
+
self.env = env
|
| 169 |
+
|
| 170 |
+
if supported_action_spaces is not None:
|
| 171 |
+
assert isinstance(self.action_space, supported_action_spaces), (
|
| 172 |
+
f"The algorithm only supports {supported_action_spaces} as action spaces "
|
| 173 |
+
f"but {self.action_space} was provided"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if not support_multi_env and self.n_envs > 1:
|
| 177 |
+
raise ValueError(
|
| 178 |
+
"Error: the model does not support multiple envs; it requires " "a single vectorized environment."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
|
| 182 |
+
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
|
| 183 |
+
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
|
| 184 |
+
|
| 185 |
+
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
|
| 186 |
+
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
|
| 190 |
+
""" "
|
| 191 |
+
Wrap environment with the appropriate wrappers if needed.
|
| 192 |
+
For instance, to have a vectorized environment
|
| 193 |
+
or to re-order the image channels.
|
| 194 |
+
|
| 195 |
+
:param env:
|
| 196 |
+
:param verbose:
|
| 197 |
+
:param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible.
|
| 198 |
+
:return: The wrapped environment.
|
| 199 |
+
"""
|
| 200 |
+
if not isinstance(env, VecEnv):
|
| 201 |
+
if not is_wrapped(env, Monitor) and monitor_wrapper:
|
| 202 |
+
if verbose >= 1:
|
| 203 |
+
print("Wrapping the env with a `Monitor` wrapper")
|
| 204 |
+
env = Monitor(env)
|
| 205 |
+
if verbose >= 1:
|
| 206 |
+
print("Wrapping the env in a DummyVecEnv.")
|
| 207 |
+
env = DummyVecEnv([lambda: env])
|
| 208 |
+
|
| 209 |
+
# Make sure that dict-spaces are not nested (not supported)
|
| 210 |
+
check_for_nested_spaces(env.observation_space)
|
| 211 |
+
|
| 212 |
+
if not is_vecenv_wrapped(env, VecTransposeImage):
|
| 213 |
+
wrap_with_vectranspose = False
|
| 214 |
+
if isinstance(env.observation_space, gym.spaces.Dict):
|
| 215 |
+
# If even one of the keys is a image-space in need of transpose, apply transpose
|
| 216 |
+
# If the image spaces are not consistent (for instance one is channel first,
|
| 217 |
+
# the other channel last), VecTransposeImage will throw an error
|
| 218 |
+
for space in env.observation_space.spaces.values():
|
| 219 |
+
wrap_with_vectranspose = wrap_with_vectranspose or (
|
| 220 |
+
is_image_space(space) and not is_image_space_channels_first(space)
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
|
| 224 |
+
env.observation_space
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if wrap_with_vectranspose:
|
| 228 |
+
if verbose >= 1:
|
| 229 |
+
print("Wrapping the env in a VecTransposeImage.")
|
| 230 |
+
env = VecTransposeImage(env)
|
| 231 |
+
|
| 232 |
+
return env
|
| 233 |
+
|
| 234 |
+
@abstractmethod
|
| 235 |
+
def _setup_model(self) -> None:
|
| 236 |
+
"""Create networks, buffer and optimizers."""
|
| 237 |
+
|
| 238 |
+
def set_logger(self, logger: Logger) -> None:
|
| 239 |
+
"""
|
| 240 |
+
Setter for for logger object.
|
| 241 |
+
|
| 242 |
+
.. warning::
|
| 243 |
+
|
| 244 |
+
When passing a custom logger object,
|
| 245 |
+
this will overwrite ``tensorboard_log`` and ``verbose`` settings
|
| 246 |
+
passed to the constructor.
|
| 247 |
+
"""
|
| 248 |
+
self._logger = logger
|
| 249 |
+
# User defined logger
|
| 250 |
+
self._custom_logger = True
|
| 251 |
+
|
| 252 |
+
@property
|
| 253 |
+
def logger(self) -> Logger:
|
| 254 |
+
"""Getter for the logger object."""
|
| 255 |
+
return self._logger
|
| 256 |
+
|
| 257 |
+
def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
|
| 258 |
+
"""
|
| 259 |
+
Return the environment that will be used for evaluation.
|
| 260 |
+
|
| 261 |
+
:param eval_env:)
|
| 262 |
+
:return:
|
| 263 |
+
"""
|
| 264 |
+
if eval_env is None:
|
| 265 |
+
eval_env = self.eval_env
|
| 266 |
+
|
| 267 |
+
if eval_env is not None:
|
| 268 |
+
eval_env = self._wrap_env(eval_env, self.verbose)
|
| 269 |
+
assert eval_env.num_envs == 1
|
| 270 |
+
return eval_env
|
| 271 |
+
|
| 272 |
+
def _setup_lr_schedule(self) -> None:
|
| 273 |
+
"""Transform to callable if needed."""
|
| 274 |
+
self.lr_schedule = get_schedule_fn(self.learning_rate)
|
| 275 |
+
|
| 276 |
+
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
|
| 277 |
+
"""
|
| 278 |
+
Compute current progress remaining (starts from 1 and ends to 0)
|
| 279 |
+
|
| 280 |
+
:param num_timesteps: current number of timesteps
|
| 281 |
+
:param total_timesteps:
|
| 282 |
+
"""
|
| 283 |
+
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
|
| 284 |
+
|
| 285 |
+
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
|
| 286 |
+
"""
|
| 287 |
+
Update the optimizers learning rate using the current learning rate schedule
|
| 288 |
+
and the current progress remaining (from 1 to 0).
|
| 289 |
+
|
| 290 |
+
:param optimizers:
|
| 291 |
+
An optimizer or a list of optimizers.
|
| 292 |
+
"""
|
| 293 |
+
# Log the current learning rate
|
| 294 |
+
self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
|
| 295 |
+
|
| 296 |
+
if not isinstance(optimizers, list):
|
| 297 |
+
optimizers = [optimizers]
|
| 298 |
+
for optimizer in optimizers:
|
| 299 |
+
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
|
| 300 |
+
|
| 301 |
+
def _excluded_save_params(self) -> List[str]:
|
| 302 |
+
"""
|
| 303 |
+
Returns the names of the parameters that should be excluded from being
|
| 304 |
+
saved by pickling. E.g. replay buffers are skipped by default
|
| 305 |
+
as they take up a lot of space. PyTorch variables should be excluded
|
| 306 |
+
with this so they can be stored with ``th.save``.
|
| 307 |
+
|
| 308 |
+
:return: List of parameters that should be excluded from being saved with pickle.
|
| 309 |
+
"""
|
| 310 |
+
return [
|
| 311 |
+
"policy",
|
| 312 |
+
"device",
|
| 313 |
+
"env",
|
| 314 |
+
"eval_env",
|
| 315 |
+
"replay_buffer",
|
| 316 |
+
"rollout_buffer",
|
| 317 |
+
"_vec_normalize_env",
|
| 318 |
+
"_episode_storage",
|
| 319 |
+
"_logger",
|
| 320 |
+
"_custom_logger",
|
| 321 |
+
]
|
| 322 |
+
|
| 323 |
+
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
|
| 324 |
+
"""
|
| 325 |
+
Get a policy class from its name representation.
|
| 326 |
+
|
| 327 |
+
The goal here is to standardize policy naming, e.g.
|
| 328 |
+
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
|
| 329 |
+
and they receive respective policies that work for them.
|
| 330 |
+
|
| 331 |
+
:param policy_name: Alias of the policy
|
| 332 |
+
:return: A policy class (type)
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
if policy_name in self.policy_aliases:
|
| 336 |
+
return self.policy_aliases[policy_name]
|
| 337 |
+
else:
|
| 338 |
+
raise ValueError(f"Policy {policy_name} unknown")
|
| 339 |
+
|
| 340 |
+
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
| 341 |
+
"""
|
| 342 |
+
Get the name of the torch variables that will be saved with
|
| 343 |
+
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
|
| 344 |
+
pickling strategy. This is to handle device placement correctly.
|
| 345 |
+
|
| 346 |
+
Names can point to specific variables under classes, e.g.
|
| 347 |
+
"policy.optimizer" would point to ``optimizer`` object of ``self.policy``
|
| 348 |
+
if this object.
|
| 349 |
+
|
| 350 |
+
:return:
|
| 351 |
+
List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
|
| 352 |
+
and list of other Torch variables to store with ``th.save``.
|
| 353 |
+
"""
|
| 354 |
+
state_dicts = ["policy"]
|
| 355 |
+
|
| 356 |
+
return state_dicts, []
|
| 357 |
+
|
| 358 |
+
def _init_callback(
|
| 359 |
+
self,
|
| 360 |
+
callback: MaybeCallback,
|
| 361 |
+
eval_env: Optional[VecEnv] = None,
|
| 362 |
+
eval_freq: int = 10000,
|
| 363 |
+
n_eval_episodes: int = 5,
|
| 364 |
+
log_path: Optional[str] = None,
|
| 365 |
+
) -> BaseCallback:
|
| 366 |
+
"""
|
| 367 |
+
:param callback: Callback(s) called at every step with state of the algorithm.
|
| 368 |
+
:param eval_freq: How many steps between evaluations; if None, do not evaluate.
|
| 369 |
+
:param n_eval_episodes: How many episodes to play per evaluation
|
| 370 |
+
:param n_eval_episodes: Number of episodes to rollout during evaluation.
|
| 371 |
+
:param log_path: Path to a folder where the evaluations will be saved
|
| 372 |
+
:return: A hybrid callback calling `callback` and performing evaluation.
|
| 373 |
+
"""
|
| 374 |
+
# Convert a list of callbacks into a callback
|
| 375 |
+
if isinstance(callback, list):
|
| 376 |
+
callback = CallbackList(callback)
|
| 377 |
+
|
| 378 |
+
# Convert functional callback to object
|
| 379 |
+
if not isinstance(callback, BaseCallback):
|
| 380 |
+
callback = ConvertCallback(callback)
|
| 381 |
+
|
| 382 |
+
# Create eval callback in charge of the evaluation
|
| 383 |
+
if eval_env is not None:
|
| 384 |
+
eval_callback = EvalCallback(
|
| 385 |
+
eval_env,
|
| 386 |
+
best_model_save_path=log_path,
|
| 387 |
+
log_path=log_path,
|
| 388 |
+
eval_freq=eval_freq,
|
| 389 |
+
n_eval_episodes=n_eval_episodes,
|
| 390 |
+
)
|
| 391 |
+
callback = CallbackList([callback, eval_callback])
|
| 392 |
+
|
| 393 |
+
callback.init_callback(self)
|
| 394 |
+
return callback
|
| 395 |
+
|
| 396 |
+
def _setup_learn(
|
| 397 |
+
self,
|
| 398 |
+
total_timesteps: int,
|
| 399 |
+
eval_env: Optional[GymEnv],
|
| 400 |
+
callback: MaybeCallback = None,
|
| 401 |
+
eval_freq: int = 10000,
|
| 402 |
+
n_eval_episodes: int = 5,
|
| 403 |
+
log_path: Optional[str] = None,
|
| 404 |
+
reset_num_timesteps: bool = True,
|
| 405 |
+
tb_log_name: str = "run",
|
| 406 |
+
) -> Tuple[int, BaseCallback]:
|
| 407 |
+
"""
|
| 408 |
+
Initialize different variables needed for training.
|
| 409 |
+
|
| 410 |
+
:param total_timesteps: The total number of samples (env steps) to train on
|
| 411 |
+
:param eval_env: Environment to use for evaluation.
|
| 412 |
+
:param callback: Callback(s) called at every step with state of the algorithm.
|
| 413 |
+
:param eval_freq: How many steps between evaluations
|
| 414 |
+
:param n_eval_episodes: How many episodes to play per evaluation
|
| 415 |
+
:param log_path: Path to a folder where the evaluations will be saved
|
| 416 |
+
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
|
| 417 |
+
:param tb_log_name: the name of the run for tensorboard log
|
| 418 |
+
:return:
|
| 419 |
+
"""
|
| 420 |
+
self.start_time = time.time()
|
| 421 |
+
|
| 422 |
+
if self.ep_info_buffer is None or reset_num_timesteps:
|
| 423 |
+
# Initialize buffers if they don't exist, or reinitialize if resetting counters
|
| 424 |
+
self.ep_info_buffer = deque(maxlen=100)
|
| 425 |
+
self.ep_success_buffer = deque(maxlen=100)
|
| 426 |
+
|
| 427 |
+
if self.action_noise is not None:
|
| 428 |
+
self.action_noise.reset()
|
| 429 |
+
|
| 430 |
+
if reset_num_timesteps:
|
| 431 |
+
self.num_timesteps = 0
|
| 432 |
+
self._episode_num = 0
|
| 433 |
+
else:
|
| 434 |
+
# Make sure training timesteps are ahead of the internal counter
|
| 435 |
+
total_timesteps += self.num_timesteps
|
| 436 |
+
self._total_timesteps = total_timesteps
|
| 437 |
+
self._num_timesteps_at_start = self.num_timesteps
|
| 438 |
+
|
| 439 |
+
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
| 440 |
+
if reset_num_timesteps or self._last_obs is None:
|
| 441 |
+
self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
|
| 442 |
+
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
|
| 443 |
+
# Retrieve unnormalized observation for saving into the buffer
|
| 444 |
+
if self._vec_normalize_env is not None:
|
| 445 |
+
self._last_original_obs = self._vec_normalize_env.get_original_obs()
|
| 446 |
+
|
| 447 |
+
if eval_env is not None and self.seed is not None:
|
| 448 |
+
eval_env.seed(self.seed)
|
| 449 |
+
|
| 450 |
+
eval_env = self._get_eval_env(eval_env)
|
| 451 |
+
|
| 452 |
+
# Configure logger's outputs if no logger was passed
|
| 453 |
+
if not self._custom_logger:
|
| 454 |
+
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
| 455 |
+
|
| 456 |
+
# Create eval callback if needed
|
| 457 |
+
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
|
| 458 |
+
|
| 459 |
+
return total_timesteps, callback
|
| 460 |
+
|
| 461 |
+
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
| 462 |
+
"""
|
| 463 |
+
Retrieve reward, episode length, episode success and update the buffer
|
| 464 |
+
if using Monitor wrapper or a GoalEnv.
|
| 465 |
+
|
| 466 |
+
:param infos: List of additional information about the transition.
|
| 467 |
+
:param dones: Termination signals
|
| 468 |
+
"""
|
| 469 |
+
if dones is None:
|
| 470 |
+
dones = np.array([False] * len(infos))
|
| 471 |
+
for idx, info in enumerate(infos):
|
| 472 |
+
maybe_ep_info = info.get("episode")
|
| 473 |
+
maybe_is_success = info.get("is_success")
|
| 474 |
+
if maybe_ep_info is not None:
|
| 475 |
+
self.ep_info_buffer.extend([maybe_ep_info])
|
| 476 |
+
if maybe_is_success is not None and dones[idx]:
|
| 477 |
+
self.ep_success_buffer.append(maybe_is_success)
|
| 478 |
+
|
| 479 |
+
def get_env(self) -> Optional[VecEnv]:
|
| 480 |
+
"""
|
| 481 |
+
Returns the current environment (can be None if not defined).
|
| 482 |
+
|
| 483 |
+
:return: The current environment
|
| 484 |
+
"""
|
| 485 |
+
return self.env
|
| 486 |
+
|
| 487 |
+
def get_vec_normalize_env(self) -> Optional[VecNormalize]:
|
| 488 |
+
"""
|
| 489 |
+
Return the ``VecNormalize`` wrapper of the training env
|
| 490 |
+
if it exists.
|
| 491 |
+
|
| 492 |
+
:return: The ``VecNormalize`` env.
|
| 493 |
+
"""
|
| 494 |
+
return self._vec_normalize_env
|
| 495 |
+
|
| 496 |
+
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
|
| 497 |
+
"""
|
| 498 |
+
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
| 499 |
+
Furthermore wrap any non vectorized env into a vectorized
|
| 500 |
+
checked parameters:
|
| 501 |
+
- observation_space
|
| 502 |
+
- action_space
|
| 503 |
+
|
| 504 |
+
:param env: The environment for learning a policy
|
| 505 |
+
:param force_reset: Force call to ``reset()`` before training
|
| 506 |
+
to avoid unexpected behavior.
|
| 507 |
+
See issue https://github.com/DLR-RM/stable-baselines3/issues/597
|
| 508 |
+
"""
|
| 509 |
+
# if it is not a VecEnv, make it a VecEnv
|
| 510 |
+
# and do other transformations (dict obs, image transpose) if needed
|
| 511 |
+
env = self._wrap_env(env, self.verbose)
|
| 512 |
+
# Check that the observation spaces match
|
| 513 |
+
check_for_correct_spaces(env, self.observation_space, self.action_space)
|
| 514 |
+
# Update VecNormalize object
|
| 515 |
+
# otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637
|
| 516 |
+
self._vec_normalize_env = unwrap_vec_normalize(env)
|
| 517 |
+
|
| 518 |
+
# Discard `_last_obs`, this will force the env to reset before training
|
| 519 |
+
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
|
| 520 |
+
if force_reset:
|
| 521 |
+
self._last_obs = None
|
| 522 |
+
|
| 523 |
+
self.n_envs = env.num_envs
|
| 524 |
+
self.env = env
|
| 525 |
+
|
| 526 |
+
@abstractmethod
|
| 527 |
+
def learn(
|
| 528 |
+
self,
|
| 529 |
+
total_timesteps: int,
|
| 530 |
+
callback: MaybeCallback = None,
|
| 531 |
+
log_interval: int = 100,
|
| 532 |
+
tb_log_name: str = "run",
|
| 533 |
+
eval_env: Optional[GymEnv] = None,
|
| 534 |
+
eval_freq: int = -1,
|
| 535 |
+
n_eval_episodes: int = 5,
|
| 536 |
+
eval_log_path: Optional[str] = None,
|
| 537 |
+
reset_num_timesteps: bool = True,
|
| 538 |
+
) -> "BaseAlgorithm":
|
| 539 |
+
"""
|
| 540 |
+
Return a trained model.
|
| 541 |
+
|
| 542 |
+
:param total_timesteps: The total number of samples (env steps) to train on
|
| 543 |
+
:param callback: callback(s) called at every step with state of the algorithm.
|
| 544 |
+
:param log_interval: The number of timesteps before logging.
|
| 545 |
+
:param tb_log_name: the name of the run for TensorBoard logging
|
| 546 |
+
:param eval_env: Environment that will be used to evaluate the agent
|
| 547 |
+
:param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
| 548 |
+
:param n_eval_episodes: Number of episode to evaluate the agent
|
| 549 |
+
:param eval_log_path: Path to a folder where the evaluations will be saved
|
| 550 |
+
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
|
| 551 |
+
:return: the trained model
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def predict(
|
| 555 |
+
self,
|
| 556 |
+
observation: np.ndarray,
|
| 557 |
+
state: Optional[Tuple[np.ndarray, ...]] = None,
|
| 558 |
+
episode_start: Optional[np.ndarray] = None,
|
| 559 |
+
deterministic: bool = False,
|
| 560 |
+
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
| 561 |
+
"""
|
| 562 |
+
Get the policy action from an observation (and optional hidden state).
|
| 563 |
+
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
| 564 |
+
|
| 565 |
+
:param observation: the input observation
|
| 566 |
+
:param state: The last hidden states (can be None, used in recurrent policies)
|
| 567 |
+
:param episode_start: The last masks (can be None, used in recurrent policies)
|
| 568 |
+
this correspond to beginning of episodes,
|
| 569 |
+
where the hidden states of the RNN must be reset.
|
| 570 |
+
:param deterministic: Whether or not to return deterministic actions.
|
| 571 |
+
:return: the model's action and the next hidden state
|
| 572 |
+
(used in recurrent policies)
|
| 573 |
+
"""
|
| 574 |
+
return self.policy.predict(observation, state, episode_start, deterministic)
|
| 575 |
+
|
| 576 |
+
def set_random_seed(self, seed: Optional[int] = None) -> None:
|
| 577 |
+
"""
|
| 578 |
+
Set the seed of the pseudo-random generators
|
| 579 |
+
(python, numpy, pytorch, gym, action_space)
|
| 580 |
+
|
| 581 |
+
:param seed:
|
| 582 |
+
"""
|
| 583 |
+
if seed is None:
|
| 584 |
+
return
|
| 585 |
+
set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
|
| 586 |
+
self.action_space.seed(seed)
|
| 587 |
+
if self.env is not None:
|
| 588 |
+
self.env.seed(seed)
|
| 589 |
+
if self.eval_env is not None:
|
| 590 |
+
self.eval_env.seed(seed)
|
| 591 |
+
|
| 592 |
+
def set_parameters(
|
| 593 |
+
self,
|
| 594 |
+
load_path_or_dict: Union[str, Dict[str, Dict]],
|
| 595 |
+
exact_match: bool = True,
|
| 596 |
+
device: Union[th.device, str] = "auto",
|
| 597 |
+
) -> None:
|
| 598 |
+
"""
|
| 599 |
+
Load parameters from a given zip-file or a nested dictionary containing parameters for
|
| 600 |
+
different modules (see ``get_parameters``).
|
| 601 |
+
|
| 602 |
+
:param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested
|
| 603 |
+
dictionary containing nn.Module parameters used by the policy. The dictionary maps
|
| 604 |
+
object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``.
|
| 605 |
+
:param exact_match: If True, the given parameters should include parameters for each
|
| 606 |
+
module and each of their parameters, otherwise raises an Exception. If set to False, this
|
| 607 |
+
can be used to update only specific parameters.
|
| 608 |
+
:param device: Device on which the code should run.
|
| 609 |
+
"""
|
| 610 |
+
params = None
|
| 611 |
+
if isinstance(load_path_or_dict, dict):
|
| 612 |
+
params = load_path_or_dict
|
| 613 |
+
else:
|
| 614 |
+
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
|
| 615 |
+
|
| 616 |
+
# Keep track which objects were updated.
|
| 617 |
+
# `_get_torch_save_params` returns [params, other_pytorch_variables].
|
| 618 |
+
# We are only interested in former here.
|
| 619 |
+
objects_needing_update = set(self._get_torch_save_params()[0])
|
| 620 |
+
updated_objects = set()
|
| 621 |
+
|
| 622 |
+
for name in params:
|
| 623 |
+
attr = None
|
| 624 |
+
try:
|
| 625 |
+
attr = recursive_getattr(self, name)
|
| 626 |
+
except Exception:
|
| 627 |
+
# What errors recursive_getattr could throw? KeyError, but
|
| 628 |
+
# possible something else too (e.g. if key is an int?).
|
| 629 |
+
# Catch anything for now.
|
| 630 |
+
raise ValueError(f"Key {name} is an invalid object name.")
|
| 631 |
+
|
| 632 |
+
if isinstance(attr, th.optim.Optimizer):
|
| 633 |
+
# Optimizers do not support "strict" keyword...
|
| 634 |
+
# Seems like they will just replace the whole
|
| 635 |
+
# optimizer state with the given one.
|
| 636 |
+
# On top of this, optimizer state-dict
|
| 637 |
+
# seems to change (e.g. first ``optim.step()``),
|
| 638 |
+
# which makes comparing state dictionary keys
|
| 639 |
+
# invalid (there is also a nesting of dictionaries
|
| 640 |
+
# with lists with dictionaries with ...), adding to the
|
| 641 |
+
# mess.
|
| 642 |
+
#
|
| 643 |
+
# TL;DR: We might not be able to reliably say
|
| 644 |
+
# if given state-dict is missing keys.
|
| 645 |
+
#
|
| 646 |
+
# Solution: Just load the state-dict as is, and trust
|
| 647 |
+
# the user has provided a sensible state dictionary.
|
| 648 |
+
attr.load_state_dict(params[name])
|
| 649 |
+
else:
|
| 650 |
+
# Assume attr is th.nn.Module
|
| 651 |
+
attr.load_state_dict(params[name], strict=exact_match)
|
| 652 |
+
updated_objects.add(name)
|
| 653 |
+
|
| 654 |
+
if exact_match and updated_objects != objects_needing_update:
|
| 655 |
+
raise ValueError(
|
| 656 |
+
"Names of parameters do not match agents' parameters: "
|
| 657 |
+
f"expected {objects_needing_update}, got {updated_objects}"
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
@classmethod
|
| 661 |
+
def load(
|
| 662 |
+
cls,
|
| 663 |
+
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
| 664 |
+
env: Optional[GymEnv] = None,
|
| 665 |
+
device: Union[th.device, str] = "auto",
|
| 666 |
+
custom_objects: Optional[Dict[str, Any]] = None,
|
| 667 |
+
print_system_info: bool = False,
|
| 668 |
+
force_reset: bool = True,
|
| 669 |
+
check_obs_space: bool = True,
|
| 670 |
+
force_load: bool = False,
|
| 671 |
+
**kwargs,
|
| 672 |
+
) -> "BaseAlgorithm":
|
| 673 |
+
"""
|
| 674 |
+
Load the model from a zip-file.
|
| 675 |
+
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
|
| 676 |
+
For an in-place load use ``set_parameters`` instead.
|
| 677 |
+
|
| 678 |
+
:param path: path to the file (or a file-like) where to
|
| 679 |
+
load the agent from
|
| 680 |
+
:param env: the new environment to run the loaded model on
|
| 681 |
+
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
| 682 |
+
:param device: Device on which the code should run.
|
| 683 |
+
:param custom_objects: Dictionary of objects to replace
|
| 684 |
+
upon loading. If a variable is present in this dictionary as a
|
| 685 |
+
key, it will not be deserialized and the corresponding item
|
| 686 |
+
will be used instead. Similar to custom_objects in
|
| 687 |
+
``keras.models.load_model``. Useful when you have an object in
|
| 688 |
+
file that can not be deserialized.
|
| 689 |
+
:param print_system_info: Whether to print system info from the saved model
|
| 690 |
+
and the current system info (useful to debug loading issues)
|
| 691 |
+
:param force_reset: Force call to ``reset()`` before training
|
| 692 |
+
to avoid unexpected behavior.
|
| 693 |
+
See https://github.com/DLR-RM/stable-baselines3/issues/597
|
| 694 |
+
:param kwargs: extra arguments to change the model when loading
|
| 695 |
+
:return: new model instance with loaded parameters
|
| 696 |
+
"""
|
| 697 |
+
if print_system_info:
|
| 698 |
+
print("== CURRENT SYSTEM INFO ==")
|
| 699 |
+
get_system_info()
|
| 700 |
+
|
| 701 |
+
data, params, pytorch_variables = load_from_zip_file(
|
| 702 |
+
path, device=device, custom_objects=custom_objects, print_system_info=print_system_info
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
# Remove stored device information and replace with ours
|
| 706 |
+
if "policy_kwargs" in data:
|
| 707 |
+
if "device" in data["policy_kwargs"]:
|
| 708 |
+
del data["policy_kwargs"]["device"]
|
| 709 |
+
|
| 710 |
+
if not force_load:
|
| 711 |
+
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
|
| 712 |
+
raise ValueError(
|
| 713 |
+
f"The specified policy kwargs do not equal the stored policy kwargs."
|
| 714 |
+
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
if "observation_space" not in data or "action_space" not in data:
|
| 718 |
+
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
|
| 719 |
+
|
| 720 |
+
if env is not None:
|
| 721 |
+
# Wrap first if needed
|
| 722 |
+
env = cls._wrap_env(env, data["verbose"])
|
| 723 |
+
# Check if given env is valid
|
| 724 |
+
if check_obs_space:
|
| 725 |
+
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
|
| 726 |
+
# Discard `_last_obs`, this will force the env to reset before training
|
| 727 |
+
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
|
| 728 |
+
if force_reset and data is not None:
|
| 729 |
+
data["_last_obs"] = None
|
| 730 |
+
else:
|
| 731 |
+
# Use stored env, if one exists. If not, continue as is (can be used for predict)
|
| 732 |
+
if "env" in data:
|
| 733 |
+
env = data["env"]
|
| 734 |
+
|
| 735 |
+
# noinspection PyArgumentList
|
| 736 |
+
model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
|
| 737 |
+
policy=data["policy_class"],
|
| 738 |
+
env=env,
|
| 739 |
+
device=device,
|
| 740 |
+
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# load parameters
|
| 744 |
+
if not force_load:
|
| 745 |
+
model.__dict__.update(data) # TODO: Helin cancelled this
|
| 746 |
+
model.__dict__.update(kwargs)
|
| 747 |
+
model._setup_model()
|
| 748 |
+
|
| 749 |
+
# put state_dicts back in place
|
| 750 |
+
model.set_parameters(params, exact_match=True, device=device)
|
| 751 |
+
|
| 752 |
+
# put other pytorch variables back in place
|
| 753 |
+
if pytorch_variables is not None:
|
| 754 |
+
for name in pytorch_variables:
|
| 755 |
+
# Skip if PyTorch variable was not defined (to ensure backward compatibility).
|
| 756 |
+
# This happens when using SAC/TQC.
|
| 757 |
+
# SAC has an entropy coefficient which can be fixed or optimized.
|
| 758 |
+
# If it is optimized, an additional PyTorch variable `log_ent_coef` is defined,
|
| 759 |
+
# otherwise it is initialized to `None`.
|
| 760 |
+
if pytorch_variables[name] is None:
|
| 761 |
+
continue
|
| 762 |
+
# Set the data attribute directly to avoid issue when using optimizers
|
| 763 |
+
# See https://github.com/DLR-RM/stable-baselines3/issues/391
|
| 764 |
+
recursive_setattr(model, name + ".data", pytorch_variables[name].data)
|
| 765 |
+
|
| 766 |
+
# Sample gSDE exploration matrix, so it uses the right device
|
| 767 |
+
# see issue #44
|
| 768 |
+
if model.use_sde:
|
| 769 |
+
model.policy.reset_noise() # pytype: disable=attribute-error
|
| 770 |
+
return model
|
| 771 |
+
|
| 772 |
+
def get_parameters(self) -> Dict[str, Dict]:
|
| 773 |
+
"""
|
| 774 |
+
Return the parameters of the agent. This includes parameters from different networks, e.g.
|
| 775 |
+
critics (value functions) and policies (pi functions).
|
| 776 |
+
|
| 777 |
+
:return: Mapping of from names of the objects to PyTorch state-dicts.
|
| 778 |
+
"""
|
| 779 |
+
state_dicts_names, _ = self._get_torch_save_params()
|
| 780 |
+
params = {}
|
| 781 |
+
for name in state_dicts_names:
|
| 782 |
+
attr = recursive_getattr(self, name)
|
| 783 |
+
# Retrieve state dict
|
| 784 |
+
params[name] = attr.state_dict()
|
| 785 |
+
return params
|
| 786 |
+
|
| 787 |
+
def save(
|
| 788 |
+
self,
|
| 789 |
+
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
| 790 |
+
exclude: Optional[Iterable[str]] = None,
|
| 791 |
+
include: Optional[Iterable[str]] = None,
|
| 792 |
+
) -> None:
|
| 793 |
+
"""
|
| 794 |
+
Save all the attributes of the object and the model parameters in a zip-file.
|
| 795 |
+
|
| 796 |
+
:param path: path to the file where the rl agent should be saved
|
| 797 |
+
:param exclude: name of parameters that should be excluded in addition to the default ones
|
| 798 |
+
:param include: name of parameters that might be excluded but should be included anyway
|
| 799 |
+
"""
|
| 800 |
+
# Copy parameter list so we don't mutate the original dict
|
| 801 |
+
data = self.__dict__.copy()
|
| 802 |
+
|
| 803 |
+
# Exclude is union of specified parameters (if any) and standard exclusions
|
| 804 |
+
if exclude is None:
|
| 805 |
+
exclude = []
|
| 806 |
+
exclude = set(exclude).union(self._excluded_save_params())
|
| 807 |
+
|
| 808 |
+
# Do not exclude params if they are specifically included
|
| 809 |
+
if include is not None:
|
| 810 |
+
exclude = exclude.difference(include)
|
| 811 |
+
|
| 812 |
+
state_dicts_names, torch_variable_names = self._get_torch_save_params()
|
| 813 |
+
all_pytorch_variables = state_dicts_names + torch_variable_names
|
| 814 |
+
for torch_var in all_pytorch_variables:
|
| 815 |
+
# We need to get only the name of the top most module as we'll remove that
|
| 816 |
+
var_name = torch_var.split(".")[0]
|
| 817 |
+
# Any params that are in the save vars must not be saved by data
|
| 818 |
+
exclude.add(var_name)
|
| 819 |
+
|
| 820 |
+
# Remove parameter entries of parameters which are to be excluded
|
| 821 |
+
for param_name in exclude:
|
| 822 |
+
data.pop(param_name, None)
|
| 823 |
+
|
| 824 |
+
# Build dict of torch variables
|
| 825 |
+
pytorch_variables = None
|
| 826 |
+
if torch_variable_names is not None:
|
| 827 |
+
pytorch_variables = {}
|
| 828 |
+
for name in torch_variable_names:
|
| 829 |
+
attr = recursive_getattr(self, name)
|
| 830 |
+
pytorch_variables[name] = attr
|
| 831 |
+
|
| 832 |
+
# Build dict of state_dicts
|
| 833 |
+
params_to_save = self.get_parameters()
|
| 834 |
+
|
| 835 |
+
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
|
dexart-release/stable_baselines3/common/buffers.py
ADDED
|
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Any, Dict, Generator, List, Optional, Union
|
| 4 |
+
import stable_baselines3.pickle_utils as pickle_utils
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
from gym import spaces
|
| 8 |
+
|
| 9 |
+
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
|
| 10 |
+
from stable_baselines3.common.type_aliases import (
|
| 11 |
+
DictReplayBufferSamples,
|
| 12 |
+
DictRolloutBufferSamples,
|
| 13 |
+
DictSSLRolloutBufferSamples,
|
| 14 |
+
ReplayBufferSamples,
|
| 15 |
+
RolloutBufferSamples,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from stable_baselines3.common.vec_env import VecNormalize
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
# Check memory used by replay buffer when possible
|
| 22 |
+
import psutil
|
| 23 |
+
except ImportError:
|
| 24 |
+
psutil = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BaseBuffer(ABC):
|
| 28 |
+
"""
|
| 29 |
+
Base class that represent a buffer (rollout or replay)
|
| 30 |
+
|
| 31 |
+
:param buffer_size: Max number of element in the buffer
|
| 32 |
+
:param observation_space: Observation space
|
| 33 |
+
:param action_space: Action space
|
| 34 |
+
:param device: PyTorch device
|
| 35 |
+
to which the values will be converted
|
| 36 |
+
:param n_envs: Number of parallel environments
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
buffer_size: int,
|
| 42 |
+
observation_space: spaces.Space,
|
| 43 |
+
action_space: spaces.Space,
|
| 44 |
+
device: Union[th.device, str] = "cpu",
|
| 45 |
+
n_envs: int = 1,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.buffer_size = buffer_size
|
| 49 |
+
self.observation_space = observation_space
|
| 50 |
+
self.action_space = action_space
|
| 51 |
+
self.obs_shape = get_obs_shape(observation_space)
|
| 52 |
+
|
| 53 |
+
self.action_dim = get_action_dim(action_space)
|
| 54 |
+
self.pos = 0
|
| 55 |
+
self.full = False
|
| 56 |
+
self.device = device
|
| 57 |
+
self.n_envs = n_envs
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
|
| 61 |
+
"""
|
| 62 |
+
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
|
| 63 |
+
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
|
| 64 |
+
to [n_steps * n_envs, ...] (which maintain the order)
|
| 65 |
+
|
| 66 |
+
:param arr:
|
| 67 |
+
:return:
|
| 68 |
+
"""
|
| 69 |
+
shape = arr.shape
|
| 70 |
+
if len(shape) < 3:
|
| 71 |
+
shape = shape + (1,)
|
| 72 |
+
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
|
| 73 |
+
|
| 74 |
+
def size(self) -> int:
|
| 75 |
+
"""
|
| 76 |
+
:return: The current size of the buffer
|
| 77 |
+
"""
|
| 78 |
+
if self.full:
|
| 79 |
+
return self.buffer_size
|
| 80 |
+
return self.pos
|
| 81 |
+
|
| 82 |
+
def add(self, *args, **kwargs) -> None:
|
| 83 |
+
"""
|
| 84 |
+
Add elements to the buffer.
|
| 85 |
+
"""
|
| 86 |
+
raise NotImplementedError()
|
| 87 |
+
|
| 88 |
+
def extend(self, *args, **kwargs) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Add a new batch of transitions to the buffer
|
| 91 |
+
"""
|
| 92 |
+
# Do a for loop along the batch axis
|
| 93 |
+
for data in zip(*args):
|
| 94 |
+
self.add(*data)
|
| 95 |
+
|
| 96 |
+
def reset(self) -> None:
|
| 97 |
+
"""
|
| 98 |
+
Reset the buffer.
|
| 99 |
+
"""
|
| 100 |
+
self.pos = 0
|
| 101 |
+
self.full = False
|
| 102 |
+
|
| 103 |
+
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
|
| 104 |
+
"""
|
| 105 |
+
:param batch_size: Number of element to sample
|
| 106 |
+
:param env: associated gym VecEnv
|
| 107 |
+
to normalize the observations/rewards when sampling
|
| 108 |
+
:return:
|
| 109 |
+
"""
|
| 110 |
+
upper_bound = self.buffer_size if self.full else self.pos
|
| 111 |
+
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
|
| 112 |
+
return self._get_samples(batch_inds, env=env)
|
| 113 |
+
|
| 114 |
+
@abstractmethod
|
| 115 |
+
def _get_samples(
|
| 116 |
+
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
|
| 117 |
+
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
|
| 118 |
+
"""
|
| 119 |
+
:param batch_inds:
|
| 120 |
+
:param env:
|
| 121 |
+
:return:
|
| 122 |
+
"""
|
| 123 |
+
raise NotImplementedError()
|
| 124 |
+
|
| 125 |
+
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
|
| 126 |
+
"""
|
| 127 |
+
Convert a numpy array to a PyTorch tensor.
|
| 128 |
+
Note: it copies the data by default
|
| 129 |
+
|
| 130 |
+
:param array:
|
| 131 |
+
:param copy: Whether to copy or not the data
|
| 132 |
+
(may be useful to avoid changing things be reference)
|
| 133 |
+
:return:
|
| 134 |
+
"""
|
| 135 |
+
if copy:
|
| 136 |
+
return th.tensor(array).to(self.device)
|
| 137 |
+
return th.as_tensor(array).to(self.device)
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def _normalize_obs(
|
| 141 |
+
obs: Union[np.ndarray, Dict[str, np.ndarray]],
|
| 142 |
+
env: Optional[VecNormalize] = None,
|
| 143 |
+
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
|
| 144 |
+
if env is not None:
|
| 145 |
+
return env.normalize_obs(obs)
|
| 146 |
+
return obs
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
|
| 150 |
+
if env is not None:
|
| 151 |
+
return env.normalize_reward(reward).astype(np.float32)
|
| 152 |
+
return reward
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class ExpertBuffer(BaseBuffer):
|
| 156 |
+
def __init__(self,
|
| 157 |
+
buffer_size: int,
|
| 158 |
+
observation_space: spaces.Space,
|
| 159 |
+
action_space: spaces.Space,
|
| 160 |
+
device: Union[th.device, str] = "cpu",
|
| 161 |
+
n_envs: int = 1,
|
| 162 |
+
dataset_path=''
|
| 163 |
+
):
|
| 164 |
+
super(ExpertBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
| 165 |
+
data = pickle_utils.load_data(dataset_path)
|
| 166 |
+
|
| 167 |
+
data_obs = []
|
| 168 |
+
data_action = []
|
| 169 |
+
|
| 170 |
+
self.optimize_memory_usage = False
|
| 171 |
+
|
| 172 |
+
# print(data[0])
|
| 173 |
+
|
| 174 |
+
for trajectory in data:
|
| 175 |
+
print(trajectory.keys())
|
| 176 |
+
# for k, v in trajectory.items():
|
| 177 |
+
data_obs.append(trajectory['observations'])
|
| 178 |
+
data_action.append(trajectory['actions'])
|
| 179 |
+
self.observations = np.concatenate(data_obs, axis=0)
|
| 180 |
+
self.actions = np.concatenate(data_action, axis=0)
|
| 181 |
+
|
| 182 |
+
assert len(self.observations) == len(self.actions), "Demo Dataset Error: Obs num does not match Action num."
|
| 183 |
+
print('Expert buffer info:', self.observations.shape, self.actions.shape)
|
| 184 |
+
self.buffer_size = len(self.observations)
|
| 185 |
+
self.full = True
|
| 186 |
+
|
| 187 |
+
def add(
|
| 188 |
+
self,
|
| 189 |
+
obs: np.ndarray,
|
| 190 |
+
next_obs: np.ndarray,
|
| 191 |
+
action: np.ndarray,
|
| 192 |
+
reward: np.ndarray,
|
| 193 |
+
done: np.ndarray,
|
| 194 |
+
infos: List[Dict[str, Any]],
|
| 195 |
+
) -> None:
|
| 196 |
+
assert False, "We do not expect user to use this method."
|
| 197 |
+
|
| 198 |
+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
| 199 |
+
# Sample randomly the env idx
|
| 200 |
+
data = (
|
| 201 |
+
self._normalize_obs(self.observations[batch_inds, :], env),
|
| 202 |
+
self.actions[batch_inds, :],
|
| 203 |
+
)
|
| 204 |
+
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
|
| 205 |
+
|
| 206 |
+
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
| 207 |
+
"""
|
| 208 |
+
Sample elements from the replay buffer.
|
| 209 |
+
Custom sampling when using memory efficient variant,
|
| 210 |
+
as we should not sample the element with index `self.pos`
|
| 211 |
+
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
| 212 |
+
|
| 213 |
+
:param batch_size: Number of element to sample
|
| 214 |
+
:param env: associated gym VecEnv
|
| 215 |
+
to normalize the observations/rewards when sampling
|
| 216 |
+
:return:
|
| 217 |
+
"""
|
| 218 |
+
if not self.optimize_memory_usage:
|
| 219 |
+
return super().sample(batch_size=batch_size, env=env)
|
| 220 |
+
# Do not sample the element with index `self.pos` as the transitions is invalid
|
| 221 |
+
# (we use only one array to store `obs` and `next_obs`)
|
| 222 |
+
if self.full:
|
| 223 |
+
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
|
| 224 |
+
else:
|
| 225 |
+
batch_inds = np.random.randint(0, self.pos, size=batch_size)
|
| 226 |
+
return self._get_samples(batch_inds, env=env)
|
| 227 |
+
|
| 228 |
+
def get_all_samples(self, env=None):
|
| 229 |
+
data = (
|
| 230 |
+
self._normalize_obs(self.observations, env),
|
| 231 |
+
self.actions,
|
| 232 |
+
)
|
| 233 |
+
return ReplayBufferSamples(*tuple(map(self.to_torch, data)), None, None, None)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ReplayBuffer(BaseBuffer):
|
| 237 |
+
"""
|
| 238 |
+
Replay buffer used in off-policy algorithms like SAC/TD3.
|
| 239 |
+
|
| 240 |
+
:param buffer_size: Max number of element in the buffer
|
| 241 |
+
:param observation_space: Observation space
|
| 242 |
+
:param action_space: Action space
|
| 243 |
+
:param device:
|
| 244 |
+
:param n_envs: Number of parallel environments
|
| 245 |
+
:param optimize_memory_usage: Enable a memory efficient variant
|
| 246 |
+
of the replay buffer which reduces by almost a factor two the memory used,
|
| 247 |
+
at a cost of more complexity.
|
| 248 |
+
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
|
| 249 |
+
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
| 250 |
+
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
|
| 251 |
+
separately and treat the task as infinite horizon task.
|
| 252 |
+
https://github.com/DLR-RM/stable-baselines3/issues/284
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
buffer_size: int,
|
| 258 |
+
observation_space: spaces.Space,
|
| 259 |
+
action_space: spaces.Space,
|
| 260 |
+
device: Union[th.device, str] = "cpu",
|
| 261 |
+
n_envs: int = 1,
|
| 262 |
+
optimize_memory_usage: bool = False,
|
| 263 |
+
handle_timeout_termination: bool = True,
|
| 264 |
+
):
|
| 265 |
+
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
| 266 |
+
|
| 267 |
+
# Adjust buffer size
|
| 268 |
+
self.buffer_size = max(buffer_size // n_envs, 1)
|
| 269 |
+
|
| 270 |
+
# Check that the replay buffer can fit into the memory
|
| 271 |
+
if psutil is not None:
|
| 272 |
+
mem_available = psutil.virtual_memory().available
|
| 273 |
+
|
| 274 |
+
self.optimize_memory_usage = optimize_memory_usage
|
| 275 |
+
|
| 276 |
+
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
|
| 277 |
+
|
| 278 |
+
if optimize_memory_usage:
|
| 279 |
+
# `observations` contains also the next observation
|
| 280 |
+
self.next_observations = None
|
| 281 |
+
else:
|
| 282 |
+
self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
|
| 283 |
+
|
| 284 |
+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
|
| 285 |
+
|
| 286 |
+
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 287 |
+
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 288 |
+
# Handle timeouts termination properly if needed
|
| 289 |
+
# see https://github.com/DLR-RM/stable-baselines3/issues/284
|
| 290 |
+
self.handle_timeout_termination = handle_timeout_termination
|
| 291 |
+
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 292 |
+
|
| 293 |
+
if psutil is not None:
|
| 294 |
+
total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
| 295 |
+
|
| 296 |
+
if self.next_observations is not None:
|
| 297 |
+
total_memory_usage += self.next_observations.nbytes
|
| 298 |
+
|
| 299 |
+
if total_memory_usage > mem_available:
|
| 300 |
+
# Convert to GB
|
| 301 |
+
total_memory_usage /= 1e9
|
| 302 |
+
mem_available /= 1e9
|
| 303 |
+
warnings.warn(
|
| 304 |
+
"This system does not have apparently enough memory to store the complete "
|
| 305 |
+
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def add(
|
| 309 |
+
self,
|
| 310 |
+
obs: np.ndarray,
|
| 311 |
+
next_obs: np.ndarray,
|
| 312 |
+
action: np.ndarray,
|
| 313 |
+
reward: np.ndarray,
|
| 314 |
+
done: np.ndarray,
|
| 315 |
+
infos: List[Dict[str, Any]],
|
| 316 |
+
) -> None:
|
| 317 |
+
|
| 318 |
+
# Reshape needed when using multiple envs with discrete observations
|
| 319 |
+
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
| 320 |
+
if isinstance(self.observation_space, spaces.Discrete):
|
| 321 |
+
obs = obs.reshape((self.n_envs,) + self.obs_shape)
|
| 322 |
+
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
|
| 323 |
+
|
| 324 |
+
# Same, for actions
|
| 325 |
+
if isinstance(self.action_space, spaces.Discrete):
|
| 326 |
+
action = action.reshape((self.n_envs, self.action_dim))
|
| 327 |
+
|
| 328 |
+
# Copy to avoid modification by reference
|
| 329 |
+
self.observations[self.pos] = np.array(obs).copy()
|
| 330 |
+
|
| 331 |
+
if self.optimize_memory_usage:
|
| 332 |
+
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
|
| 333 |
+
else:
|
| 334 |
+
self.next_observations[self.pos] = np.array(next_obs).copy()
|
| 335 |
+
|
| 336 |
+
self.actions[self.pos] = np.array(action).copy()
|
| 337 |
+
self.rewards[self.pos] = np.array(reward).copy()
|
| 338 |
+
self.dones[self.pos] = np.array(done).copy()
|
| 339 |
+
|
| 340 |
+
if self.handle_timeout_termination:
|
| 341 |
+
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
| 342 |
+
|
| 343 |
+
self.pos += 1
|
| 344 |
+
if self.pos == self.buffer_size:
|
| 345 |
+
self.full = True
|
| 346 |
+
self.pos = 0
|
| 347 |
+
|
| 348 |
+
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
| 349 |
+
"""
|
| 350 |
+
Sample elements from the replay buffer.
|
| 351 |
+
Custom sampling when using memory efficient variant,
|
| 352 |
+
as we should not sample the element with index `self.pos`
|
| 353 |
+
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
|
| 354 |
+
|
| 355 |
+
:param batch_size: Number of element to sample
|
| 356 |
+
:param env: associated gym VecEnv
|
| 357 |
+
to normalize the observations/rewards when sampling
|
| 358 |
+
:return:
|
| 359 |
+
"""
|
| 360 |
+
if not self.optimize_memory_usage:
|
| 361 |
+
return super().sample(batch_size=batch_size, env=env)
|
| 362 |
+
# Do not sample the element with index `self.pos` as the transitions is invalid
|
| 363 |
+
# (we use only one array to store `obs` and `next_obs`)
|
| 364 |
+
if self.full:
|
| 365 |
+
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
|
| 366 |
+
else:
|
| 367 |
+
batch_inds = np.random.randint(0, self.pos, size=batch_size)
|
| 368 |
+
return self._get_samples(batch_inds, env=env)
|
| 369 |
+
|
| 370 |
+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
|
| 371 |
+
# Sample randomly the env idx
|
| 372 |
+
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
|
| 373 |
+
|
| 374 |
+
if self.optimize_memory_usage:
|
| 375 |
+
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
|
| 376 |
+
else:
|
| 377 |
+
next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
|
| 378 |
+
|
| 379 |
+
data = (
|
| 380 |
+
self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
|
| 381 |
+
self.actions[batch_inds, env_indices, :],
|
| 382 |
+
next_obs,
|
| 383 |
+
# Only use dones that are not due to timeouts
|
| 384 |
+
# deactivated by default (timeouts is initialized as an array of False)
|
| 385 |
+
(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
|
| 386 |
+
self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
|
| 387 |
+
)
|
| 388 |
+
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class RolloutBuffer(BaseBuffer):
|
| 392 |
+
"""
|
| 393 |
+
Rollout buffer used in on-policy algorithms like A2C/PPO.
|
| 394 |
+
It corresponds to ``buffer_size`` transitions collected
|
| 395 |
+
using the current policy.
|
| 396 |
+
This experience will be discarded after the policy update.
|
| 397 |
+
In order to use PPO objective, we also store the current value of each state
|
| 398 |
+
and the log probability of each taken action.
|
| 399 |
+
|
| 400 |
+
The term rollout here refers to the model-free notion and should not
|
| 401 |
+
be used with the concept of rollout used in model-based RL or planning.
|
| 402 |
+
Hence, it is only involved in policy and value function training but not action selection.
|
| 403 |
+
|
| 404 |
+
:param buffer_size: Max number of element in the buffer
|
| 405 |
+
:param observation_space: Observation space
|
| 406 |
+
:param action_space: Action space
|
| 407 |
+
:param device:
|
| 408 |
+
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
| 409 |
+
Equivalent to classic advantage when set to 1.
|
| 410 |
+
:param gamma: Discount factor
|
| 411 |
+
:param n_envs: Number of parallel environments
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
def __init__(
|
| 415 |
+
self,
|
| 416 |
+
buffer_size: int,
|
| 417 |
+
observation_space: spaces.Space,
|
| 418 |
+
action_space: spaces.Space,
|
| 419 |
+
device: Union[th.device, str] = "cpu",
|
| 420 |
+
gae_lambda: float = 1,
|
| 421 |
+
gamma: float = 0.99,
|
| 422 |
+
n_envs: int = 1,
|
| 423 |
+
):
|
| 424 |
+
|
| 425 |
+
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
| 426 |
+
self.gae_lambda = gae_lambda
|
| 427 |
+
self.gamma = gamma
|
| 428 |
+
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
| 429 |
+
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
|
| 430 |
+
self.generator_ready = False
|
| 431 |
+
self.reset()
|
| 432 |
+
|
| 433 |
+
def reset(self) -> None:
|
| 434 |
+
|
| 435 |
+
self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
|
| 436 |
+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
| 437 |
+
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 438 |
+
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 439 |
+
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 440 |
+
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 441 |
+
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 442 |
+
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 443 |
+
self.generator_ready = False
|
| 444 |
+
super().reset()
|
| 445 |
+
|
| 446 |
+
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
|
| 447 |
+
"""
|
| 448 |
+
Post-processing step: compute the lambda-return (TD(lambda) estimate)
|
| 449 |
+
and GAE(lambda) advantage.
|
| 450 |
+
|
| 451 |
+
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
| 452 |
+
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
|
| 453 |
+
where R is the sum of discounted reward with value bootstrap
|
| 454 |
+
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
|
| 455 |
+
|
| 456 |
+
The TD(lambda) estimator has also two special cases:
|
| 457 |
+
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
|
| 458 |
+
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
|
| 459 |
+
|
| 460 |
+
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
|
| 461 |
+
|
| 462 |
+
:param last_values: state value estimation for the last step (one for each env)
|
| 463 |
+
:param dones: if the last step was a terminal step (one bool for each env).
|
| 464 |
+
"""
|
| 465 |
+
# Convert to numpy
|
| 466 |
+
last_values = last_values.clone().cpu().numpy().flatten()
|
| 467 |
+
|
| 468 |
+
last_gae_lam = 0
|
| 469 |
+
for step in reversed(range(self.buffer_size)):
|
| 470 |
+
if step == self.buffer_size - 1:
|
| 471 |
+
next_non_terminal = 1.0 - dones
|
| 472 |
+
next_values = last_values
|
| 473 |
+
else:
|
| 474 |
+
next_non_terminal = 1.0 - self.episode_starts[step + 1]
|
| 475 |
+
next_values = self.values[step + 1]
|
| 476 |
+
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
|
| 477 |
+
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
| 478 |
+
self.advantages[step] = last_gae_lam
|
| 479 |
+
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
|
| 480 |
+
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
|
| 481 |
+
self.returns = self.advantages + self.values
|
| 482 |
+
|
| 483 |
+
def add(
|
| 484 |
+
self,
|
| 485 |
+
obs: np.ndarray,
|
| 486 |
+
action: np.ndarray,
|
| 487 |
+
reward: np.ndarray,
|
| 488 |
+
episode_start: np.ndarray,
|
| 489 |
+
value: th.Tensor,
|
| 490 |
+
log_prob: th.Tensor,
|
| 491 |
+
) -> None:
|
| 492 |
+
"""
|
| 493 |
+
:param obs: Observation
|
| 494 |
+
:param action: Action
|
| 495 |
+
:param reward:
|
| 496 |
+
:param episode_start: Start of episode signal.
|
| 497 |
+
:param value: estimated value of the current state
|
| 498 |
+
following the current policy.
|
| 499 |
+
:param log_prob: log probability of the action
|
| 500 |
+
following the current policy.
|
| 501 |
+
"""
|
| 502 |
+
if len(log_prob.shape) == 0:
|
| 503 |
+
# Reshape 0-d tensor to avoid error
|
| 504 |
+
log_prob = log_prob.reshape(-1, 1)
|
| 505 |
+
|
| 506 |
+
# Reshape needed when using multiple envs with discrete observations
|
| 507 |
+
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
| 508 |
+
if isinstance(self.observation_space, spaces.Discrete):
|
| 509 |
+
obs = obs.reshape((self.n_envs,) + self.obs_shape)
|
| 510 |
+
|
| 511 |
+
self.observations[self.pos] = np.array(obs).copy()
|
| 512 |
+
self.actions[self.pos] = np.array(action).copy()
|
| 513 |
+
self.rewards[self.pos] = np.array(reward).copy()
|
| 514 |
+
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
| 515 |
+
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
| 516 |
+
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
| 517 |
+
self.pos += 1
|
| 518 |
+
if self.pos == self.buffer_size:
|
| 519 |
+
self.full = True
|
| 520 |
+
|
| 521 |
+
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
|
| 522 |
+
assert self.full, ""
|
| 523 |
+
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
| 524 |
+
# Prepare the data
|
| 525 |
+
if not self.generator_ready:
|
| 526 |
+
|
| 527 |
+
_tensor_names = [
|
| 528 |
+
"observations",
|
| 529 |
+
"actions",
|
| 530 |
+
"values",
|
| 531 |
+
"log_probs",
|
| 532 |
+
"advantages",
|
| 533 |
+
"returns",
|
| 534 |
+
]
|
| 535 |
+
|
| 536 |
+
for tensor in _tensor_names:
|
| 537 |
+
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
| 538 |
+
self.generator_ready = True
|
| 539 |
+
|
| 540 |
+
# Return everything, don't create minibatches
|
| 541 |
+
if batch_size is None:
|
| 542 |
+
batch_size = self.buffer_size * self.n_envs
|
| 543 |
+
|
| 544 |
+
start_idx = 0
|
| 545 |
+
while start_idx < self.buffer_size * self.n_envs:
|
| 546 |
+
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
| 547 |
+
start_idx += batch_size
|
| 548 |
+
|
| 549 |
+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
|
| 550 |
+
data = (
|
| 551 |
+
self.observations[batch_inds],
|
| 552 |
+
self.actions[batch_inds],
|
| 553 |
+
self.values[batch_inds].flatten(),
|
| 554 |
+
self.log_probs[batch_inds].flatten(),
|
| 555 |
+
self.advantages[batch_inds].flatten(),
|
| 556 |
+
self.returns[batch_inds].flatten(),
|
| 557 |
+
)
|
| 558 |
+
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class DictReplayBuffer(ReplayBuffer):
|
| 562 |
+
"""
|
| 563 |
+
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
|
| 564 |
+
Extends the ReplayBuffer to use dictionary observations
|
| 565 |
+
|
| 566 |
+
:param buffer_size: Max number of element in the buffer
|
| 567 |
+
:param observation_space: Observation space
|
| 568 |
+
:param action_space: Action space
|
| 569 |
+
:param device:
|
| 570 |
+
:param n_envs: Number of parallel environments
|
| 571 |
+
:param optimize_memory_usage: Enable a memory efficient variant
|
| 572 |
+
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
|
| 573 |
+
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
|
| 574 |
+
separately and treat the task as infinite horizon task.
|
| 575 |
+
https://github.com/DLR-RM/stable-baselines3/issues/284
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
def __init__(
|
| 579 |
+
self,
|
| 580 |
+
buffer_size: int,
|
| 581 |
+
observation_space: spaces.Space,
|
| 582 |
+
action_space: spaces.Space,
|
| 583 |
+
device: Union[th.device, str] = "cpu",
|
| 584 |
+
n_envs: int = 1,
|
| 585 |
+
optimize_memory_usage: bool = False,
|
| 586 |
+
handle_timeout_termination: bool = True,
|
| 587 |
+
):
|
| 588 |
+
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
| 589 |
+
|
| 590 |
+
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
|
| 591 |
+
self.buffer_size = max(buffer_size // n_envs, 1)
|
| 592 |
+
|
| 593 |
+
# Check that the replay buffer can fit into the memory
|
| 594 |
+
if psutil is not None:
|
| 595 |
+
mem_available = psutil.virtual_memory().available
|
| 596 |
+
|
| 597 |
+
assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
|
| 598 |
+
# disabling as this adds quite a bit of complexity
|
| 599 |
+
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
|
| 600 |
+
self.optimize_memory_usage = optimize_memory_usage
|
| 601 |
+
|
| 602 |
+
self.observations = {
|
| 603 |
+
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
|
| 604 |
+
for key, _obs_shape in self.obs_shape.items()
|
| 605 |
+
}
|
| 606 |
+
self.next_observations = {
|
| 607 |
+
key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
|
| 608 |
+
for key, _obs_shape in self.obs_shape.items()
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
|
| 612 |
+
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 613 |
+
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 614 |
+
|
| 615 |
+
# Handle timeouts termination properly if needed
|
| 616 |
+
# see https://github.com/DLR-RM/stable-baselines3/issues/284
|
| 617 |
+
self.handle_timeout_termination = handle_timeout_termination
|
| 618 |
+
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 619 |
+
|
| 620 |
+
if psutil is not None:
|
| 621 |
+
obs_nbytes = 0
|
| 622 |
+
for _, obs in self.observations.items():
|
| 623 |
+
obs_nbytes += obs.nbytes
|
| 624 |
+
|
| 625 |
+
total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
|
| 626 |
+
if self.next_observations is not None:
|
| 627 |
+
next_obs_nbytes = 0
|
| 628 |
+
for _, obs in self.observations.items():
|
| 629 |
+
next_obs_nbytes += obs.nbytes
|
| 630 |
+
total_memory_usage += next_obs_nbytes
|
| 631 |
+
|
| 632 |
+
if total_memory_usage > mem_available:
|
| 633 |
+
# Convert to GB
|
| 634 |
+
total_memory_usage /= 1e9
|
| 635 |
+
mem_available /= 1e9
|
| 636 |
+
warnings.warn(
|
| 637 |
+
"This system does not have apparently enough memory to store the complete "
|
| 638 |
+
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
def add(
|
| 642 |
+
self,
|
| 643 |
+
obs: Dict[str, np.ndarray],
|
| 644 |
+
next_obs: Dict[str, np.ndarray],
|
| 645 |
+
action: np.ndarray,
|
| 646 |
+
reward: np.ndarray,
|
| 647 |
+
done: np.ndarray,
|
| 648 |
+
infos: List[Dict[str, Any]],
|
| 649 |
+
) -> None:
|
| 650 |
+
# Copy to avoid modification by reference
|
| 651 |
+
for key in self.observations.keys():
|
| 652 |
+
# Reshape needed when using multiple envs with discrete observations
|
| 653 |
+
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
| 654 |
+
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
| 655 |
+
obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
|
| 656 |
+
self.observations[key][self.pos] = np.array(obs[key])
|
| 657 |
+
|
| 658 |
+
for key in self.next_observations.keys():
|
| 659 |
+
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
| 660 |
+
next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
|
| 661 |
+
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
|
| 662 |
+
|
| 663 |
+
# Same reshape, for actions
|
| 664 |
+
if isinstance(self.action_space, spaces.Discrete):
|
| 665 |
+
action = action.reshape((self.n_envs, self.action_dim))
|
| 666 |
+
|
| 667 |
+
self.actions[self.pos] = np.array(action).copy()
|
| 668 |
+
self.rewards[self.pos] = np.array(reward).copy()
|
| 669 |
+
self.dones[self.pos] = np.array(done).copy()
|
| 670 |
+
|
| 671 |
+
if self.handle_timeout_termination:
|
| 672 |
+
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
|
| 673 |
+
|
| 674 |
+
self.pos += 1
|
| 675 |
+
if self.pos == self.buffer_size:
|
| 676 |
+
self.full = True
|
| 677 |
+
self.pos = 0
|
| 678 |
+
|
| 679 |
+
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
|
| 680 |
+
"""
|
| 681 |
+
Sample elements from the replay buffer.
|
| 682 |
+
|
| 683 |
+
:param batch_size: Number of element to sample
|
| 684 |
+
:param env: associated gym VecEnv
|
| 685 |
+
to normalize the observations/rewards when sampling
|
| 686 |
+
:return:
|
| 687 |
+
"""
|
| 688 |
+
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
|
| 689 |
+
|
| 690 |
+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
|
| 691 |
+
# Sample randomly the env idx
|
| 692 |
+
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
|
| 693 |
+
|
| 694 |
+
# Normalize if needed and remove extra dimension (we are using only one env for now)
|
| 695 |
+
obs_ = self._normalize_obs({key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}, env)
|
| 696 |
+
next_obs_ = self._normalize_obs(
|
| 697 |
+
{key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# Convert to torch tensor
|
| 701 |
+
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
|
| 702 |
+
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
|
| 703 |
+
|
| 704 |
+
return DictReplayBufferSamples(
|
| 705 |
+
observations=observations,
|
| 706 |
+
actions=self.to_torch(self.actions[batch_inds, env_indices]),
|
| 707 |
+
next_observations=next_observations,
|
| 708 |
+
# Only use dones that are not due to timeouts
|
| 709 |
+
# deactivated by default (timeouts is initialized as an array of False)
|
| 710 |
+
dones=self.to_torch(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(
|
| 711 |
+
-1, 1
|
| 712 |
+
),
|
| 713 |
+
rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)),
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class DictRolloutBuffer(RolloutBuffer):
|
| 718 |
+
"""
|
| 719 |
+
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
| 720 |
+
Extends the RolloutBuffer to use dictionary observations
|
| 721 |
+
|
| 722 |
+
It corresponds to ``buffer_size`` transitions collected
|
| 723 |
+
using the current policy.
|
| 724 |
+
This experience will be discarded after the policy update.
|
| 725 |
+
In order to use PPO objective, we also store the current value of each state
|
| 726 |
+
and the log probability of each taken action.
|
| 727 |
+
|
| 728 |
+
The term rollout here refers to the model-free notion and should not
|
| 729 |
+
be used with the concept of rollout used in model-based RL or planning.
|
| 730 |
+
Hence, it is only involved in policy and value function training but not action selection.
|
| 731 |
+
|
| 732 |
+
:param buffer_size: Max number of element in the buffer
|
| 733 |
+
:param observation_space: Observation space
|
| 734 |
+
:param action_space: Action space
|
| 735 |
+
:param device:
|
| 736 |
+
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
| 737 |
+
Equivalent to Monte-Carlo advantage estimate when set to 1.
|
| 738 |
+
:param gamma: Discount factor
|
| 739 |
+
:param n_envs: Number of parallel environments
|
| 740 |
+
"""
|
| 741 |
+
|
| 742 |
+
def __init__(
|
| 743 |
+
self,
|
| 744 |
+
buffer_size: int,
|
| 745 |
+
observation_space: spaces.Space,
|
| 746 |
+
action_space: spaces.Space,
|
| 747 |
+
device: Union[th.device, str] = "cpu",
|
| 748 |
+
gae_lambda: float = 1,
|
| 749 |
+
gamma: float = 0.99,
|
| 750 |
+
n_envs: int = 1,
|
| 751 |
+
):
|
| 752 |
+
|
| 753 |
+
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
| 754 |
+
|
| 755 |
+
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
| 756 |
+
|
| 757 |
+
self.gae_lambda = gae_lambda
|
| 758 |
+
self.gamma = gamma
|
| 759 |
+
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
| 760 |
+
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
|
| 761 |
+
self.generator_ready = False
|
| 762 |
+
self.reset()
|
| 763 |
+
|
| 764 |
+
def reset(self) -> None:
|
| 765 |
+
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
| 766 |
+
self.observations = {}
|
| 767 |
+
for key, obs_input_shape in self.obs_shape.items():
|
| 768 |
+
self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
|
| 769 |
+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
| 770 |
+
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 771 |
+
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 772 |
+
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 773 |
+
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 774 |
+
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 775 |
+
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 776 |
+
self.generator_ready = False
|
| 777 |
+
super(RolloutBuffer, self).reset()
|
| 778 |
+
|
| 779 |
+
def add(
|
| 780 |
+
self,
|
| 781 |
+
obs: Dict[str, np.ndarray],
|
| 782 |
+
action: np.ndarray,
|
| 783 |
+
reward: np.ndarray,
|
| 784 |
+
episode_start: np.ndarray,
|
| 785 |
+
value: th.Tensor,
|
| 786 |
+
log_prob: th.Tensor,
|
| 787 |
+
) -> None:
|
| 788 |
+
"""
|
| 789 |
+
:param obs: Observation
|
| 790 |
+
:param action: Action
|
| 791 |
+
:param reward:
|
| 792 |
+
:param episode_start: Start of episode signal.
|
| 793 |
+
:param value: estimated value of the current state
|
| 794 |
+
following the current policy.
|
| 795 |
+
:param log_prob: log probability of the action
|
| 796 |
+
following the current policy.
|
| 797 |
+
"""
|
| 798 |
+
if len(log_prob.shape) == 0:
|
| 799 |
+
# Reshape 0-d tensor to avoid error
|
| 800 |
+
log_prob = log_prob.reshape(-1, 1)
|
| 801 |
+
|
| 802 |
+
for key in self.observations.keys():
|
| 803 |
+
obs_ = np.array(obs[key]).copy()
|
| 804 |
+
# Reshape needed when using multiple envs with discrete observations
|
| 805 |
+
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
| 806 |
+
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
| 807 |
+
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
|
| 808 |
+
self.observations[key][self.pos] = obs_
|
| 809 |
+
|
| 810 |
+
self.actions[self.pos] = np.array(action).copy()
|
| 811 |
+
self.rewards[self.pos] = np.array(reward).copy()
|
| 812 |
+
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
| 813 |
+
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
| 814 |
+
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
| 815 |
+
self.pos += 1
|
| 816 |
+
if self.pos == self.buffer_size:
|
| 817 |
+
self.full = True
|
| 818 |
+
|
| 819 |
+
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
|
| 820 |
+
assert self.full, ""
|
| 821 |
+
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
| 822 |
+
# Prepare the data
|
| 823 |
+
if not self.generator_ready:
|
| 824 |
+
|
| 825 |
+
for key, obs in self.observations.items():
|
| 826 |
+
self.observations[key] = self.swap_and_flatten(obs)
|
| 827 |
+
|
| 828 |
+
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
|
| 829 |
+
|
| 830 |
+
for tensor in _tensor_names:
|
| 831 |
+
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
| 832 |
+
self.generator_ready = True
|
| 833 |
+
|
| 834 |
+
# Return everything, don't create minibatches
|
| 835 |
+
if batch_size is None:
|
| 836 |
+
batch_size = self.buffer_size * self.n_envs
|
| 837 |
+
|
| 838 |
+
start_idx = 0
|
| 839 |
+
while start_idx < self.buffer_size * self.n_envs:
|
| 840 |
+
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
| 841 |
+
start_idx += batch_size
|
| 842 |
+
|
| 843 |
+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples:
|
| 844 |
+
|
| 845 |
+
return DictRolloutBufferSamples(
|
| 846 |
+
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
| 847 |
+
actions=self.to_torch(self.actions[batch_inds]),
|
| 848 |
+
old_values=self.to_torch(self.values[batch_inds].flatten()),
|
| 849 |
+
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
|
| 850 |
+
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
|
| 851 |
+
returns=self.to_torch(self.returns[batch_inds].flatten()),
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
class DictSSLRolloutBuffer(RolloutBuffer):
|
| 856 |
+
"""
|
| 857 |
+
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
|
| 858 |
+
Extends the RolloutBuffer to use dictionary observations
|
| 859 |
+
|
| 860 |
+
It corresponds to ``buffer_size`` transitions collected
|
| 861 |
+
using the current policy.
|
| 862 |
+
This experience will be discarded after the policy update.
|
| 863 |
+
In order to use PPO objective, we also store the current value of each state
|
| 864 |
+
and the log probability of each taken action.
|
| 865 |
+
|
| 866 |
+
The term rollout here refers to the model-free notion and should not
|
| 867 |
+
be used with the concept of rollout used in model-based RL or planning.
|
| 868 |
+
Hence, it is only involved in policy and value function training but not action selection.
|
| 869 |
+
|
| 870 |
+
:param buffer_size: Max number of element in the buffer
|
| 871 |
+
:param observation_space: Observation space
|
| 872 |
+
:param action_space: Action space
|
| 873 |
+
:param device:
|
| 874 |
+
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
| 875 |
+
Equivalent to Monte-Carlo advantage estimate when set to 1.
|
| 876 |
+
:param gamma: Discount factor
|
| 877 |
+
:param n_envs: Number of parallel environments
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
def __init__(
|
| 881 |
+
self,
|
| 882 |
+
buffer_size: int,
|
| 883 |
+
observation_space: spaces.Space,
|
| 884 |
+
action_space: spaces.Space,
|
| 885 |
+
device: Union[th.device, str] = "cpu",
|
| 886 |
+
gae_lambda: float = 1,
|
| 887 |
+
gamma: float = 0.99,
|
| 888 |
+
n_envs: int = 1,
|
| 889 |
+
):
|
| 890 |
+
|
| 891 |
+
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
|
| 892 |
+
|
| 893 |
+
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
| 894 |
+
|
| 895 |
+
self.gae_lambda = gae_lambda
|
| 896 |
+
self.gamma = gamma
|
| 897 |
+
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
|
| 898 |
+
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
|
| 899 |
+
self.generator_ready = False
|
| 900 |
+
self.reset()
|
| 901 |
+
|
| 902 |
+
def reset(self) -> None:
|
| 903 |
+
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
|
| 904 |
+
self.observations = {}
|
| 905 |
+
for key, obs_input_shape in self.obs_shape.items():
|
| 906 |
+
self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
|
| 907 |
+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
| 908 |
+
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 909 |
+
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 910 |
+
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 911 |
+
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 912 |
+
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 913 |
+
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
| 914 |
+
self.generator_ready = False
|
| 915 |
+
super(RolloutBuffer, self).reset()
|
| 916 |
+
|
| 917 |
+
def add(
|
| 918 |
+
self,
|
| 919 |
+
obs: Dict[str, np.ndarray],
|
| 920 |
+
action: np.ndarray,
|
| 921 |
+
reward: np.ndarray,
|
| 922 |
+
episode_start: np.ndarray,
|
| 923 |
+
value: th.Tensor,
|
| 924 |
+
log_prob: th.Tensor,
|
| 925 |
+
) -> None:
|
| 926 |
+
"""
|
| 927 |
+
:param obs: Observation
|
| 928 |
+
:param action: Action
|
| 929 |
+
:param reward:
|
| 930 |
+
:param episode_start: Start of episode signal.
|
| 931 |
+
:param value: estimated value of the current state
|
| 932 |
+
following the current policy.
|
| 933 |
+
:param log_prob: log probability of the action
|
| 934 |
+
following the current policy.
|
| 935 |
+
"""
|
| 936 |
+
if len(log_prob.shape) == 0:
|
| 937 |
+
# Reshape 0-d tensor to avoid error
|
| 938 |
+
log_prob = log_prob.reshape(-1, 1)
|
| 939 |
+
|
| 940 |
+
for key in self.observations.keys():
|
| 941 |
+
obs_ = np.array(obs[key]).copy()
|
| 942 |
+
# Reshape needed when using multiple envs with discrete observations
|
| 943 |
+
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
| 944 |
+
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
|
| 945 |
+
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
|
| 946 |
+
self.observations[key][self.pos] = obs_
|
| 947 |
+
|
| 948 |
+
self.actions[self.pos] = np.array(action).copy()
|
| 949 |
+
self.rewards[self.pos] = np.array(reward).copy()
|
| 950 |
+
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
| 951 |
+
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
| 952 |
+
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
| 953 |
+
self.pos += 1
|
| 954 |
+
if self.pos == self.buffer_size:
|
| 955 |
+
self.full = True
|
| 956 |
+
|
| 957 |
+
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
|
| 958 |
+
assert self.full, ""
|
| 959 |
+
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
| 960 |
+
# Prepare the data
|
| 961 |
+
if not self.generator_ready:
|
| 962 |
+
|
| 963 |
+
for key, obs in self.observations.items():
|
| 964 |
+
self.observations[key] = self.swap_and_flatten(obs)
|
| 965 |
+
|
| 966 |
+
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
|
| 967 |
+
|
| 968 |
+
for tensor in _tensor_names:
|
| 969 |
+
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
|
| 970 |
+
self.generator_ready = True
|
| 971 |
+
|
| 972 |
+
# Return everything, don't create minibatches
|
| 973 |
+
if batch_size is None:
|
| 974 |
+
batch_size = self.buffer_size * self.n_envs
|
| 975 |
+
|
| 976 |
+
start_idx = 0
|
| 977 |
+
while start_idx < self.buffer_size * self.n_envs:
|
| 978 |
+
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
| 979 |
+
start_idx += batch_size
|
| 980 |
+
|
| 981 |
+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictSSLRolloutBufferSamples:
|
| 982 |
+
|
| 983 |
+
def get_next_obs(obs, ind, i):
|
| 984 |
+
result = {}
|
| 985 |
+
for key, obs in obs.items():
|
| 986 |
+
future_batch_inds = np.clip(ind + i, 0, len(obs)-1)
|
| 987 |
+
next_obs = self.to_torch(obs[future_batch_inds])
|
| 988 |
+
result[key] = next_obs
|
| 989 |
+
|
| 990 |
+
return result
|
| 991 |
+
|
| 992 |
+
def get_next_action(actions, ind, i):
|
| 993 |
+
future_batch_inds = np.clip(ind + i, 0, len(actions) - 1)
|
| 994 |
+
return self.to_torch(actions[future_batch_inds])
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
next_observations = [get_next_obs(self.observations, batch_inds, i) for i in range(4)]
|
| 998 |
+
next_actions = [get_next_action(self.actions, batch_inds, i) for i in range(4)]
|
| 999 |
+
|
| 1000 |
+
return DictSSLRolloutBufferSamples(
|
| 1001 |
+
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
| 1002 |
+
next_observations=next_observations,
|
| 1003 |
+
actions=self.to_torch(self.actions[batch_inds]),
|
| 1004 |
+
next_actions=next_actions,
|
| 1005 |
+
old_values=self.to_torch(self.values[batch_inds].flatten()),
|
| 1006 |
+
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
|
| 1007 |
+
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
|
| 1008 |
+
returns=self.to_torch(self.returns[batch_inds].flatten()),
|
| 1009 |
+
)
|
| 1010 |
+
|
dexart-release/stable_baselines3/common/callbacks.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import gym
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from stable_baselines3.common import base_class # pytype: disable=pyi-error
|
| 10 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
| 11 |
+
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseCallback(ABC):
|
| 15 |
+
"""
|
| 16 |
+
Base class for callback.
|
| 17 |
+
|
| 18 |
+
:param verbose:
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, verbose: int = 0):
|
| 22 |
+
super().__init__()
|
| 23 |
+
# The RL model
|
| 24 |
+
self.model = None # type: Optional[base_class.BaseAlgorithm]
|
| 25 |
+
# An alias for self.model.get_env(), the environment used for training
|
| 26 |
+
self.training_env = None # type: Union[gym.Env, VecEnv, None]
|
| 27 |
+
# Number of time the callback was called
|
| 28 |
+
self.n_calls = 0 # type: int
|
| 29 |
+
# n_envs * n times env.step() was called
|
| 30 |
+
self.num_timesteps = 0 # type: int
|
| 31 |
+
self.verbose = verbose
|
| 32 |
+
self.locals: Dict[str, Any] = {}
|
| 33 |
+
self.globals: Dict[str, Any] = {}
|
| 34 |
+
self.logger = None
|
| 35 |
+
# Sometimes, for event callback, it is useful
|
| 36 |
+
# to have access to the parent object
|
| 37 |
+
self.parent = None # type: Optional[BaseCallback]
|
| 38 |
+
|
| 39 |
+
# Type hint as string to avoid circular import
|
| 40 |
+
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
|
| 41 |
+
"""
|
| 42 |
+
Initialize the callback by saving references to the
|
| 43 |
+
RL model and the training environment for convenience.
|
| 44 |
+
"""
|
| 45 |
+
self.model = model
|
| 46 |
+
self.training_env = model.get_env()
|
| 47 |
+
self.logger = model.logger
|
| 48 |
+
self._init_callback()
|
| 49 |
+
|
| 50 |
+
def _init_callback(self) -> None:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
|
| 54 |
+
# Those are reference and will be updated automatically
|
| 55 |
+
self.locals = locals_
|
| 56 |
+
self.globals = globals_
|
| 57 |
+
self._on_training_start()
|
| 58 |
+
|
| 59 |
+
def _on_training_start(self) -> None:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
def on_rollout_start(self) -> None:
|
| 63 |
+
self._on_rollout_start()
|
| 64 |
+
|
| 65 |
+
def _on_rollout_start(self) -> None:
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def _on_step(self) -> bool:
|
| 70 |
+
"""
|
| 71 |
+
:return: If the callback returns False, training is aborted early.
|
| 72 |
+
"""
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
def on_step(self) -> bool:
|
| 76 |
+
"""
|
| 77 |
+
This method will be called by the model after each call to ``env.step()``.
|
| 78 |
+
|
| 79 |
+
For child callback (of an ``EventCallback``), this will be called
|
| 80 |
+
when the event is triggered.
|
| 81 |
+
|
| 82 |
+
:return: If the callback returns False, training is aborted early.
|
| 83 |
+
"""
|
| 84 |
+
self.n_calls += 1
|
| 85 |
+
# timesteps start at zero
|
| 86 |
+
self.num_timesteps = self.model.num_timesteps
|
| 87 |
+
|
| 88 |
+
return self._on_step()
|
| 89 |
+
|
| 90 |
+
def on_training_end(self) -> None:
|
| 91 |
+
self._on_training_end()
|
| 92 |
+
|
| 93 |
+
def _on_training_end(self) -> None:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
def on_rollout_end(self) -> None:
|
| 97 |
+
self._on_rollout_end()
|
| 98 |
+
|
| 99 |
+
def _on_rollout_end(self) -> None:
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
def update_locals(self, locals_: Dict[str, Any]) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Update the references to the local variables.
|
| 105 |
+
|
| 106 |
+
:param locals_: the local variables during rollout collection
|
| 107 |
+
"""
|
| 108 |
+
self.locals.update(locals_)
|
| 109 |
+
self.update_child_locals(locals_)
|
| 110 |
+
|
| 111 |
+
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
| 112 |
+
"""
|
| 113 |
+
Update the references to the local variables on sub callbacks.
|
| 114 |
+
|
| 115 |
+
:param locals_: the local variables during rollout collection
|
| 116 |
+
"""
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class EventCallback(BaseCallback):
|
| 121 |
+
"""
|
| 122 |
+
Base class for triggering callback on event.
|
| 123 |
+
|
| 124 |
+
:param callback: Callback that will be called
|
| 125 |
+
when an event is triggered.
|
| 126 |
+
:param verbose:
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
|
| 130 |
+
super().__init__(verbose=verbose)
|
| 131 |
+
self.callback = callback
|
| 132 |
+
# Give access to the parent
|
| 133 |
+
if callback is not None:
|
| 134 |
+
self.callback.parent = self
|
| 135 |
+
|
| 136 |
+
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
|
| 137 |
+
super().init_callback(model)
|
| 138 |
+
if self.callback is not None:
|
| 139 |
+
self.callback.init_callback(self.model)
|
| 140 |
+
|
| 141 |
+
def _on_training_start(self) -> None:
|
| 142 |
+
if self.callback is not None:
|
| 143 |
+
self.callback.on_training_start(self.locals, self.globals)
|
| 144 |
+
|
| 145 |
+
def _on_event(self) -> bool:
|
| 146 |
+
if self.callback is not None:
|
| 147 |
+
return self.callback.on_step()
|
| 148 |
+
return True
|
| 149 |
+
|
| 150 |
+
def _on_step(self) -> bool:
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
| 154 |
+
"""
|
| 155 |
+
Update the references to the local variables.
|
| 156 |
+
|
| 157 |
+
:param locals_: the local variables during rollout collection
|
| 158 |
+
"""
|
| 159 |
+
if self.callback is not None:
|
| 160 |
+
self.callback.update_locals(locals_)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class CallbackList(BaseCallback):
|
| 164 |
+
"""
|
| 165 |
+
Class for chaining callbacks.
|
| 166 |
+
|
| 167 |
+
:param callbacks: A list of callbacks that will be called
|
| 168 |
+
sequentially.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(self, callbacks: List[BaseCallback]):
|
| 172 |
+
super().__init__()
|
| 173 |
+
assert isinstance(callbacks, list)
|
| 174 |
+
self.callbacks = callbacks
|
| 175 |
+
|
| 176 |
+
def _init_callback(self) -> None:
|
| 177 |
+
for callback in self.callbacks:
|
| 178 |
+
callback.init_callback(self.model)
|
| 179 |
+
|
| 180 |
+
def _on_training_start(self) -> None:
|
| 181 |
+
for callback in self.callbacks:
|
| 182 |
+
callback.on_training_start(self.locals, self.globals)
|
| 183 |
+
|
| 184 |
+
def _on_rollout_start(self) -> None:
|
| 185 |
+
for callback in self.callbacks:
|
| 186 |
+
callback.on_rollout_start()
|
| 187 |
+
|
| 188 |
+
def _on_step(self) -> bool:
|
| 189 |
+
continue_training = True
|
| 190 |
+
for callback in self.callbacks:
|
| 191 |
+
# Return False (stop training) if at least one callback returns False
|
| 192 |
+
continue_training = callback.on_step() and continue_training
|
| 193 |
+
return continue_training
|
| 194 |
+
|
| 195 |
+
def _on_rollout_end(self) -> None:
|
| 196 |
+
for callback in self.callbacks:
|
| 197 |
+
callback.on_rollout_end()
|
| 198 |
+
|
| 199 |
+
def _on_training_end(self) -> None:
|
| 200 |
+
for callback in self.callbacks:
|
| 201 |
+
callback.on_training_end()
|
| 202 |
+
|
| 203 |
+
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
| 204 |
+
"""
|
| 205 |
+
Update the references to the local variables.
|
| 206 |
+
|
| 207 |
+
:param locals_: the local variables during rollout collection
|
| 208 |
+
"""
|
| 209 |
+
for callback in self.callbacks:
|
| 210 |
+
callback.update_locals(locals_)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class CheckpointCallback(BaseCallback):
|
| 214 |
+
"""
|
| 215 |
+
Callback for saving a model every ``save_freq`` calls
|
| 216 |
+
to ``env.step()``.
|
| 217 |
+
|
| 218 |
+
.. warning::
|
| 219 |
+
|
| 220 |
+
When using multiple environments, each call to ``env.step()``
|
| 221 |
+
will effectively correspond to ``n_envs`` steps.
|
| 222 |
+
To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
|
| 223 |
+
|
| 224 |
+
:param save_freq:
|
| 225 |
+
:param save_path: Path to the folder where the model will be saved.
|
| 226 |
+
:param name_prefix: Common prefix to the saved models
|
| 227 |
+
:param verbose:
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
|
| 231 |
+
super().__init__(verbose)
|
| 232 |
+
self.save_freq = save_freq
|
| 233 |
+
self.save_path = save_path
|
| 234 |
+
self.name_prefix = name_prefix
|
| 235 |
+
|
| 236 |
+
def _init_callback(self) -> None:
|
| 237 |
+
# Create folder if needed
|
| 238 |
+
if self.save_path is not None:
|
| 239 |
+
os.makedirs(self.save_path, exist_ok=True)
|
| 240 |
+
|
| 241 |
+
def _on_step(self) -> bool:
|
| 242 |
+
if self.n_calls % self.save_freq == 0:
|
| 243 |
+
path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
|
| 244 |
+
self.model.save(path)
|
| 245 |
+
if self.verbose > 1:
|
| 246 |
+
print(f"Saving model checkpoint to {path}")
|
| 247 |
+
return True
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class ConvertCallback(BaseCallback):
|
| 251 |
+
"""
|
| 252 |
+
Convert functional callback (old-style) to object.
|
| 253 |
+
|
| 254 |
+
:param callback:
|
| 255 |
+
:param verbose:
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
|
| 259 |
+
super().__init__(verbose)
|
| 260 |
+
self.callback = callback
|
| 261 |
+
|
| 262 |
+
def _on_step(self) -> bool:
|
| 263 |
+
if self.callback is not None:
|
| 264 |
+
return self.callback(self.locals, self.globals)
|
| 265 |
+
return True
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class EvalCallback(EventCallback):
|
| 269 |
+
"""
|
| 270 |
+
Callback for evaluating an agent.
|
| 271 |
+
|
| 272 |
+
.. warning::
|
| 273 |
+
|
| 274 |
+
When using multiple environments, each call to ``env.step()``
|
| 275 |
+
will effectively correspond to ``n_envs`` steps.
|
| 276 |
+
To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``
|
| 277 |
+
|
| 278 |
+
:param eval_env: The environment used for initialization
|
| 279 |
+
:param callback_on_new_best: Callback to trigger
|
| 280 |
+
when there is a new best model according to the ``mean_reward``
|
| 281 |
+
:param callback_after_eval: Callback to trigger after every evaluation
|
| 282 |
+
:param n_eval_episodes: The number of episodes to test the agent
|
| 283 |
+
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
|
| 284 |
+
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
|
| 285 |
+
will be saved. It will be updated at each evaluation.
|
| 286 |
+
:param best_model_save_path: Path to a folder where the best model
|
| 287 |
+
according to performance on the eval env will be saved.
|
| 288 |
+
:param deterministic: Whether the evaluation should
|
| 289 |
+
use a stochastic or deterministic actions.
|
| 290 |
+
:param render: Whether to render or not the environment during evaluation
|
| 291 |
+
:param verbose:
|
| 292 |
+
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
| 293 |
+
wrapped with a Monitor wrapper)
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(
|
| 297 |
+
self,
|
| 298 |
+
eval_env: Union[gym.Env, VecEnv],
|
| 299 |
+
callback_on_new_best: Optional[BaseCallback] = None,
|
| 300 |
+
callback_after_eval: Optional[BaseCallback] = None,
|
| 301 |
+
n_eval_episodes: int = 5,
|
| 302 |
+
eval_freq: int = 10000,
|
| 303 |
+
log_path: Optional[str] = None,
|
| 304 |
+
best_model_save_path: Optional[str] = None,
|
| 305 |
+
deterministic: bool = True,
|
| 306 |
+
render: bool = False,
|
| 307 |
+
verbose: int = 1,
|
| 308 |
+
warn: bool = True,
|
| 309 |
+
):
|
| 310 |
+
super().__init__(callback_after_eval, verbose=verbose)
|
| 311 |
+
|
| 312 |
+
self.callback_on_new_best = callback_on_new_best
|
| 313 |
+
if self.callback_on_new_best is not None:
|
| 314 |
+
# Give access to the parent
|
| 315 |
+
self.callback_on_new_best.parent = self
|
| 316 |
+
|
| 317 |
+
self.n_eval_episodes = n_eval_episodes
|
| 318 |
+
self.eval_freq = eval_freq
|
| 319 |
+
self.best_mean_reward = -np.inf
|
| 320 |
+
self.last_mean_reward = -np.inf
|
| 321 |
+
self.deterministic = deterministic
|
| 322 |
+
self.render = render
|
| 323 |
+
self.warn = warn
|
| 324 |
+
|
| 325 |
+
# Convert to VecEnv for consistency
|
| 326 |
+
if not isinstance(eval_env, VecEnv):
|
| 327 |
+
eval_env = DummyVecEnv([lambda: eval_env])
|
| 328 |
+
|
| 329 |
+
self.eval_env = eval_env
|
| 330 |
+
self.best_model_save_path = best_model_save_path
|
| 331 |
+
# Logs will be written in ``evaluations.npz``
|
| 332 |
+
if log_path is not None:
|
| 333 |
+
log_path = os.path.join(log_path, "evaluations")
|
| 334 |
+
self.log_path = log_path
|
| 335 |
+
self.evaluations_results = []
|
| 336 |
+
self.evaluations_timesteps = []
|
| 337 |
+
self.evaluations_length = []
|
| 338 |
+
# For computing success rate
|
| 339 |
+
self._is_success_buffer = []
|
| 340 |
+
self.evaluations_successes = []
|
| 341 |
+
|
| 342 |
+
def _init_callback(self) -> None:
|
| 343 |
+
# Does not work in some corner cases, where the wrapper is not the same
|
| 344 |
+
if not isinstance(self.training_env, type(self.eval_env)):
|
| 345 |
+
warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
|
| 346 |
+
|
| 347 |
+
# Create folders if needed
|
| 348 |
+
if self.best_model_save_path is not None:
|
| 349 |
+
os.makedirs(self.best_model_save_path, exist_ok=True)
|
| 350 |
+
if self.log_path is not None:
|
| 351 |
+
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
| 352 |
+
|
| 353 |
+
# Init callback called on new best model
|
| 354 |
+
if self.callback_on_new_best is not None:
|
| 355 |
+
self.callback_on_new_best.init_callback(self.model)
|
| 356 |
+
|
| 357 |
+
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
|
| 358 |
+
"""
|
| 359 |
+
Callback passed to the ``evaluate_policy`` function
|
| 360 |
+
in order to log the success rate (when applicable),
|
| 361 |
+
for instance when using HER.
|
| 362 |
+
|
| 363 |
+
:param locals_:
|
| 364 |
+
:param globals_:
|
| 365 |
+
"""
|
| 366 |
+
info = locals_["info"]
|
| 367 |
+
|
| 368 |
+
if locals_["done"]:
|
| 369 |
+
maybe_is_success = info.get("is_success")
|
| 370 |
+
if maybe_is_success is not None:
|
| 371 |
+
self._is_success_buffer.append(maybe_is_success)
|
| 372 |
+
|
| 373 |
+
def _on_step(self) -> bool:
|
| 374 |
+
|
| 375 |
+
continue_training = True
|
| 376 |
+
|
| 377 |
+
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
| 378 |
+
|
| 379 |
+
# Sync training and eval env if there is VecNormalize
|
| 380 |
+
if self.model.get_vec_normalize_env() is not None:
|
| 381 |
+
try:
|
| 382 |
+
sync_envs_normalization(self.training_env, self.eval_env)
|
| 383 |
+
except AttributeError:
|
| 384 |
+
raise AssertionError(
|
| 385 |
+
"Training and eval env are not wrapped the same way, "
|
| 386 |
+
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
|
| 387 |
+
"and warning above."
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Reset success rate buffer
|
| 391 |
+
self._is_success_buffer = []
|
| 392 |
+
|
| 393 |
+
episode_rewards, episode_lengths = evaluate_policy(
|
| 394 |
+
self.model,
|
| 395 |
+
self.eval_env,
|
| 396 |
+
n_eval_episodes=self.n_eval_episodes,
|
| 397 |
+
render=self.render,
|
| 398 |
+
deterministic=self.deterministic,
|
| 399 |
+
return_episode_rewards=True,
|
| 400 |
+
warn=self.warn,
|
| 401 |
+
callback=self._log_success_callback,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
if self.log_path is not None:
|
| 405 |
+
self.evaluations_timesteps.append(self.num_timesteps)
|
| 406 |
+
self.evaluations_results.append(episode_rewards)
|
| 407 |
+
self.evaluations_length.append(episode_lengths)
|
| 408 |
+
|
| 409 |
+
kwargs = {}
|
| 410 |
+
# Save success log if present
|
| 411 |
+
if len(self._is_success_buffer) > 0:
|
| 412 |
+
self.evaluations_successes.append(self._is_success_buffer)
|
| 413 |
+
kwargs = dict(successes=self.evaluations_successes)
|
| 414 |
+
|
| 415 |
+
np.savez(
|
| 416 |
+
self.log_path,
|
| 417 |
+
timesteps=self.evaluations_timesteps,
|
| 418 |
+
results=self.evaluations_results,
|
| 419 |
+
ep_lengths=self.evaluations_length,
|
| 420 |
+
**kwargs,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
| 424 |
+
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
| 425 |
+
self.last_mean_reward = mean_reward
|
| 426 |
+
|
| 427 |
+
if self.verbose > 0:
|
| 428 |
+
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
| 429 |
+
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
|
| 430 |
+
# Add to current Logger
|
| 431 |
+
self.logger.record("eval/mean_reward", float(mean_reward))
|
| 432 |
+
self.logger.record("eval/mean_ep_length", mean_ep_length)
|
| 433 |
+
|
| 434 |
+
if len(self._is_success_buffer) > 0:
|
| 435 |
+
success_rate = np.mean(self._is_success_buffer)
|
| 436 |
+
if self.verbose > 0:
|
| 437 |
+
print(f"Success rate: {100 * success_rate:.2f}%")
|
| 438 |
+
self.logger.record("eval/success_rate", success_rate)
|
| 439 |
+
|
| 440 |
+
# Dump log so the evaluation results are printed with the correct timestep
|
| 441 |
+
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
| 442 |
+
self.logger.dump(self.num_timesteps)
|
| 443 |
+
|
| 444 |
+
if mean_reward > self.best_mean_reward:
|
| 445 |
+
if self.verbose > 0:
|
| 446 |
+
print("New best mean reward!")
|
| 447 |
+
if self.best_model_save_path is not None:
|
| 448 |
+
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
| 449 |
+
self.best_mean_reward = mean_reward
|
| 450 |
+
# Trigger callback on new best model, if needed
|
| 451 |
+
if self.callback_on_new_best is not None:
|
| 452 |
+
continue_training = self.callback_on_new_best.on_step()
|
| 453 |
+
|
| 454 |
+
# Trigger callback after every evaluation, if needed
|
| 455 |
+
if self.callback is not None:
|
| 456 |
+
continue_training = continue_training and self._on_event()
|
| 457 |
+
|
| 458 |
+
return continue_training
|
| 459 |
+
|
| 460 |
+
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
| 461 |
+
"""
|
| 462 |
+
Update the references to the local variables.
|
| 463 |
+
|
| 464 |
+
:param locals_: the local variables during rollout collection
|
| 465 |
+
"""
|
| 466 |
+
if self.callback:
|
| 467 |
+
self.callback.update_locals(locals_)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class StopTrainingOnRewardThreshold(BaseCallback):
|
| 471 |
+
"""
|
| 472 |
+
Stop the training once a threshold in episodic reward
|
| 473 |
+
has been reached (i.e. when the model is good enough).
|
| 474 |
+
|
| 475 |
+
It must be used with the ``EvalCallback``.
|
| 476 |
+
|
| 477 |
+
:param reward_threshold: Minimum expected reward per episode
|
| 478 |
+
to stop training.
|
| 479 |
+
:param verbose:
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
def __init__(self, reward_threshold: float, verbose: int = 0):
|
| 483 |
+
super().__init__(verbose=verbose)
|
| 484 |
+
self.reward_threshold = reward_threshold
|
| 485 |
+
|
| 486 |
+
def _on_step(self) -> bool:
|
| 487 |
+
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``"
|
| 488 |
+
# Convert np.bool_ to bool, otherwise callback() is False won't work
|
| 489 |
+
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
|
| 490 |
+
if self.verbose > 0 and not continue_training:
|
| 491 |
+
print(
|
| 492 |
+
f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
|
| 493 |
+
f" is above the threshold {self.reward_threshold}"
|
| 494 |
+
)
|
| 495 |
+
return continue_training
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class EveryNTimesteps(EventCallback):
|
| 499 |
+
"""
|
| 500 |
+
Trigger a callback every ``n_steps`` timesteps
|
| 501 |
+
|
| 502 |
+
:param n_steps: Number of timesteps between two trigger.
|
| 503 |
+
:param callback: Callback that will be called
|
| 504 |
+
when the event is triggered.
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
def __init__(self, n_steps: int, callback: BaseCallback):
|
| 508 |
+
super().__init__(callback)
|
| 509 |
+
self.n_steps = n_steps
|
| 510 |
+
self.last_time_trigger = 0
|
| 511 |
+
|
| 512 |
+
def _on_step(self) -> bool:
|
| 513 |
+
if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
|
| 514 |
+
self.last_time_trigger = self.num_timesteps
|
| 515 |
+
return self._on_event()
|
| 516 |
+
return True
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class StopTrainingOnMaxEpisodes(BaseCallback):
|
| 520 |
+
"""
|
| 521 |
+
Stop the training once a maximum number of episodes are played.
|
| 522 |
+
|
| 523 |
+
For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
|
| 524 |
+
and in total for ``max_episodes * n_envs`` episodes.
|
| 525 |
+
|
| 526 |
+
:param max_episodes: Maximum number of episodes to stop training.
|
| 527 |
+
:param verbose: Select whether to print information about when training ended by reaching ``max_episodes``
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(self, max_episodes: int, verbose: int = 0):
|
| 531 |
+
super().__init__(verbose=verbose)
|
| 532 |
+
self.max_episodes = max_episodes
|
| 533 |
+
self._total_max_episodes = max_episodes
|
| 534 |
+
self.n_episodes = 0
|
| 535 |
+
|
| 536 |
+
def _init_callback(self) -> None:
|
| 537 |
+
# At start set total max according to number of envirnments
|
| 538 |
+
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
|
| 539 |
+
|
| 540 |
+
def _on_step(self) -> bool:
|
| 541 |
+
# Check that the `dones` local variable is defined
|
| 542 |
+
assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`"
|
| 543 |
+
self.n_episodes += np.sum(self.locals["dones"]).item()
|
| 544 |
+
|
| 545 |
+
continue_training = self.n_episodes < self._total_max_episodes
|
| 546 |
+
|
| 547 |
+
if self.verbose > 0 and not continue_training:
|
| 548 |
+
mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
|
| 549 |
+
mean_ep_str = (
|
| 550 |
+
f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
print(
|
| 554 |
+
f"Stopping training with a total of {self.num_timesteps} steps because the "
|
| 555 |
+
f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, "
|
| 556 |
+
f"by playing for {self.n_episodes} episodes "
|
| 557 |
+
f"{mean_ep_str}"
|
| 558 |
+
)
|
| 559 |
+
return continue_training
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class StopTrainingOnNoModelImprovement(BaseCallback):
|
| 563 |
+
"""
|
| 564 |
+
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
|
| 565 |
+
|
| 566 |
+
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
|
| 567 |
+
|
| 568 |
+
It must be used with the ``EvalCallback``.
|
| 569 |
+
|
| 570 |
+
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
|
| 571 |
+
:param min_evals: Number of evaluations before start to count evaluations without improvements.
|
| 572 |
+
:param verbose: Verbosity of the output (set to 1 for info messages)
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
|
| 576 |
+
super().__init__(verbose=verbose)
|
| 577 |
+
self.max_no_improvement_evals = max_no_improvement_evals
|
| 578 |
+
self.min_evals = min_evals
|
| 579 |
+
self.last_best_mean_reward = -np.inf
|
| 580 |
+
self.no_improvement_evals = 0
|
| 581 |
+
|
| 582 |
+
def _on_step(self) -> bool:
|
| 583 |
+
assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
|
| 584 |
+
|
| 585 |
+
continue_training = True
|
| 586 |
+
|
| 587 |
+
if self.n_calls > self.min_evals:
|
| 588 |
+
if self.parent.best_mean_reward > self.last_best_mean_reward:
|
| 589 |
+
self.no_improvement_evals = 0
|
| 590 |
+
else:
|
| 591 |
+
self.no_improvement_evals += 1
|
| 592 |
+
if self.no_improvement_evals > self.max_no_improvement_evals:
|
| 593 |
+
continue_training = False
|
| 594 |
+
|
| 595 |
+
self.last_best_mean_reward = self.parent.best_mean_reward
|
| 596 |
+
|
| 597 |
+
if self.verbose > 0 and not continue_training:
|
| 598 |
+
print(
|
| 599 |
+
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
return continue_training
|
dexart-release/stable_baselines3/common/distributions.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Probability distributions."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import gym
|
| 7 |
+
import torch as th
|
| 8 |
+
from gym import spaces
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.distributions import Bernoulli, Categorical, Normal
|
| 11 |
+
|
| 12 |
+
from stable_baselines3.common.preprocessing import get_action_dim
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Distribution(ABC):
|
| 16 |
+
"""Abstract base class for distributions."""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.distribution = None
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
|
| 24 |
+
"""Create the layers and parameters that represent the distribution.
|
| 25 |
+
|
| 26 |
+
Subclasses must define this, but the arguments and return type vary between
|
| 27 |
+
concrete classes."""
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def proba_distribution(self, *args, **kwargs) -> "Distribution":
|
| 31 |
+
"""Set parameters of the distribution.
|
| 32 |
+
|
| 33 |
+
:return: self
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
@abstractmethod
|
| 37 |
+
def log_prob(self, x: th.Tensor) -> th.Tensor:
|
| 38 |
+
"""
|
| 39 |
+
Returns the log likelihood
|
| 40 |
+
|
| 41 |
+
:param x: the taken action
|
| 42 |
+
:return: The log likelihood of the distribution
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def entropy(self) -> Optional[th.Tensor]:
|
| 47 |
+
"""
|
| 48 |
+
Returns Shannon's entropy of the probability
|
| 49 |
+
|
| 50 |
+
:return: the entropy, or None if no analytical form is known
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def sample(self) -> th.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Returns a sample from the probability distribution
|
| 57 |
+
|
| 58 |
+
:return: the stochastic action
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
@abstractmethod
|
| 62 |
+
def mode(self) -> th.Tensor:
|
| 63 |
+
"""
|
| 64 |
+
Returns the most likely action (deterministic output)
|
| 65 |
+
from the probability distribution
|
| 66 |
+
|
| 67 |
+
:return: the stochastic action
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def get_actions(self, deterministic: bool = False) -> th.Tensor:
|
| 71 |
+
"""
|
| 72 |
+
Return actions according to the probability distribution.
|
| 73 |
+
|
| 74 |
+
:param deterministic:
|
| 75 |
+
:return:
|
| 76 |
+
"""
|
| 77 |
+
if deterministic:
|
| 78 |
+
return self.mode()
|
| 79 |
+
return self.sample()
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Returns samples from the probability distribution
|
| 85 |
+
given its parameters.
|
| 86 |
+
|
| 87 |
+
:return: actions
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
|
| 92 |
+
"""
|
| 93 |
+
Returns samples and the associated log probabilities
|
| 94 |
+
from the probability distribution given its parameters.
|
| 95 |
+
|
| 96 |
+
:return: actions and log prob
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Continuous actions are usually considered to be independent,
|
| 103 |
+
so we can sum components of the ``log_prob`` or the entropy.
|
| 104 |
+
|
| 105 |
+
:param tensor: shape: (n_batch, n_actions) or (n_batch,)
|
| 106 |
+
:return: shape: (n_batch,)
|
| 107 |
+
"""
|
| 108 |
+
if len(tensor.shape) > 1:
|
| 109 |
+
tensor = tensor.sum(dim=1)
|
| 110 |
+
else:
|
| 111 |
+
tensor = tensor.sum()
|
| 112 |
+
return tensor
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class DiagGaussianDistribution(Distribution):
|
| 116 |
+
"""
|
| 117 |
+
Gaussian distribution with diagonal covariance matrix, for continuous actions.
|
| 118 |
+
|
| 119 |
+
:param action_dim: Dimension of the action space.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, action_dim: int):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.action_dim = action_dim
|
| 125 |
+
self.mean_actions = None
|
| 126 |
+
self.log_std = None
|
| 127 |
+
|
| 128 |
+
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
|
| 129 |
+
"""
|
| 130 |
+
Create the layers and parameter that represent the distribution:
|
| 131 |
+
one output will be the mean of the Gaussian, the other parameter will be the
|
| 132 |
+
standard deviation (log std in fact to allow negative values)
|
| 133 |
+
|
| 134 |
+
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
| 135 |
+
:param log_std_init: Initial value for the log standard deviation
|
| 136 |
+
:return:
|
| 137 |
+
"""
|
| 138 |
+
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
| 139 |
+
# TODO: allow action dependent std
|
| 140 |
+
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
| 141 |
+
return mean_actions, log_std
|
| 142 |
+
|
| 143 |
+
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
|
| 144 |
+
"""
|
| 145 |
+
Create the distribution given its parameters (mean, std)
|
| 146 |
+
|
| 147 |
+
:param mean_actions:
|
| 148 |
+
:param log_std:
|
| 149 |
+
:return:
|
| 150 |
+
"""
|
| 151 |
+
action_std = th.ones_like(mean_actions) * log_std.exp()
|
| 152 |
+
self.distribution = Normal(mean_actions, action_std)
|
| 153 |
+
return self
|
| 154 |
+
|
| 155 |
+
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
| 156 |
+
"""
|
| 157 |
+
Get the log probabilities of actions according to the distribution.
|
| 158 |
+
Note that you must first call the ``proba_distribution()`` method.
|
| 159 |
+
|
| 160 |
+
:param actions:
|
| 161 |
+
:return:
|
| 162 |
+
"""
|
| 163 |
+
log_prob = self.distribution.log_prob(actions)
|
| 164 |
+
return sum_independent_dims(log_prob)
|
| 165 |
+
|
| 166 |
+
def entropy(self) -> th.Tensor:
|
| 167 |
+
return sum_independent_dims(self.distribution.entropy())
|
| 168 |
+
|
| 169 |
+
def sample(self) -> th.Tensor:
|
| 170 |
+
# Reparametrization trick to pass gradients
|
| 171 |
+
return self.distribution.rsample()
|
| 172 |
+
|
| 173 |
+
def mode(self) -> th.Tensor:
|
| 174 |
+
return self.distribution.mean
|
| 175 |
+
|
| 176 |
+
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
| 177 |
+
# Update the proba distribution
|
| 178 |
+
self.proba_distribution(mean_actions, log_std)
|
| 179 |
+
return self.get_actions(deterministic=deterministic)
|
| 180 |
+
|
| 181 |
+
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
| 182 |
+
"""
|
| 183 |
+
Compute the log probability of taking an action
|
| 184 |
+
given the distribution parameters.
|
| 185 |
+
|
| 186 |
+
:param mean_actions:
|
| 187 |
+
:param log_std:
|
| 188 |
+
:return:
|
| 189 |
+
"""
|
| 190 |
+
actions = self.actions_from_params(mean_actions, log_std)
|
| 191 |
+
log_prob = self.log_prob(actions)
|
| 192 |
+
return actions, log_prob
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
| 196 |
+
"""
|
| 197 |
+
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
|
| 198 |
+
|
| 199 |
+
:param action_dim: Dimension of the action space.
|
| 200 |
+
:param epsilon: small value to avoid NaN due to numerical imprecision.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, action_dim: int, epsilon: float = 1e-6):
|
| 204 |
+
super().__init__(action_dim)
|
| 205 |
+
# Avoid NaN (prevents division by zero or log of zero)
|
| 206 |
+
self.epsilon = epsilon
|
| 207 |
+
self.gaussian_actions = None
|
| 208 |
+
|
| 209 |
+
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
|
| 210 |
+
super().proba_distribution(mean_actions, log_std)
|
| 211 |
+
return self
|
| 212 |
+
|
| 213 |
+
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
|
| 214 |
+
# Inverse tanh
|
| 215 |
+
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
|
| 216 |
+
# We use numpy to avoid numerical instability
|
| 217 |
+
if gaussian_actions is None:
|
| 218 |
+
# It will be clipped to avoid NaN when inversing tanh
|
| 219 |
+
gaussian_actions = TanhBijector.inverse(actions)
|
| 220 |
+
|
| 221 |
+
# Log likelihood for a Gaussian distribution
|
| 222 |
+
log_prob = super().log_prob(gaussian_actions)
|
| 223 |
+
# Squash correction (from original SAC implementation)
|
| 224 |
+
# this comes from the fact that tanh is bijective and differentiable
|
| 225 |
+
log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
|
| 226 |
+
return log_prob
|
| 227 |
+
|
| 228 |
+
def entropy(self) -> Optional[th.Tensor]:
|
| 229 |
+
# No analytical form,
|
| 230 |
+
# entropy needs to be estimated using -log_prob.mean()
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
def sample(self) -> th.Tensor:
|
| 234 |
+
# Reparametrization trick to pass gradients
|
| 235 |
+
self.gaussian_actions = super().sample()
|
| 236 |
+
return th.tanh(self.gaussian_actions)
|
| 237 |
+
|
| 238 |
+
def mode(self) -> th.Tensor:
|
| 239 |
+
self.gaussian_actions = super().mode()
|
| 240 |
+
# Squash the output
|
| 241 |
+
return th.tanh(self.gaussian_actions)
|
| 242 |
+
|
| 243 |
+
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
| 244 |
+
action = self.actions_from_params(mean_actions, log_std)
|
| 245 |
+
log_prob = self.log_prob(action, self.gaussian_actions)
|
| 246 |
+
return action, log_prob
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class CategoricalDistribution(Distribution):
|
| 250 |
+
"""
|
| 251 |
+
Categorical distribution for discrete actions.
|
| 252 |
+
|
| 253 |
+
:param action_dim: Number of discrete actions
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, action_dim: int):
|
| 257 |
+
super().__init__()
|
| 258 |
+
self.action_dim = action_dim
|
| 259 |
+
|
| 260 |
+
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
| 261 |
+
"""
|
| 262 |
+
Create the layer that represents the distribution:
|
| 263 |
+
it will be the logits of the Categorical distribution.
|
| 264 |
+
You can then get probabilities using a softmax.
|
| 265 |
+
|
| 266 |
+
:param latent_dim: Dimension of the last layer
|
| 267 |
+
of the policy network (before the action layer)
|
| 268 |
+
:return:
|
| 269 |
+
"""
|
| 270 |
+
action_logits = nn.Linear(latent_dim, self.action_dim)
|
| 271 |
+
return action_logits
|
| 272 |
+
|
| 273 |
+
def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution":
|
| 274 |
+
self.distribution = Categorical(logits=action_logits)
|
| 275 |
+
return self
|
| 276 |
+
|
| 277 |
+
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
| 278 |
+
return self.distribution.log_prob(actions)
|
| 279 |
+
|
| 280 |
+
def entropy(self) -> th.Tensor:
|
| 281 |
+
return self.distribution.entropy()
|
| 282 |
+
|
| 283 |
+
def sample(self) -> th.Tensor:
|
| 284 |
+
return self.distribution.sample()
|
| 285 |
+
|
| 286 |
+
def mode(self) -> th.Tensor:
|
| 287 |
+
return th.argmax(self.distribution.probs, dim=1)
|
| 288 |
+
|
| 289 |
+
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
| 290 |
+
# Update the proba distribution
|
| 291 |
+
self.proba_distribution(action_logits)
|
| 292 |
+
return self.get_actions(deterministic=deterministic)
|
| 293 |
+
|
| 294 |
+
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
| 295 |
+
actions = self.actions_from_params(action_logits)
|
| 296 |
+
log_prob = self.log_prob(actions)
|
| 297 |
+
return actions, log_prob
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class MultiCategoricalDistribution(Distribution):
|
| 301 |
+
"""
|
| 302 |
+
MultiCategorical distribution for multi discrete actions.
|
| 303 |
+
|
| 304 |
+
:param action_dims: List of sizes of discrete action spaces
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, action_dims: List[int]):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.action_dims = action_dims
|
| 310 |
+
|
| 311 |
+
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
| 312 |
+
"""
|
| 313 |
+
Create the layer that represents the distribution:
|
| 314 |
+
it will be the logits (flattened) of the MultiCategorical distribution.
|
| 315 |
+
You can then get probabilities using a softmax on each sub-space.
|
| 316 |
+
|
| 317 |
+
:param latent_dim: Dimension of the last layer
|
| 318 |
+
of the policy network (before the action layer)
|
| 319 |
+
:return:
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
|
| 323 |
+
return action_logits
|
| 324 |
+
|
| 325 |
+
def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
|
| 326 |
+
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
|
| 327 |
+
return self
|
| 328 |
+
|
| 329 |
+
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
| 330 |
+
# Extract each discrete action and compute log prob for their respective distributions
|
| 331 |
+
return th.stack(
|
| 332 |
+
[dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1
|
| 333 |
+
).sum(dim=1)
|
| 334 |
+
|
| 335 |
+
def entropy(self) -> th.Tensor:
|
| 336 |
+
return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)
|
| 337 |
+
|
| 338 |
+
def sample(self) -> th.Tensor:
|
| 339 |
+
return th.stack([dist.sample() for dist in self.distribution], dim=1)
|
| 340 |
+
|
| 341 |
+
def mode(self) -> th.Tensor:
|
| 342 |
+
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)
|
| 343 |
+
|
| 344 |
+
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
| 345 |
+
# Update the proba distribution
|
| 346 |
+
self.proba_distribution(action_logits)
|
| 347 |
+
return self.get_actions(deterministic=deterministic)
|
| 348 |
+
|
| 349 |
+
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
| 350 |
+
actions = self.actions_from_params(action_logits)
|
| 351 |
+
log_prob = self.log_prob(actions)
|
| 352 |
+
return actions, log_prob
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class BernoulliDistribution(Distribution):
|
| 356 |
+
"""
|
| 357 |
+
Bernoulli distribution for MultiBinary action spaces.
|
| 358 |
+
|
| 359 |
+
:param action_dim: Number of binary actions
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def __init__(self, action_dims: int):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.action_dims = action_dims
|
| 365 |
+
|
| 366 |
+
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
|
| 367 |
+
"""
|
| 368 |
+
Create the layer that represents the distribution:
|
| 369 |
+
it will be the logits of the Bernoulli distribution.
|
| 370 |
+
|
| 371 |
+
:param latent_dim: Dimension of the last layer
|
| 372 |
+
of the policy network (before the action layer)
|
| 373 |
+
:return:
|
| 374 |
+
"""
|
| 375 |
+
action_logits = nn.Linear(latent_dim, self.action_dims)
|
| 376 |
+
return action_logits
|
| 377 |
+
|
| 378 |
+
def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution":
|
| 379 |
+
self.distribution = Bernoulli(logits=action_logits)
|
| 380 |
+
return self
|
| 381 |
+
|
| 382 |
+
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
| 383 |
+
return self.distribution.log_prob(actions).sum(dim=1)
|
| 384 |
+
|
| 385 |
+
def entropy(self) -> th.Tensor:
|
| 386 |
+
return self.distribution.entropy().sum(dim=1)
|
| 387 |
+
|
| 388 |
+
def sample(self) -> th.Tensor:
|
| 389 |
+
return self.distribution.sample()
|
| 390 |
+
|
| 391 |
+
def mode(self) -> th.Tensor:
|
| 392 |
+
return th.round(self.distribution.probs)
|
| 393 |
+
|
| 394 |
+
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
| 395 |
+
# Update the proba distribution
|
| 396 |
+
self.proba_distribution(action_logits)
|
| 397 |
+
return self.get_actions(deterministic=deterministic)
|
| 398 |
+
|
| 399 |
+
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
| 400 |
+
actions = self.actions_from_params(action_logits)
|
| 401 |
+
log_prob = self.log_prob(actions)
|
| 402 |
+
return actions, log_prob
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class StateDependentNoiseDistribution(Distribution):
|
| 406 |
+
"""
|
| 407 |
+
Distribution class for using generalized State Dependent Exploration (gSDE).
|
| 408 |
+
Paper: https://arxiv.org/abs/2005.05719
|
| 409 |
+
|
| 410 |
+
It is used to create the noise exploration matrix and
|
| 411 |
+
compute the log probability of an action with that noise.
|
| 412 |
+
|
| 413 |
+
:param action_dim: Dimension of the action space.
|
| 414 |
+
:param full_std: Whether to use (n_features x n_actions) parameters
|
| 415 |
+
for the std instead of only (n_features,)
|
| 416 |
+
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
|
| 417 |
+
a positive standard deviation (cf paper). It allows to keep variance
|
| 418 |
+
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
| 419 |
+
:param squash_output: Whether to squash the output using a tanh function,
|
| 420 |
+
this ensures bounds are satisfied.
|
| 421 |
+
:param learn_features: Whether to learn features for gSDE or not.
|
| 422 |
+
This will enable gradients to be backpropagated through the features
|
| 423 |
+
``latent_sde`` in the code.
|
| 424 |
+
:param epsilon: small value to avoid NaN due to numerical imprecision.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(
|
| 428 |
+
self,
|
| 429 |
+
action_dim: int,
|
| 430 |
+
full_std: bool = True,
|
| 431 |
+
use_expln: bool = False,
|
| 432 |
+
squash_output: bool = False,
|
| 433 |
+
learn_features: bool = False,
|
| 434 |
+
epsilon: float = 1e-6,
|
| 435 |
+
):
|
| 436 |
+
super().__init__()
|
| 437 |
+
self.action_dim = action_dim
|
| 438 |
+
self.latent_sde_dim = None
|
| 439 |
+
self.mean_actions = None
|
| 440 |
+
self.log_std = None
|
| 441 |
+
self.weights_dist = None
|
| 442 |
+
self.exploration_mat = None
|
| 443 |
+
self.exploration_matrices = None
|
| 444 |
+
self._latent_sde = None
|
| 445 |
+
self.use_expln = use_expln
|
| 446 |
+
self.full_std = full_std
|
| 447 |
+
self.epsilon = epsilon
|
| 448 |
+
self.learn_features = learn_features
|
| 449 |
+
if squash_output:
|
| 450 |
+
self.bijector = TanhBijector(epsilon)
|
| 451 |
+
else:
|
| 452 |
+
self.bijector = None
|
| 453 |
+
|
| 454 |
+
def get_std(self, log_std: th.Tensor) -> th.Tensor:
|
| 455 |
+
"""
|
| 456 |
+
Get the standard deviation from the learned parameter
|
| 457 |
+
(log of it by default). This ensures that the std is positive.
|
| 458 |
+
|
| 459 |
+
:param log_std:
|
| 460 |
+
:return:
|
| 461 |
+
"""
|
| 462 |
+
if self.use_expln:
|
| 463 |
+
# From gSDE paper, it allows to keep variance
|
| 464 |
+
# above zero and prevent it from growing too fast
|
| 465 |
+
below_threshold = th.exp(log_std) * (log_std <= 0)
|
| 466 |
+
# Avoid NaN: zeros values that are below zero
|
| 467 |
+
safe_log_std = log_std * (log_std > 0) + self.epsilon
|
| 468 |
+
above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0)
|
| 469 |
+
std = below_threshold + above_threshold
|
| 470 |
+
else:
|
| 471 |
+
# Use normal exponential
|
| 472 |
+
std = th.exp(log_std)
|
| 473 |
+
|
| 474 |
+
if self.full_std:
|
| 475 |
+
return std
|
| 476 |
+
# Reduce the number of parameters:
|
| 477 |
+
return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
|
| 478 |
+
|
| 479 |
+
def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
|
| 480 |
+
"""
|
| 481 |
+
Sample weights for the noise exploration matrix,
|
| 482 |
+
using a centered Gaussian distribution.
|
| 483 |
+
|
| 484 |
+
:param log_std:
|
| 485 |
+
:param batch_size:
|
| 486 |
+
"""
|
| 487 |
+
std = self.get_std(log_std)
|
| 488 |
+
self.weights_dist = Normal(th.zeros_like(std), std)
|
| 489 |
+
# Reparametrization trick to pass gradients
|
| 490 |
+
self.exploration_mat = self.weights_dist.rsample()
|
| 491 |
+
# Pre-compute matrices in case of parallel exploration
|
| 492 |
+
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
|
| 493 |
+
|
| 494 |
+
def proba_distribution_net(
|
| 495 |
+
self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
|
| 496 |
+
) -> Tuple[nn.Module, nn.Parameter]:
|
| 497 |
+
"""
|
| 498 |
+
Create the layers and parameter that represent the distribution:
|
| 499 |
+
one output will be the deterministic action, the other parameter will be the
|
| 500 |
+
standard deviation of the distribution that control the weights of the noise matrix.
|
| 501 |
+
|
| 502 |
+
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
| 503 |
+
:param log_std_init: Initial value for the log standard deviation
|
| 504 |
+
:param latent_sde_dim: Dimension of the last layer of the features extractor
|
| 505 |
+
for gSDE. By default, it is shared with the policy network.
|
| 506 |
+
:return:
|
| 507 |
+
"""
|
| 508 |
+
# Network for the deterministic action, it represents the mean of the distribution
|
| 509 |
+
mean_actions_net = nn.Linear(latent_dim, self.action_dim)
|
| 510 |
+
# When we learn features for the noise, the feature dimension
|
| 511 |
+
# can be different between the policy and the noise network
|
| 512 |
+
self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
|
| 513 |
+
# Reduce the number of parameters if needed
|
| 514 |
+
log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
|
| 515 |
+
# Transform it to a parameter so it can be optimized
|
| 516 |
+
log_std = nn.Parameter(log_std * log_std_init, requires_grad=True)
|
| 517 |
+
# Sample an exploration matrix
|
| 518 |
+
self.sample_weights(log_std)
|
| 519 |
+
return mean_actions_net, log_std
|
| 520 |
+
|
| 521 |
+
def proba_distribution(
|
| 522 |
+
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
|
| 523 |
+
) -> "StateDependentNoiseDistribution":
|
| 524 |
+
"""
|
| 525 |
+
Create the distribution given its parameters (mean, std)
|
| 526 |
+
|
| 527 |
+
:param mean_actions:
|
| 528 |
+
:param log_std:
|
| 529 |
+
:param latent_sde:
|
| 530 |
+
:return:
|
| 531 |
+
"""
|
| 532 |
+
# Stop gradient if we don't want to influence the features
|
| 533 |
+
self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
| 534 |
+
variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2)
|
| 535 |
+
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
|
| 536 |
+
return self
|
| 537 |
+
|
| 538 |
+
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
| 539 |
+
if self.bijector is not None:
|
| 540 |
+
gaussian_actions = self.bijector.inverse(actions)
|
| 541 |
+
else:
|
| 542 |
+
gaussian_actions = actions
|
| 543 |
+
# log likelihood for a gaussian
|
| 544 |
+
log_prob = self.distribution.log_prob(gaussian_actions)
|
| 545 |
+
# Sum along action dim
|
| 546 |
+
log_prob = sum_independent_dims(log_prob)
|
| 547 |
+
|
| 548 |
+
if self.bijector is not None:
|
| 549 |
+
# Squash correction (from original SAC implementation)
|
| 550 |
+
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
|
| 551 |
+
return log_prob
|
| 552 |
+
|
| 553 |
+
def entropy(self) -> Optional[th.Tensor]:
|
| 554 |
+
if self.bijector is not None:
|
| 555 |
+
# No analytical form,
|
| 556 |
+
# entropy needs to be estimated using -log_prob.mean()
|
| 557 |
+
return None
|
| 558 |
+
return sum_independent_dims(self.distribution.entropy())
|
| 559 |
+
|
| 560 |
+
def sample(self) -> th.Tensor:
|
| 561 |
+
noise = self.get_noise(self._latent_sde)
|
| 562 |
+
actions = self.distribution.mean + noise
|
| 563 |
+
if self.bijector is not None:
|
| 564 |
+
return self.bijector.forward(actions)
|
| 565 |
+
return actions
|
| 566 |
+
|
| 567 |
+
def mode(self) -> th.Tensor:
|
| 568 |
+
actions = self.distribution.mean
|
| 569 |
+
if self.bijector is not None:
|
| 570 |
+
return self.bijector.forward(actions)
|
| 571 |
+
return actions
|
| 572 |
+
|
| 573 |
+
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
|
| 574 |
+
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
|
| 575 |
+
# Default case: only one exploration matrix
|
| 576 |
+
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
| 577 |
+
return th.mm(latent_sde, self.exploration_mat)
|
| 578 |
+
# Use batch matrix multiplication for efficient computation
|
| 579 |
+
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
| 580 |
+
latent_sde = latent_sde.unsqueeze(1)
|
| 581 |
+
# (batch_size, 1, n_actions)
|
| 582 |
+
noise = th.bmm(latent_sde, self.exploration_matrices)
|
| 583 |
+
return noise.squeeze(1)
|
| 584 |
+
|
| 585 |
+
def actions_from_params(
|
| 586 |
+
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
|
| 587 |
+
) -> th.Tensor:
|
| 588 |
+
# Update the proba distribution
|
| 589 |
+
self.proba_distribution(mean_actions, log_std, latent_sde)
|
| 590 |
+
return self.get_actions(deterministic=deterministic)
|
| 591 |
+
|
| 592 |
+
def log_prob_from_params(
|
| 593 |
+
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
|
| 594 |
+
) -> Tuple[th.Tensor, th.Tensor]:
|
| 595 |
+
actions = self.actions_from_params(mean_actions, log_std, latent_sde)
|
| 596 |
+
log_prob = self.log_prob(actions)
|
| 597 |
+
return actions, log_prob
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class TanhBijector:
|
| 601 |
+
"""
|
| 602 |
+
Bijective transformation of a probability distribution
|
| 603 |
+
using a squashing function (tanh)
|
| 604 |
+
TODO: use Pyro instead (https://pyro.ai/)
|
| 605 |
+
|
| 606 |
+
:param epsilon: small value to avoid NaN due to numerical imprecision.
|
| 607 |
+
"""
|
| 608 |
+
|
| 609 |
+
def __init__(self, epsilon: float = 1e-6):
|
| 610 |
+
super().__init__()
|
| 611 |
+
self.epsilon = epsilon
|
| 612 |
+
|
| 613 |
+
@staticmethod
|
| 614 |
+
def forward(x: th.Tensor) -> th.Tensor:
|
| 615 |
+
return th.tanh(x)
|
| 616 |
+
|
| 617 |
+
@staticmethod
|
| 618 |
+
def atanh(x: th.Tensor) -> th.Tensor:
|
| 619 |
+
"""
|
| 620 |
+
Inverse of Tanh
|
| 621 |
+
|
| 622 |
+
Taken from Pyro: https://github.com/pyro-ppl/pyro
|
| 623 |
+
0.5 * torch.log((1 + x ) / (1 - x))
|
| 624 |
+
"""
|
| 625 |
+
return 0.5 * (x.log1p() - (-x).log1p())
|
| 626 |
+
|
| 627 |
+
@staticmethod
|
| 628 |
+
def inverse(y: th.Tensor) -> th.Tensor:
|
| 629 |
+
"""
|
| 630 |
+
Inverse tanh.
|
| 631 |
+
|
| 632 |
+
:param y:
|
| 633 |
+
:return:
|
| 634 |
+
"""
|
| 635 |
+
eps = th.finfo(y.dtype).eps
|
| 636 |
+
# Clip the action to avoid NaN
|
| 637 |
+
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
|
| 638 |
+
|
| 639 |
+
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
|
| 640 |
+
# Squash correction (from original SAC implementation)
|
| 641 |
+
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def make_proba_distribution(
|
| 645 |
+
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
| 646 |
+
) -> Distribution:
|
| 647 |
+
"""
|
| 648 |
+
Return an instance of Distribution for the correct type of action space
|
| 649 |
+
|
| 650 |
+
:param action_space: the input action space
|
| 651 |
+
:param use_sde: Force the use of StateDependentNoiseDistribution
|
| 652 |
+
instead of DiagGaussianDistribution
|
| 653 |
+
:param dist_kwargs: Keyword arguments to pass to the probability distribution
|
| 654 |
+
:return: the appropriate Distribution object
|
| 655 |
+
"""
|
| 656 |
+
if dist_kwargs is None:
|
| 657 |
+
dist_kwargs = {}
|
| 658 |
+
|
| 659 |
+
if isinstance(action_space, spaces.Box):
|
| 660 |
+
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
| 661 |
+
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
|
| 662 |
+
return cls(get_action_dim(action_space), **dist_kwargs)
|
| 663 |
+
elif isinstance(action_space, spaces.Discrete):
|
| 664 |
+
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
| 665 |
+
elif isinstance(action_space, spaces.MultiDiscrete):
|
| 666 |
+
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
|
| 667 |
+
elif isinstance(action_space, spaces.MultiBinary):
|
| 668 |
+
return BernoulliDistribution(action_space.n, **dist_kwargs)
|
| 669 |
+
else:
|
| 670 |
+
raise NotImplementedError(
|
| 671 |
+
"Error: probability distribution, not implemented for action space"
|
| 672 |
+
f"of type {type(action_space)}."
|
| 673 |
+
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor:
|
| 678 |
+
"""
|
| 679 |
+
Wrapper for the PyTorch implementation of the full form KL Divergence
|
| 680 |
+
|
| 681 |
+
:param dist_true: the p distribution
|
| 682 |
+
:param dist_pred: the q distribution
|
| 683 |
+
:return: KL(dist_true||dist_pred)
|
| 684 |
+
"""
|
| 685 |
+
# KL Divergence for different distribution types is out of scope
|
| 686 |
+
assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
|
| 687 |
+
|
| 688 |
+
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
|
| 689 |
+
# so we need to implement it ourselves!
|
| 690 |
+
if isinstance(dist_pred, MultiCategoricalDistribution):
|
| 691 |
+
assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space"
|
| 692 |
+
return th.stack(
|
| 693 |
+
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
|
| 694 |
+
dim=1,
|
| 695 |
+
).sum(dim=1)
|
| 696 |
+
|
| 697 |
+
# Use the PyTorch kl_divergence implementation
|
| 698 |
+
else:
|
| 699 |
+
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)
|
dexart-release/stable_baselines3/common/env_util.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Callable, Dict, Optional, Type, Union
|
| 3 |
+
|
| 4 |
+
import gym
|
| 5 |
+
|
| 6 |
+
from stable_baselines3.common.monitor import Monitor
|
| 7 |
+
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
|
| 11 |
+
"""
|
| 12 |
+
Retrieve a ``VecEnvWrapper`` object by recursively searching.
|
| 13 |
+
|
| 14 |
+
:param env: Environment to unwrap
|
| 15 |
+
:param wrapper_class: Wrapper to look for
|
| 16 |
+
:return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
|
| 17 |
+
"""
|
| 18 |
+
env_tmp = env
|
| 19 |
+
while isinstance(env_tmp, gym.Wrapper):
|
| 20 |
+
if isinstance(env_tmp, wrapper_class):
|
| 21 |
+
return env_tmp
|
| 22 |
+
env_tmp = env_tmp.env
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Check if a given environment has been wrapped with a given wrapper.
|
| 29 |
+
|
| 30 |
+
:param env: Environment to check
|
| 31 |
+
:param wrapper_class: Wrapper class to look for
|
| 32 |
+
:return: True if environment has been wrapped with ``wrapper_class``.
|
| 33 |
+
"""
|
| 34 |
+
return unwrap_wrapper(env, wrapper_class) is not None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def make_vec_env(
|
| 38 |
+
env_id: Union[str, Type[gym.Env]],
|
| 39 |
+
n_envs: int = 1,
|
| 40 |
+
seed: Optional[int] = None,
|
| 41 |
+
start_index: int = 0,
|
| 42 |
+
monitor_dir: Optional[str] = None,
|
| 43 |
+
wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
|
| 44 |
+
env_kwargs: Optional[Dict[str, Any]] = None,
|
| 45 |
+
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
|
| 46 |
+
vec_env_kwargs: Optional[Dict[str, Any]] = None,
|
| 47 |
+
monitor_kwargs: Optional[Dict[str, Any]] = None,
|
| 48 |
+
wrapper_kwargs: Optional[Dict[str, Any]] = None,
|
| 49 |
+
) -> VecEnv:
|
| 50 |
+
"""
|
| 51 |
+
Create a wrapped, monitored ``VecEnv``.
|
| 52 |
+
By default it uses a ``DummyVecEnv`` which is usually faster
|
| 53 |
+
than a ``SubprocVecEnv``.
|
| 54 |
+
|
| 55 |
+
:param env_id: the environment ID or the environment class
|
| 56 |
+
:param n_envs: the number of environments you wish to have in parallel
|
| 57 |
+
:param seed: the initial seed for the random number generator
|
| 58 |
+
:param start_index: start rank index
|
| 59 |
+
:param monitor_dir: Path to a folder where the monitor files will be saved.
|
| 60 |
+
If None, no file will be written, however, the env will still be wrapped
|
| 61 |
+
in a Monitor wrapper to provide additional information about training.
|
| 62 |
+
:param wrapper_class: Additional wrapper to use on the environment.
|
| 63 |
+
This can also be a function with single argument that wraps the environment in many things.
|
| 64 |
+
:param env_kwargs: Optional keyword argument to pass to the env constructor
|
| 65 |
+
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
|
| 66 |
+
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
|
| 67 |
+
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
|
| 68 |
+
:param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
|
| 69 |
+
:return: The wrapped environment
|
| 70 |
+
"""
|
| 71 |
+
env_kwargs = {} if env_kwargs is None else env_kwargs
|
| 72 |
+
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
|
| 73 |
+
monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
|
| 74 |
+
wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs
|
| 75 |
+
|
| 76 |
+
def make_env(rank):
|
| 77 |
+
def _init():
|
| 78 |
+
if isinstance(env_id, str):
|
| 79 |
+
env = gym.make(env_id, **env_kwargs)
|
| 80 |
+
else:
|
| 81 |
+
env = env_id(**env_kwargs)
|
| 82 |
+
if seed is not None:
|
| 83 |
+
env.seed(seed + rank)
|
| 84 |
+
env.action_space.seed(seed + rank)
|
| 85 |
+
# Wrap the env in a Monitor wrapper
|
| 86 |
+
# to have additional training information
|
| 87 |
+
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
|
| 88 |
+
# Create the monitor folder if needed
|
| 89 |
+
if monitor_path is not None:
|
| 90 |
+
os.makedirs(monitor_dir, exist_ok=True)
|
| 91 |
+
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
|
| 92 |
+
# Optionally, wrap the environment with the provided wrapper
|
| 93 |
+
if wrapper_class is not None:
|
| 94 |
+
env = wrapper_class(env, **wrapper_kwargs)
|
| 95 |
+
return env
|
| 96 |
+
|
| 97 |
+
return _init
|
| 98 |
+
|
| 99 |
+
# No custom VecEnv is passed
|
| 100 |
+
if vec_env_cls is None:
|
| 101 |
+
# Default: use a DummyVecEnv
|
| 102 |
+
vec_env_cls = DummyVecEnv
|
| 103 |
+
|
| 104 |
+
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
|
dexart-release/stable_baselines3/common/evaluation.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import gym
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from stable_baselines3.common import base_class
|
| 8 |
+
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def evaluate_policy(
|
| 12 |
+
model: "base_class.BaseAlgorithm",
|
| 13 |
+
env: Union[gym.Env, VecEnv],
|
| 14 |
+
n_eval_episodes: int = 10,
|
| 15 |
+
deterministic: bool = True,
|
| 16 |
+
render: bool = False,
|
| 17 |
+
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
|
| 18 |
+
reward_threshold: Optional[float] = None,
|
| 19 |
+
return_episode_rewards: bool = False,
|
| 20 |
+
warn: bool = True,
|
| 21 |
+
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
|
| 22 |
+
"""
|
| 23 |
+
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
|
| 24 |
+
If a vector env is passed in, this divides the episodes to evaluate onto the
|
| 25 |
+
different elements of the vector env. This static division of work is done to
|
| 26 |
+
remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
|
| 27 |
+
details and discussion.
|
| 28 |
+
|
| 29 |
+
.. note::
|
| 30 |
+
If environment has not been wrapped with ``Monitor`` wrapper, reward and
|
| 31 |
+
episode lengths are counted as it appears with ``env.step`` calls. If
|
| 32 |
+
the environment contains wrappers that modify rewards or episode lengths
|
| 33 |
+
(e.g. reward scaling, early episode reset), these will affect the evaluation
|
| 34 |
+
results as well. You can avoid this by wrapping environment with ``Monitor``
|
| 35 |
+
wrapper before anything else.
|
| 36 |
+
|
| 37 |
+
:param model: The RL agent you want to evaluate.
|
| 38 |
+
:param env: The gym environment or ``VecEnv`` environment.
|
| 39 |
+
:param n_eval_episodes: Number of episode to evaluate the agent
|
| 40 |
+
:param deterministic: Whether to use deterministic or stochastic actions
|
| 41 |
+
:param render: Whether to render the environment or not
|
| 42 |
+
:param callback: callback function to do additional checks,
|
| 43 |
+
called after each step. Gets locals() and globals() passed as parameters.
|
| 44 |
+
:param reward_threshold: Minimum expected reward per episode,
|
| 45 |
+
this will raise an error if the performance is not met
|
| 46 |
+
:param return_episode_rewards: If True, a list of rewards and episode lengths
|
| 47 |
+
per episode will be returned instead of the mean.
|
| 48 |
+
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
|
| 49 |
+
evaluation environment.
|
| 50 |
+
:return: Mean reward per episode, std of reward per episode.
|
| 51 |
+
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
|
| 52 |
+
list containing per-episode rewards and second containing per-episode lengths
|
| 53 |
+
(in number of steps).
|
| 54 |
+
"""
|
| 55 |
+
is_monitor_wrapped = False
|
| 56 |
+
# Avoid circular import
|
| 57 |
+
from stable_baselines3.common.monitor import Monitor
|
| 58 |
+
|
| 59 |
+
if not isinstance(env, VecEnv):
|
| 60 |
+
env = DummyVecEnv([lambda: env])
|
| 61 |
+
|
| 62 |
+
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
|
| 63 |
+
|
| 64 |
+
if not is_monitor_wrapped and warn:
|
| 65 |
+
warnings.warn(
|
| 66 |
+
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
|
| 67 |
+
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
|
| 68 |
+
"Consider wrapping environment first with ``Monitor`` wrapper.",
|
| 69 |
+
UserWarning,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
n_envs = env.num_envs
|
| 73 |
+
episode_rewards = []
|
| 74 |
+
episode_lengths = []
|
| 75 |
+
|
| 76 |
+
episode_counts = np.zeros(n_envs, dtype="int")
|
| 77 |
+
# Divides episodes among different sub environments in the vector as evenly as possible
|
| 78 |
+
episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
|
| 79 |
+
|
| 80 |
+
current_rewards = np.zeros(n_envs)
|
| 81 |
+
current_lengths = np.zeros(n_envs, dtype="int")
|
| 82 |
+
observations = env.reset()
|
| 83 |
+
states = None
|
| 84 |
+
episode_starts = np.ones((env.num_envs,), dtype=bool)
|
| 85 |
+
while (episode_counts < episode_count_targets).any():
|
| 86 |
+
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
|
| 87 |
+
observations, rewards, dones, infos = env.step(actions)
|
| 88 |
+
current_rewards += rewards
|
| 89 |
+
current_lengths += 1
|
| 90 |
+
for i in range(n_envs):
|
| 91 |
+
if episode_counts[i] < episode_count_targets[i]:
|
| 92 |
+
|
| 93 |
+
# unpack values so that the callback can access the local variables
|
| 94 |
+
reward = rewards[i]
|
| 95 |
+
done = dones[i]
|
| 96 |
+
info = infos[i]
|
| 97 |
+
episode_starts[i] = done
|
| 98 |
+
|
| 99 |
+
if callback is not None:
|
| 100 |
+
callback(locals(), globals())
|
| 101 |
+
|
| 102 |
+
if dones[i]:
|
| 103 |
+
if is_monitor_wrapped:
|
| 104 |
+
# Atari wrapper can send a "done" signal when
|
| 105 |
+
# the agent loses a life, but it does not correspond
|
| 106 |
+
# to the true end of episode
|
| 107 |
+
if "episode" in info.keys():
|
| 108 |
+
# Do not trust "done" with episode endings.
|
| 109 |
+
# Monitor wrapper includes "episode" key in info if environment
|
| 110 |
+
# has been wrapped with it. Use those rewards instead.
|
| 111 |
+
episode_rewards.append(info["episode"]["r"])
|
| 112 |
+
episode_lengths.append(info["episode"]["l"])
|
| 113 |
+
# Only increment at the real end of an episode
|
| 114 |
+
episode_counts[i] += 1
|
| 115 |
+
else:
|
| 116 |
+
episode_rewards.append(current_rewards[i])
|
| 117 |
+
episode_lengths.append(current_lengths[i])
|
| 118 |
+
episode_counts[i] += 1
|
| 119 |
+
current_rewards[i] = 0
|
| 120 |
+
current_lengths[i] = 0
|
| 121 |
+
|
| 122 |
+
if render:
|
| 123 |
+
env.render()
|
| 124 |
+
|
| 125 |
+
mean_reward = np.mean(episode_rewards)
|
| 126 |
+
std_reward = np.std(episode_rewards)
|
| 127 |
+
if reward_threshold is not None:
|
| 128 |
+
assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
|
| 129 |
+
if return_episode_rewards:
|
| 130 |
+
return episode_rewards, episode_lengths
|
| 131 |
+
return mean_reward, std_reward
|
dexart-release/stable_baselines3/common/logger.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
import warnings
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas
|
| 12 |
+
import torch as th
|
| 13 |
+
from matplotlib import pyplot as plt
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 17 |
+
except ImportError:
|
| 18 |
+
SummaryWriter = None
|
| 19 |
+
|
| 20 |
+
DEBUG = 10
|
| 21 |
+
INFO = 20
|
| 22 |
+
WARN = 30
|
| 23 |
+
ERROR = 40
|
| 24 |
+
DISABLED = 50
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Video:
|
| 28 |
+
"""
|
| 29 |
+
Video data class storing the video frames and the frame per seconds
|
| 30 |
+
|
| 31 |
+
:param frames: frames to create the video from
|
| 32 |
+
:param fps: frames per second
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, frames: th.Tensor, fps: Union[float, int]):
|
| 36 |
+
self.frames = frames
|
| 37 |
+
self.fps = fps
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Figure:
|
| 41 |
+
"""
|
| 42 |
+
Figure data class storing a matplotlib figure and whether to close the figure after logging it
|
| 43 |
+
|
| 44 |
+
:param figure: figure to log
|
| 45 |
+
:param close: if true, close the figure after logging it
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, figure: plt.figure, close: bool):
|
| 49 |
+
self.figure = figure
|
| 50 |
+
self.close = close
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Image:
|
| 54 |
+
"""
|
| 55 |
+
Image data class storing an image and data format
|
| 56 |
+
|
| 57 |
+
:param image: image to log
|
| 58 |
+
:param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc.
|
| 59 |
+
More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html
|
| 60 |
+
Gym envs normally use 'HWC' (channel last)
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):
|
| 64 |
+
self.image = image
|
| 65 |
+
self.dataformats = dataformats
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class FormatUnsupportedError(NotImplementedError):
|
| 69 |
+
"""
|
| 70 |
+
Custom error to display informative message when
|
| 71 |
+
a value is not supported by some formats.
|
| 72 |
+
|
| 73 |
+
:param unsupported_formats: A sequence of unsupported formats,
|
| 74 |
+
for instance ``["stdout"]``.
|
| 75 |
+
:param value_description: Description of the value that cannot be logged by this format.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, unsupported_formats: Sequence[str], value_description: str):
|
| 79 |
+
if len(unsupported_formats) > 1:
|
| 80 |
+
format_str = f"formats {', '.join(unsupported_formats)} are"
|
| 81 |
+
else:
|
| 82 |
+
format_str = f"format {unsupported_formats[0]} is"
|
| 83 |
+
super().__init__(
|
| 84 |
+
f"The {format_str} not supported for the {value_description} value logged.\n"
|
| 85 |
+
f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class KVWriter:
|
| 90 |
+
"""
|
| 91 |
+
Key Value writer
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
|
| 95 |
+
step: int = 0) -> None:
|
| 96 |
+
"""
|
| 97 |
+
Write a dictionary to file
|
| 98 |
+
|
| 99 |
+
:param key_values:
|
| 100 |
+
:param key_excluded:
|
| 101 |
+
:param step:
|
| 102 |
+
"""
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
|
| 105 |
+
def close(self) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Close owned resources
|
| 108 |
+
"""
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class SeqWriter:
|
| 113 |
+
"""
|
| 114 |
+
sequence writer
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def write_sequence(self, sequence: List) -> None:
|
| 118 |
+
"""
|
| 119 |
+
write_sequence an array to file
|
| 120 |
+
|
| 121 |
+
:param sequence:
|
| 122 |
+
"""
|
| 123 |
+
raise NotImplementedError
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class HumanOutputFormat(KVWriter, SeqWriter):
|
| 127 |
+
"""A human-readable output format producing ASCII tables of key-value pairs.
|
| 128 |
+
|
| 129 |
+
Set attribute ``max_length`` to change the maximum length of keys and values
|
| 130 |
+
to write to output (or specify it when calling ``__init__``).
|
| 131 |
+
|
| 132 |
+
:param filename_or_file: the file to write the log to
|
| 133 |
+
:param max_length: the maximum length of keys and values to write to output.
|
| 134 |
+
Outputs longer than this will be truncated. An error will be raised
|
| 135 |
+
if multiple keys are truncated to the same value. The maximum output
|
| 136 |
+
width will be ``2*max_length + 7``. The default of 36 produces output
|
| 137 |
+
no longer than 79 characters wide.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
|
| 141 |
+
self.max_length = max_length
|
| 142 |
+
if isinstance(filename_or_file, str):
|
| 143 |
+
self.file = open(filename_or_file, "wt")
|
| 144 |
+
self.own_file = True
|
| 145 |
+
else:
|
| 146 |
+
assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
|
| 147 |
+
self.file = filename_or_file
|
| 148 |
+
self.own_file = False
|
| 149 |
+
|
| 150 |
+
def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
|
| 151 |
+
# Create strings for printing
|
| 152 |
+
key2str = {}
|
| 153 |
+
tag = None
|
| 154 |
+
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
| 155 |
+
|
| 156 |
+
if excluded is not None and ("stdout" in excluded or "log" in excluded):
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
elif isinstance(value, Video):
|
| 160 |
+
raise FormatUnsupportedError(["stdout", "log"], "video")
|
| 161 |
+
|
| 162 |
+
elif isinstance(value, Figure):
|
| 163 |
+
raise FormatUnsupportedError(["stdout", "log"], "figure")
|
| 164 |
+
|
| 165 |
+
elif isinstance(value, Image):
|
| 166 |
+
raise FormatUnsupportedError(["stdout", "log"], "image")
|
| 167 |
+
|
| 168 |
+
elif isinstance(value, float):
|
| 169 |
+
# Align left
|
| 170 |
+
value_str = f"{value:<8.3g}"
|
| 171 |
+
else:
|
| 172 |
+
value_str = str(value)
|
| 173 |
+
|
| 174 |
+
if key.find("/") > 0: # Find tag and add it to the dict
|
| 175 |
+
tag = key[: key.find("/") + 1]
|
| 176 |
+
key2str[self._truncate(tag)] = ""
|
| 177 |
+
# Remove tag from key
|
| 178 |
+
if tag is not None and tag in key:
|
| 179 |
+
key = str(" " + key[len(tag):])
|
| 180 |
+
|
| 181 |
+
truncated_key = self._truncate(key)
|
| 182 |
+
if truncated_key in key2str:
|
| 183 |
+
raise ValueError(
|
| 184 |
+
f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
|
| 185 |
+
)
|
| 186 |
+
key2str[truncated_key] = self._truncate(value_str)
|
| 187 |
+
|
| 188 |
+
# Find max widths
|
| 189 |
+
if len(key2str) == 0:
|
| 190 |
+
warnings.warn("Tried to write empty key-value dict")
|
| 191 |
+
return
|
| 192 |
+
else:
|
| 193 |
+
key_width = max(map(len, key2str.keys()))
|
| 194 |
+
val_width = max(map(len, key2str.values()))
|
| 195 |
+
|
| 196 |
+
# Write out the data
|
| 197 |
+
dashes = "-" * (key_width + val_width + 7)
|
| 198 |
+
lines = [dashes]
|
| 199 |
+
for key, value in key2str.items():
|
| 200 |
+
key_space = " " * (key_width - len(key))
|
| 201 |
+
val_space = " " * (val_width - len(value))
|
| 202 |
+
lines.append(f"| {key}{key_space} | {value}{val_space} |")
|
| 203 |
+
lines.append(dashes)
|
| 204 |
+
self.file.write("\n".join(lines) + "\n")
|
| 205 |
+
|
| 206 |
+
# Flush the output to the file
|
| 207 |
+
self.file.flush()
|
| 208 |
+
|
| 209 |
+
def _truncate(self, string: str) -> str:
|
| 210 |
+
if len(string) > self.max_length:
|
| 211 |
+
string = string[: self.max_length - 3] + "..."
|
| 212 |
+
return string
|
| 213 |
+
|
| 214 |
+
def write_sequence(self, sequence: List) -> None:
|
| 215 |
+
sequence = list(sequence)
|
| 216 |
+
for i, elem in enumerate(sequence):
|
| 217 |
+
self.file.write(elem)
|
| 218 |
+
if i < len(sequence) - 1: # add space unless this is the last one
|
| 219 |
+
self.file.write(" ")
|
| 220 |
+
self.file.write("\n")
|
| 221 |
+
self.file.flush()
|
| 222 |
+
|
| 223 |
+
def close(self) -> None:
|
| 224 |
+
"""
|
| 225 |
+
closes the file
|
| 226 |
+
"""
|
| 227 |
+
if self.own_file:
|
| 228 |
+
self.file.close()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def filter_excluded_keys(
|
| 232 |
+
key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
|
| 233 |
+
) -> Dict[str, Any]:
|
| 234 |
+
"""
|
| 235 |
+
Filters the keys specified by ``key_exclude`` for the specified format
|
| 236 |
+
|
| 237 |
+
:param key_values: log dictionary to be filtered
|
| 238 |
+
:param key_excluded: keys to be excluded per format
|
| 239 |
+
:param _format: format for which this filter is run
|
| 240 |
+
:return: dict without the excluded keys
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def is_excluded(key: str) -> bool:
|
| 244 |
+
return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key]
|
| 245 |
+
|
| 246 |
+
return {key: value for key, value in key_values.items() if not is_excluded(key)}
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class JSONOutputFormat(KVWriter):
|
| 250 |
+
"""
|
| 251 |
+
Log to a file, in the JSON format
|
| 252 |
+
|
| 253 |
+
:param filename: the file to write the log to
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, filename: str):
|
| 257 |
+
self.file = open(filename, "wt")
|
| 258 |
+
|
| 259 |
+
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
|
| 260 |
+
step: int = 0) -> None:
|
| 261 |
+
def cast_to_json_serializable(value: Any):
|
| 262 |
+
if isinstance(value, Video):
|
| 263 |
+
raise FormatUnsupportedError(["json"], "video")
|
| 264 |
+
if isinstance(value, Figure):
|
| 265 |
+
raise FormatUnsupportedError(["json"], "figure")
|
| 266 |
+
if isinstance(value, Image):
|
| 267 |
+
raise FormatUnsupportedError(["json"], "image")
|
| 268 |
+
if hasattr(value, "dtype"):
|
| 269 |
+
if value.shape == () or len(value) == 1:
|
| 270 |
+
# if value is a dimensionless numpy array or of length 1, serialize as a float
|
| 271 |
+
return float(value)
|
| 272 |
+
else:
|
| 273 |
+
# otherwise, a value is a numpy array, serialize as a list or nested lists
|
| 274 |
+
return value.tolist()
|
| 275 |
+
return value
|
| 276 |
+
|
| 277 |
+
key_values = {
|
| 278 |
+
key: cast_to_json_serializable(value)
|
| 279 |
+
for key, value in filter_excluded_keys(key_values, key_excluded, "json").items()
|
| 280 |
+
}
|
| 281 |
+
self.file.write(json.dumps(key_values) + "\n")
|
| 282 |
+
self.file.flush()
|
| 283 |
+
|
| 284 |
+
def close(self) -> None:
|
| 285 |
+
"""
|
| 286 |
+
closes the file
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
self.file.close()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class CSVOutputFormat(KVWriter):
|
| 293 |
+
"""
|
| 294 |
+
Log to a file, in a CSV format
|
| 295 |
+
|
| 296 |
+
:param filename: the file to write the log to
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def __init__(self, filename: str):
|
| 300 |
+
self.file = open(filename, "w+t")
|
| 301 |
+
self.keys = []
|
| 302 |
+
self.separator = ","
|
| 303 |
+
self.quotechar = '"'
|
| 304 |
+
|
| 305 |
+
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
|
| 306 |
+
step: int = 0) -> None:
|
| 307 |
+
# Add our current row to the history
|
| 308 |
+
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
|
| 309 |
+
extra_keys = key_values.keys() - self.keys
|
| 310 |
+
if extra_keys:
|
| 311 |
+
self.keys.extend(extra_keys)
|
| 312 |
+
self.file.seek(0)
|
| 313 |
+
lines = self.file.readlines()
|
| 314 |
+
self.file.seek(0)
|
| 315 |
+
for (i, key) in enumerate(self.keys):
|
| 316 |
+
if i > 0:
|
| 317 |
+
self.file.write(",")
|
| 318 |
+
self.file.write(key)
|
| 319 |
+
self.file.write("\n")
|
| 320 |
+
for line in lines[1:]:
|
| 321 |
+
self.file.write(line[:-1])
|
| 322 |
+
self.file.write(self.separator * len(extra_keys))
|
| 323 |
+
self.file.write("\n")
|
| 324 |
+
for i, key in enumerate(self.keys):
|
| 325 |
+
if i > 0:
|
| 326 |
+
self.file.write(",")
|
| 327 |
+
value = key_values.get(key)
|
| 328 |
+
|
| 329 |
+
if isinstance(value, Video):
|
| 330 |
+
raise FormatUnsupportedError(["csv"], "video")
|
| 331 |
+
|
| 332 |
+
elif isinstance(value, Figure):
|
| 333 |
+
raise FormatUnsupportedError(["csv"], "figure")
|
| 334 |
+
|
| 335 |
+
elif isinstance(value, Image):
|
| 336 |
+
raise FormatUnsupportedError(["csv"], "image")
|
| 337 |
+
|
| 338 |
+
elif isinstance(value, str):
|
| 339 |
+
# escape quotechars by prepending them with another quotechar
|
| 340 |
+
value = value.replace(self.quotechar, self.quotechar + self.quotechar)
|
| 341 |
+
|
| 342 |
+
# additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers
|
| 343 |
+
self.file.write(self.quotechar + value + self.quotechar)
|
| 344 |
+
|
| 345 |
+
elif value is not None:
|
| 346 |
+
self.file.write(str(value))
|
| 347 |
+
self.file.write("\n")
|
| 348 |
+
self.file.flush()
|
| 349 |
+
|
| 350 |
+
def close(self) -> None:
|
| 351 |
+
"""
|
| 352 |
+
closes the file
|
| 353 |
+
"""
|
| 354 |
+
self.file.close()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class TensorBoardOutputFormat(KVWriter):
|
| 358 |
+
"""
|
| 359 |
+
Dumps key/value pairs into TensorBoard's numeric format.
|
| 360 |
+
|
| 361 |
+
:param folder: the folder to write the log to
|
| 362 |
+
"""
|
| 363 |
+
|
| 364 |
+
def __init__(self, folder: str):
|
| 365 |
+
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
|
| 366 |
+
self.writer = SummaryWriter(log_dir=folder)
|
| 367 |
+
|
| 368 |
+
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
|
| 369 |
+
step: int = 0) -> None:
|
| 370 |
+
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
| 371 |
+
|
| 372 |
+
if excluded is not None and "tensorboard" in excluded:
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
if isinstance(value, np.ScalarType):
|
| 376 |
+
if isinstance(value, str):
|
| 377 |
+
# str is considered a np.ScalarType
|
| 378 |
+
self.writer.add_text(key, value, step)
|
| 379 |
+
else:
|
| 380 |
+
self.writer.add_scalar(key, value, step)
|
| 381 |
+
|
| 382 |
+
if isinstance(value, th.Tensor):
|
| 383 |
+
self.writer.add_histogram(key, value, step)
|
| 384 |
+
|
| 385 |
+
if isinstance(value, Video):
|
| 386 |
+
self.writer.add_video(key, value.frames, step, value.fps)
|
| 387 |
+
|
| 388 |
+
if isinstance(value, Figure):
|
| 389 |
+
self.writer.add_figure(key, value.figure, step, close=value.close)
|
| 390 |
+
|
| 391 |
+
if isinstance(value, Image):
|
| 392 |
+
self.writer.add_image(key, value.image, step, dataformats=value.dataformats)
|
| 393 |
+
|
| 394 |
+
# Flush the output to the file
|
| 395 |
+
self.writer.flush()
|
| 396 |
+
|
| 397 |
+
def close(self) -> None:
|
| 398 |
+
"""
|
| 399 |
+
closes the file
|
| 400 |
+
"""
|
| 401 |
+
if self.writer:
|
| 402 |
+
self.writer.close()
|
| 403 |
+
self.writer = None
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
|
| 407 |
+
"""
|
| 408 |
+
return a logger for the requested format
|
| 409 |
+
|
| 410 |
+
:param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard')
|
| 411 |
+
:param log_dir: the logging directory
|
| 412 |
+
:param log_suffix: the suffix for the log file
|
| 413 |
+
:return: the logger
|
| 414 |
+
"""
|
| 415 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 416 |
+
if _format == "stdout":
|
| 417 |
+
return HumanOutputFormat(sys.stdout)
|
| 418 |
+
elif _format == "log":
|
| 419 |
+
return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt"))
|
| 420 |
+
elif _format == "json":
|
| 421 |
+
return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json"))
|
| 422 |
+
elif _format == "csv":
|
| 423 |
+
return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv"))
|
| 424 |
+
elif _format == "tensorboard":
|
| 425 |
+
return TensorBoardOutputFormat(log_dir)
|
| 426 |
+
else:
|
| 427 |
+
raise ValueError(f"Unknown format specified: {_format}")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ================================================================
|
| 431 |
+
# Backend
|
| 432 |
+
# ================================================================
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class Logger:
|
| 436 |
+
"""
|
| 437 |
+
The logger class.
|
| 438 |
+
|
| 439 |
+
:param folder: the logging location
|
| 440 |
+
:param output_formats: the list of output formats
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
|
| 444 |
+
self.name_to_value = defaultdict(float) # values this iteration
|
| 445 |
+
self.name_to_count = defaultdict(int)
|
| 446 |
+
self.name_to_excluded = defaultdict(str)
|
| 447 |
+
self.level = INFO
|
| 448 |
+
self.dir = folder
|
| 449 |
+
self.output_formats = output_formats
|
| 450 |
+
|
| 451 |
+
def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
| 452 |
+
"""
|
| 453 |
+
Log a value of some diagnostic
|
| 454 |
+
Call this once for each diagnostic quantity, each iteration
|
| 455 |
+
If called many times, last value will be used.
|
| 456 |
+
|
| 457 |
+
:param key: save to log this key
|
| 458 |
+
:param value: save to log this value
|
| 459 |
+
:param exclude: outputs to be excluded
|
| 460 |
+
"""
|
| 461 |
+
self.name_to_value[key] = value
|
| 462 |
+
self.name_to_excluded[key] = exclude
|
| 463 |
+
|
| 464 |
+
def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
| 465 |
+
"""
|
| 466 |
+
The same as record(), but if called many times, values averaged.
|
| 467 |
+
|
| 468 |
+
:param key: save to log this key
|
| 469 |
+
:param value: save to log this value
|
| 470 |
+
:param exclude: outputs to be excluded
|
| 471 |
+
"""
|
| 472 |
+
if value is None:
|
| 473 |
+
self.name_to_value[key] = None
|
| 474 |
+
return
|
| 475 |
+
old_val, count = self.name_to_value[key], self.name_to_count[key]
|
| 476 |
+
self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
|
| 477 |
+
self.name_to_count[key] = count + 1
|
| 478 |
+
self.name_to_excluded[key] = exclude
|
| 479 |
+
|
| 480 |
+
def dump(self, step: int = 0) -> None:
|
| 481 |
+
"""
|
| 482 |
+
Write all of the diagnostics from the current iteration
|
| 483 |
+
"""
|
| 484 |
+
if self.level == DISABLED:
|
| 485 |
+
return
|
| 486 |
+
for _format in self.output_formats:
|
| 487 |
+
if isinstance(_format, KVWriter):
|
| 488 |
+
_format.write(self.name_to_value, self.name_to_excluded, step)
|
| 489 |
+
|
| 490 |
+
self.name_to_value.clear()
|
| 491 |
+
self.name_to_count.clear()
|
| 492 |
+
self.name_to_excluded.clear()
|
| 493 |
+
|
| 494 |
+
def log(self, *args, level: int = INFO) -> None:
|
| 495 |
+
"""
|
| 496 |
+
Write the sequence of args, with no separators,
|
| 497 |
+
to the console and output files (if you've configured an output file).
|
| 498 |
+
|
| 499 |
+
level: int. (see logger.py docs) If the global logger level is higher than
|
| 500 |
+
the level argument here, don't print to stdout.
|
| 501 |
+
|
| 502 |
+
:param args: log the arguments
|
| 503 |
+
:param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
|
| 504 |
+
"""
|
| 505 |
+
if self.level <= level:
|
| 506 |
+
self._do_log(args)
|
| 507 |
+
|
| 508 |
+
def debug(self, *args) -> None:
|
| 509 |
+
"""
|
| 510 |
+
Write the sequence of args, with no separators,
|
| 511 |
+
to the console and output files (if you've configured an output file).
|
| 512 |
+
Using the DEBUG level.
|
| 513 |
+
|
| 514 |
+
:param args: log the arguments
|
| 515 |
+
"""
|
| 516 |
+
self.log(*args, level=DEBUG)
|
| 517 |
+
|
| 518 |
+
def info(self, *args) -> None:
|
| 519 |
+
"""
|
| 520 |
+
Write the sequence of args, with no separators,
|
| 521 |
+
to the console and output files (if you've configured an output file).
|
| 522 |
+
Using the INFO level.
|
| 523 |
+
|
| 524 |
+
:param args: log the arguments
|
| 525 |
+
"""
|
| 526 |
+
self.log(*args, level=INFO)
|
| 527 |
+
|
| 528 |
+
def warn(self, *args) -> None:
|
| 529 |
+
"""
|
| 530 |
+
Write the sequence of args, with no separators,
|
| 531 |
+
to the console and output files (if you've configured an output file).
|
| 532 |
+
Using the WARN level.
|
| 533 |
+
|
| 534 |
+
:param args: log the arguments
|
| 535 |
+
"""
|
| 536 |
+
self.log(*args, level=WARN)
|
| 537 |
+
|
| 538 |
+
def error(self, *args) -> None:
|
| 539 |
+
"""
|
| 540 |
+
Write the sequence of args, with no separators,
|
| 541 |
+
to the console and output files (if you've configured an output file).
|
| 542 |
+
Using the ERROR level.
|
| 543 |
+
|
| 544 |
+
:param args: log the arguments
|
| 545 |
+
"""
|
| 546 |
+
self.log(*args, level=ERROR)
|
| 547 |
+
|
| 548 |
+
# Configuration
|
| 549 |
+
# ----------------------------------------
|
| 550 |
+
def set_level(self, level: int) -> None:
|
| 551 |
+
"""
|
| 552 |
+
Set logging threshold on current logger.
|
| 553 |
+
|
| 554 |
+
:param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
|
| 555 |
+
"""
|
| 556 |
+
self.level = level
|
| 557 |
+
|
| 558 |
+
def get_dir(self) -> str:
|
| 559 |
+
"""
|
| 560 |
+
Get directory that log files are being written to.
|
| 561 |
+
will be None if there is no output directory (i.e., if you didn't call start)
|
| 562 |
+
|
| 563 |
+
:return: the logging directory
|
| 564 |
+
"""
|
| 565 |
+
return self.dir
|
| 566 |
+
|
| 567 |
+
def close(self) -> None:
|
| 568 |
+
"""
|
| 569 |
+
closes the file
|
| 570 |
+
"""
|
| 571 |
+
for _format in self.output_formats:
|
| 572 |
+
_format.close()
|
| 573 |
+
|
| 574 |
+
# Misc
|
| 575 |
+
# ----------------------------------------
|
| 576 |
+
def _do_log(self, args) -> None:
|
| 577 |
+
"""
|
| 578 |
+
log to the requested format outputs
|
| 579 |
+
|
| 580 |
+
:param args: the arguments to log
|
| 581 |
+
"""
|
| 582 |
+
for _format in self.output_formats:
|
| 583 |
+
if isinstance(_format, SeqWriter):
|
| 584 |
+
_format.write_sequence(map(str, args))
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
|
| 588 |
+
"""
|
| 589 |
+
Configure the current logger.
|
| 590 |
+
|
| 591 |
+
:param folder: the save location
|
| 592 |
+
(if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time])
|
| 593 |
+
:param format_strings: the output logging format
|
| 594 |
+
(if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
|
| 595 |
+
:return: The logger object.
|
| 596 |
+
"""
|
| 597 |
+
if folder is None:
|
| 598 |
+
folder = os.getenv("SB3_LOGDIR")
|
| 599 |
+
if folder is None:
|
| 600 |
+
folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"))
|
| 601 |
+
assert isinstance(folder, str)
|
| 602 |
+
os.makedirs(folder, exist_ok=True)
|
| 603 |
+
|
| 604 |
+
log_suffix = ""
|
| 605 |
+
if format_strings is None:
|
| 606 |
+
format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",")
|
| 607 |
+
|
| 608 |
+
format_strings = list(filter(None, format_strings))
|
| 609 |
+
output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings]
|
| 610 |
+
|
| 611 |
+
logger = Logger(folder=folder, output_formats=output_formats)
|
| 612 |
+
# Only print when some files will be saved
|
| 613 |
+
if len(format_strings) > 0 and format_strings != ["stdout"]:
|
| 614 |
+
logger.log(f"Logging to {folder}")
|
| 615 |
+
return logger
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
# ================================================================
|
| 619 |
+
# Readers
|
| 620 |
+
# ================================================================
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def read_json(filename: str) -> pandas.DataFrame:
|
| 624 |
+
"""
|
| 625 |
+
read a json file using pandas
|
| 626 |
+
|
| 627 |
+
:param filename: the file path to read
|
| 628 |
+
:return: the data in the json
|
| 629 |
+
"""
|
| 630 |
+
data = []
|
| 631 |
+
with open(filename) as file_handler:
|
| 632 |
+
for line in file_handler:
|
| 633 |
+
data.append(json.loads(line))
|
| 634 |
+
return pandas.DataFrame(data)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def read_csv(filename: str) -> pandas.DataFrame:
|
| 638 |
+
"""
|
| 639 |
+
read a csv file using pandas
|
| 640 |
+
|
| 641 |
+
:param filename: the file path to read
|
| 642 |
+
:return: the data in the csv
|
| 643 |
+
"""
|
| 644 |
+
return pandas.read_csv(filename, index_col=None, comment="#")
|
dexart-release/stable_baselines3/common/monitor.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from glob import glob
|
| 8 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import gym
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas
|
| 13 |
+
|
| 14 |
+
from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Monitor(gym.Wrapper):
|
| 18 |
+
"""
|
| 19 |
+
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
|
| 20 |
+
|
| 21 |
+
:param env: The environment
|
| 22 |
+
:param filename: the location to save a log file, can be None for no log
|
| 23 |
+
:param allow_early_resets: allows the reset of the environment before it is done
|
| 24 |
+
:param reset_keywords: extra keywords for the reset call,
|
| 25 |
+
if extra parameters are needed at reset
|
| 26 |
+
:param info_keywords: extra information to log, from the information return of env.step()
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
EXT = "monitor.csv"
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
env: gym.Env,
|
| 34 |
+
filename: Optional[str] = None,
|
| 35 |
+
allow_early_resets: bool = True,
|
| 36 |
+
reset_keywords: Tuple[str, ...] = (),
|
| 37 |
+
info_keywords: Tuple[str, ...] = (),
|
| 38 |
+
):
|
| 39 |
+
super().__init__(env=env)
|
| 40 |
+
self.t_start = time.time()
|
| 41 |
+
if filename is not None:
|
| 42 |
+
self.results_writer = ResultsWriter(
|
| 43 |
+
filename,
|
| 44 |
+
header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
|
| 45 |
+
extra_keys=reset_keywords + info_keywords,
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
self.results_writer = None
|
| 49 |
+
self.reset_keywords = reset_keywords
|
| 50 |
+
self.info_keywords = info_keywords
|
| 51 |
+
self.allow_early_resets = allow_early_resets
|
| 52 |
+
self.rewards = None
|
| 53 |
+
self.needs_reset = True
|
| 54 |
+
self.episode_returns = []
|
| 55 |
+
self.episode_lengths = []
|
| 56 |
+
self.episode_times = []
|
| 57 |
+
self.total_steps = 0
|
| 58 |
+
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
|
| 59 |
+
|
| 60 |
+
def reset(self, **kwargs) -> GymObs:
|
| 61 |
+
"""
|
| 62 |
+
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
|
| 63 |
+
|
| 64 |
+
:param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
|
| 65 |
+
:return: the first observation of the environment
|
| 66 |
+
"""
|
| 67 |
+
if not self.allow_early_resets and not self.needs_reset:
|
| 68 |
+
raise RuntimeError(
|
| 69 |
+
"Tried to reset an environment before done. If you want to allow early resets, "
|
| 70 |
+
"wrap your env with Monitor(env, path, allow_early_resets=True)"
|
| 71 |
+
)
|
| 72 |
+
self.rewards = []
|
| 73 |
+
self.needs_reset = False
|
| 74 |
+
for key in self.reset_keywords:
|
| 75 |
+
value = kwargs.get(key)
|
| 76 |
+
if value is None:
|
| 77 |
+
raise ValueError(f"Expected you to pass keyword argument {key} into reset")
|
| 78 |
+
self.current_reset_info[key] = value
|
| 79 |
+
return self.env.reset(**kwargs)
|
| 80 |
+
|
| 81 |
+
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
|
| 82 |
+
"""
|
| 83 |
+
Step the environment with the given action
|
| 84 |
+
|
| 85 |
+
:param action: the action
|
| 86 |
+
:return: observation, reward, done, information
|
| 87 |
+
"""
|
| 88 |
+
if self.needs_reset:
|
| 89 |
+
raise RuntimeError("Tried to step environment that needs reset")
|
| 90 |
+
observation, reward, done, info = self.env.step(action)
|
| 91 |
+
self.rewards.append(reward)
|
| 92 |
+
if done:
|
| 93 |
+
self.needs_reset = True
|
| 94 |
+
ep_rew = sum(self.rewards)
|
| 95 |
+
ep_len = len(self.rewards)
|
| 96 |
+
ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
|
| 97 |
+
for key in self.info_keywords:
|
| 98 |
+
ep_info[key] = info[key]
|
| 99 |
+
self.episode_returns.append(ep_rew)
|
| 100 |
+
self.episode_lengths.append(ep_len)
|
| 101 |
+
self.episode_times.append(time.time() - self.t_start)
|
| 102 |
+
ep_info.update(self.current_reset_info)
|
| 103 |
+
if self.results_writer:
|
| 104 |
+
self.results_writer.write_row(ep_info)
|
| 105 |
+
info["episode"] = ep_info
|
| 106 |
+
self.total_steps += 1
|
| 107 |
+
return observation, reward, done, info
|
| 108 |
+
|
| 109 |
+
def close(self) -> None:
|
| 110 |
+
"""
|
| 111 |
+
Closes the environment
|
| 112 |
+
"""
|
| 113 |
+
super().close()
|
| 114 |
+
if self.results_writer is not None:
|
| 115 |
+
self.results_writer.close()
|
| 116 |
+
|
| 117 |
+
def get_total_steps(self) -> int:
|
| 118 |
+
"""
|
| 119 |
+
Returns the total number of timesteps
|
| 120 |
+
|
| 121 |
+
:return:
|
| 122 |
+
"""
|
| 123 |
+
return self.total_steps
|
| 124 |
+
|
| 125 |
+
def get_episode_rewards(self) -> List[float]:
|
| 126 |
+
"""
|
| 127 |
+
Returns the rewards of all the episodes
|
| 128 |
+
|
| 129 |
+
:return:
|
| 130 |
+
"""
|
| 131 |
+
return self.episode_returns
|
| 132 |
+
|
| 133 |
+
def get_episode_lengths(self) -> List[int]:
|
| 134 |
+
"""
|
| 135 |
+
Returns the number of timesteps of all the episodes
|
| 136 |
+
|
| 137 |
+
:return:
|
| 138 |
+
"""
|
| 139 |
+
return self.episode_lengths
|
| 140 |
+
|
| 141 |
+
def get_episode_times(self) -> List[float]:
|
| 142 |
+
"""
|
| 143 |
+
Returns the runtime in seconds of all the episodes
|
| 144 |
+
|
| 145 |
+
:return:
|
| 146 |
+
"""
|
| 147 |
+
return self.episode_times
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class LoadMonitorResultsError(Exception):
|
| 151 |
+
"""
|
| 152 |
+
Raised when loading the monitor log fails.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class ResultsWriter:
|
| 159 |
+
"""
|
| 160 |
+
A result writer that saves the data from the `Monitor` class
|
| 161 |
+
|
| 162 |
+
:param filename: the location to save a log file, can be None for no log
|
| 163 |
+
:param header: the header dictionary object of the saved csv
|
| 164 |
+
:param reset_keywords: the extra information to log, typically is composed of
|
| 165 |
+
``reset_keywords`` and ``info_keywords``
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
filename: str = "",
|
| 171 |
+
header: Optional[Dict[str, Union[float, str]]] = None,
|
| 172 |
+
extra_keys: Tuple[str, ...] = (),
|
| 173 |
+
):
|
| 174 |
+
if header is None:
|
| 175 |
+
header = {}
|
| 176 |
+
if not filename.endswith(Monitor.EXT):
|
| 177 |
+
if os.path.isdir(filename):
|
| 178 |
+
filename = os.path.join(filename, Monitor.EXT)
|
| 179 |
+
else:
|
| 180 |
+
filename = filename + "." + Monitor.EXT
|
| 181 |
+
# Prevent newline issue on Windows, see GH issue #692
|
| 182 |
+
self.file_handler = open(filename, "wt", newline="\n")
|
| 183 |
+
self.file_handler.write("#%s\n" % json.dumps(header))
|
| 184 |
+
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
|
| 185 |
+
self.logger.writeheader()
|
| 186 |
+
self.file_handler.flush()
|
| 187 |
+
|
| 188 |
+
def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
|
| 189 |
+
"""
|
| 190 |
+
Close the file handler
|
| 191 |
+
|
| 192 |
+
:param epinfo: the information on episodic return, length, and time
|
| 193 |
+
"""
|
| 194 |
+
if self.logger:
|
| 195 |
+
self.logger.writerow(epinfo)
|
| 196 |
+
self.file_handler.flush()
|
| 197 |
+
|
| 198 |
+
def close(self) -> None:
|
| 199 |
+
"""
|
| 200 |
+
Close the file handler
|
| 201 |
+
"""
|
| 202 |
+
self.file_handler.close()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def get_monitor_files(path: str) -> List[str]:
|
| 206 |
+
"""
|
| 207 |
+
get all the monitor files in the given path
|
| 208 |
+
|
| 209 |
+
:param path: the logging folder
|
| 210 |
+
:return: the log files
|
| 211 |
+
"""
|
| 212 |
+
return glob(os.path.join(path, "*" + Monitor.EXT))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def load_results(path: str) -> pandas.DataFrame:
|
| 216 |
+
"""
|
| 217 |
+
Load all Monitor logs from a given directory path matching ``*monitor.csv``
|
| 218 |
+
|
| 219 |
+
:param path: the directory path containing the log file(s)
|
| 220 |
+
:return: the logged data
|
| 221 |
+
"""
|
| 222 |
+
monitor_files = get_monitor_files(path)
|
| 223 |
+
if len(monitor_files) == 0:
|
| 224 |
+
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
|
| 225 |
+
data_frames, headers = [], []
|
| 226 |
+
for file_name in monitor_files:
|
| 227 |
+
with open(file_name) as file_handler:
|
| 228 |
+
first_line = file_handler.readline()
|
| 229 |
+
assert first_line[0] == "#"
|
| 230 |
+
header = json.loads(first_line[1:])
|
| 231 |
+
data_frame = pandas.read_csv(file_handler, index_col=None)
|
| 232 |
+
headers.append(header)
|
| 233 |
+
data_frame["t"] += header["t_start"]
|
| 234 |
+
data_frames.append(data_frame)
|
| 235 |
+
data_frame = pandas.concat(data_frames)
|
| 236 |
+
data_frame.sort_values("t", inplace=True)
|
| 237 |
+
data_frame.reset_index(inplace=True)
|
| 238 |
+
data_frame["t"] -= min(header["t_start"] for header in headers)
|
| 239 |
+
return data_frame
|
dexart-release/stable_baselines3/common/noise.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Iterable, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ActionNoise(ABC):
|
| 9 |
+
"""
|
| 10 |
+
The action noise base class
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
def reset(self) -> None:
|
| 17 |
+
"""
|
| 18 |
+
call end of episode reset for the noise
|
| 19 |
+
"""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def __call__(self) -> np.ndarray:
|
| 24 |
+
raise NotImplementedError()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class NormalActionNoise(ActionNoise):
|
| 28 |
+
"""
|
| 29 |
+
A Gaussian action noise
|
| 30 |
+
|
| 31 |
+
:param mean: the mean value of the noise
|
| 32 |
+
:param sigma: the scale of the noise (std here)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
|
| 36 |
+
self._mu = mean
|
| 37 |
+
self._sigma = sigma
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
def __call__(self) -> np.ndarray:
|
| 41 |
+
return np.random.normal(self._mu, self._sigma)
|
| 42 |
+
|
| 43 |
+
def __repr__(self) -> str:
|
| 44 |
+
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
| 48 |
+
"""
|
| 49 |
+
An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
|
| 50 |
+
|
| 51 |
+
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
| 52 |
+
|
| 53 |
+
:param mean: the mean of the noise
|
| 54 |
+
:param sigma: the scale of the noise
|
| 55 |
+
:param theta: the rate of mean reversion
|
| 56 |
+
:param dt: the timestep for the noise
|
| 57 |
+
:param initial_noise: the initial value for the noise output, (if None: 0)
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
mean: np.ndarray,
|
| 63 |
+
sigma: np.ndarray,
|
| 64 |
+
theta: float = 0.15,
|
| 65 |
+
dt: float = 1e-2,
|
| 66 |
+
initial_noise: Optional[np.ndarray] = None,
|
| 67 |
+
):
|
| 68 |
+
self._theta = theta
|
| 69 |
+
self._mu = mean
|
| 70 |
+
self._sigma = sigma
|
| 71 |
+
self._dt = dt
|
| 72 |
+
self.initial_noise = initial_noise
|
| 73 |
+
self.noise_prev = np.zeros_like(self._mu)
|
| 74 |
+
self.reset()
|
| 75 |
+
super().__init__()
|
| 76 |
+
|
| 77 |
+
def __call__(self) -> np.ndarray:
|
| 78 |
+
noise = (
|
| 79 |
+
self.noise_prev
|
| 80 |
+
+ self._theta * (self._mu - self.noise_prev) * self._dt
|
| 81 |
+
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
|
| 82 |
+
)
|
| 83 |
+
self.noise_prev = noise
|
| 84 |
+
return noise
|
| 85 |
+
|
| 86 |
+
def reset(self) -> None:
|
| 87 |
+
"""
|
| 88 |
+
reset the Ornstein Uhlenbeck noise, to the initial position
|
| 89 |
+
"""
|
| 90 |
+
self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
|
| 91 |
+
|
| 92 |
+
def __repr__(self) -> str:
|
| 93 |
+
return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class VectorizedActionNoise(ActionNoise):
|
| 97 |
+
"""
|
| 98 |
+
A Vectorized action noise for parallel environments.
|
| 99 |
+
|
| 100 |
+
:param base_noise: ActionNoise The noise generator to use
|
| 101 |
+
:param n_envs: The number of parallel environments
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, base_noise: ActionNoise, n_envs: int):
|
| 105 |
+
try:
|
| 106 |
+
self.n_envs = int(n_envs)
|
| 107 |
+
assert self.n_envs > 0
|
| 108 |
+
except (TypeError, AssertionError):
|
| 109 |
+
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
|
| 110 |
+
|
| 111 |
+
self.base_noise = base_noise
|
| 112 |
+
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
|
| 113 |
+
|
| 114 |
+
def reset(self, indices: Optional[Iterable[int]] = None) -> None:
|
| 115 |
+
"""
|
| 116 |
+
Reset all the noise processes, or those listed in indices
|
| 117 |
+
|
| 118 |
+
:param indices: Optional[Iterable[int]] The indices to reset. Default: None.
|
| 119 |
+
If the parameter is None, then all processes are reset to their initial position.
|
| 120 |
+
"""
|
| 121 |
+
if indices is None:
|
| 122 |
+
indices = range(len(self.noises))
|
| 123 |
+
|
| 124 |
+
for index in indices:
|
| 125 |
+
self.noises[index].reset()
|
| 126 |
+
|
| 127 |
+
def __repr__(self) -> str:
|
| 128 |
+
return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"
|
| 129 |
+
|
| 130 |
+
def __call__(self) -> np.ndarray:
|
| 131 |
+
"""
|
| 132 |
+
Generate and stack the action noise from each noise object
|
| 133 |
+
"""
|
| 134 |
+
noise = np.stack([noise() for noise in self.noises])
|
| 135 |
+
return noise
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def base_noise(self) -> ActionNoise:
|
| 139 |
+
return self._base_noise
|
| 140 |
+
|
| 141 |
+
@base_noise.setter
|
| 142 |
+
def base_noise(self, base_noise: ActionNoise) -> None:
|
| 143 |
+
if base_noise is None:
|
| 144 |
+
raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
|
| 145 |
+
if not isinstance(base_noise, ActionNoise):
|
| 146 |
+
raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
|
| 147 |
+
self._base_noise = base_noise
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def noises(self) -> List[ActionNoise]:
|
| 151 |
+
return self._noises
|
| 152 |
+
|
| 153 |
+
@noises.setter
|
| 154 |
+
def noises(self, noises: List[ActionNoise]) -> None:
|
| 155 |
+
noises = list(noises) # raises TypeError if not iterable
|
| 156 |
+
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
|
| 157 |
+
|
| 158 |
+
different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
|
| 159 |
+
|
| 160 |
+
if len(different_types):
|
| 161 |
+
raise ValueError(
|
| 162 |
+
f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self._noises = noises
|
| 166 |
+
for noise in noises:
|
| 167 |
+
noise.reset()
|
dexart-release/stable_baselines3/common/on_policy_algorithm.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
| 3 |
+
|
| 4 |
+
import gym
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch as th
|
| 7 |
+
|
| 8 |
+
from stable_baselines3.common.base_class import BaseAlgorithm
|
| 9 |
+
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
| 10 |
+
from stable_baselines3.common.callbacks import BaseCallback
|
| 11 |
+
from stable_baselines3.common.policies import ActorCriticPolicy
|
| 12 |
+
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
| 13 |
+
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
|
| 14 |
+
from stable_baselines3.common.vec_env import VecEnv
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class OnPolicyAlgorithm(BaseAlgorithm):
|
| 18 |
+
"""
|
| 19 |
+
The base for On-Policy algorithms (ex: A2C/PPO).
|
| 20 |
+
|
| 21 |
+
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
|
| 22 |
+
:param env: The environment to learn from (if registered in Gym, can be str)
|
| 23 |
+
:param learning_rate: The learning rate, it can be a function
|
| 24 |
+
of the current progress remaining (from 1 to 0)
|
| 25 |
+
:param n_steps: The number of steps to run for each environment per update
|
| 26 |
+
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
| 27 |
+
:param gamma: Discount factor
|
| 28 |
+
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
|
| 29 |
+
Equivalent to classic advantage when set to 1.
|
| 30 |
+
:param ent_coef: Entropy coefficient for the loss calculation
|
| 31 |
+
:param vf_coef: Value function coefficient for the loss calculation
|
| 32 |
+
:param max_grad_norm: The maximum value for the gradient clipping
|
| 33 |
+
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
|
| 34 |
+
instead of action noise exploration (default: False)
|
| 35 |
+
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
| 36 |
+
Default: -1 (only sample at the beginning of the rollout)
|
| 37 |
+
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
| 38 |
+
:param create_eval_env: Whether to create a second environment that will be
|
| 39 |
+
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
| 40 |
+
:param monitor_wrapper: When creating an environment, whether to wrap it
|
| 41 |
+
or not in a Monitor wrapper.
|
| 42 |
+
:param policy_kwargs: additional arguments to be passed to the policy on creation
|
| 43 |
+
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
|
| 44 |
+
:param seed: Seed for the pseudo random generators
|
| 45 |
+
:param device: Device (cpu, cuda, ...) on which the code should be run.
|
| 46 |
+
Setting it to auto, the code will be run on the GPU if possible.
|
| 47 |
+
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
| 48 |
+
:param supported_action_spaces: The action spaces supported by the algorithm.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
policy: Union[str, Type[ActorCriticPolicy]],
|
| 54 |
+
env: Union[GymEnv, str],
|
| 55 |
+
learning_rate: Union[float, Schedule],
|
| 56 |
+
n_steps: int,
|
| 57 |
+
gamma: float,
|
| 58 |
+
gae_lambda: float,
|
| 59 |
+
ent_coef: float,
|
| 60 |
+
vf_coef: float,
|
| 61 |
+
max_grad_norm: float,
|
| 62 |
+
use_sde: bool,
|
| 63 |
+
sde_sample_freq: int,
|
| 64 |
+
tensorboard_log: Optional[str] = None,
|
| 65 |
+
create_eval_env: bool = False,
|
| 66 |
+
monitor_wrapper: bool = True,
|
| 67 |
+
policy_kwargs: Optional[Dict[str, Any]] = None,
|
| 68 |
+
verbose: int = 0,
|
| 69 |
+
seed: Optional[int] = None,
|
| 70 |
+
device: Union[th.device, str] = "auto",
|
| 71 |
+
_init_setup_model: bool = True,
|
| 72 |
+
supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
|
| 73 |
+
):
|
| 74 |
+
|
| 75 |
+
super().__init__(
|
| 76 |
+
policy=policy,
|
| 77 |
+
env=env,
|
| 78 |
+
learning_rate=learning_rate,
|
| 79 |
+
policy_kwargs=policy_kwargs,
|
| 80 |
+
verbose=verbose,
|
| 81 |
+
device=device,
|
| 82 |
+
use_sde=use_sde,
|
| 83 |
+
sde_sample_freq=sde_sample_freq,
|
| 84 |
+
create_eval_env=create_eval_env,
|
| 85 |
+
support_multi_env=True,
|
| 86 |
+
seed=seed,
|
| 87 |
+
tensorboard_log=tensorboard_log,
|
| 88 |
+
supported_action_spaces=supported_action_spaces,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.n_steps = n_steps
|
| 92 |
+
self.gamma = gamma
|
| 93 |
+
self.gae_lambda = gae_lambda
|
| 94 |
+
self.ent_coef = ent_coef
|
| 95 |
+
self.vf_coef = vf_coef
|
| 96 |
+
self.max_grad_norm = max_grad_norm
|
| 97 |
+
self.rollout_buffer = None
|
| 98 |
+
|
| 99 |
+
self.last_rollout_reward = -np.inf
|
| 100 |
+
self.need_restore = False
|
| 101 |
+
self.last_policy_saved: List[Dict] = [{}, {}]
|
| 102 |
+
self.current_restore_step = 0
|
| 103 |
+
|
| 104 |
+
if _init_setup_model:
|
| 105 |
+
self._setup_model()
|
| 106 |
+
|
| 107 |
+
def _setup_model(self) -> None:
|
| 108 |
+
self._setup_lr_schedule()
|
| 109 |
+
self.set_random_seed(self.seed)
|
| 110 |
+
|
| 111 |
+
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer
|
| 112 |
+
|
| 113 |
+
self.rollout_buffer = buffer_cls(
|
| 114 |
+
self.n_steps,
|
| 115 |
+
self.observation_space,
|
| 116 |
+
self.action_space,
|
| 117 |
+
device=self.device,
|
| 118 |
+
gamma=self.gamma,
|
| 119 |
+
gae_lambda=self.gae_lambda,
|
| 120 |
+
n_envs=self.n_envs,
|
| 121 |
+
)
|
| 122 |
+
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
| 123 |
+
self.observation_space,
|
| 124 |
+
self.action_space,
|
| 125 |
+
self.lr_schedule,
|
| 126 |
+
use_sde=self.use_sde,
|
| 127 |
+
**self.policy_kwargs # pytype:disable=not-instantiable
|
| 128 |
+
)
|
| 129 |
+
self.policy = self.policy.to(self.device)
|
| 130 |
+
|
| 131 |
+
def collect_rollouts(
|
| 132 |
+
self,
|
| 133 |
+
env: VecEnv,
|
| 134 |
+
callback: BaseCallback,
|
| 135 |
+
rollout_buffer: RolloutBuffer,
|
| 136 |
+
n_rollout_steps: int,
|
| 137 |
+
) -> bool:
|
| 138 |
+
"""
|
| 139 |
+
Collect experiences using the current policy and fill a ``RolloutBuffer``.
|
| 140 |
+
The term rollout here refers to the model-free notion and should not
|
| 141 |
+
be used with the concept of rollout used in model-based RL or planning.
|
| 142 |
+
|
| 143 |
+
:param env: The training environment
|
| 144 |
+
:param callback: Callback that will be called at each step
|
| 145 |
+
(and at the beginning and end of the rollout)
|
| 146 |
+
:param rollout_buffer: Buffer to fill with rollouts
|
| 147 |
+
:param n_steps: Number of experiences to collect per environment
|
| 148 |
+
:return: True if function returned with at least `n_rollout_steps`
|
| 149 |
+
collected, False if callback terminated rollout prematurely.
|
| 150 |
+
"""
|
| 151 |
+
assert self._last_obs is not None, "No previous observation was provided"
|
| 152 |
+
# Switch to eval mode (this affects batch norm / dropout)
|
| 153 |
+
self.policy.set_training_mode(False)
|
| 154 |
+
last_episode_reward = self.last_rollout_reward
|
| 155 |
+
self.last_rollout_reward = 0
|
| 156 |
+
num_rollouts = 0
|
| 157 |
+
n_steps = 0
|
| 158 |
+
rollout_buffer.reset()
|
| 159 |
+
# Sample new weights for the state dependent exploration
|
| 160 |
+
if self.use_sde:
|
| 161 |
+
self.policy.reset_noise(env.num_envs)
|
| 162 |
+
callback.on_rollout_start()
|
| 163 |
+
while n_steps < n_rollout_steps:
|
| 164 |
+
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
| 165 |
+
# Sample a new noise matrix
|
| 166 |
+
self.policy.reset_noise(env.num_envs)
|
| 167 |
+
|
| 168 |
+
with th.no_grad():
|
| 169 |
+
# Convert to pytorch tensor or to TensorDict
|
| 170 |
+
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
| 171 |
+
actions, values, log_probs = self.policy(obs_tensor)
|
| 172 |
+
actions = actions.cpu().numpy()
|
| 173 |
+
|
| 174 |
+
# Rescale and perform action
|
| 175 |
+
clipped_actions = actions
|
| 176 |
+
# Clip the actions to avoid out of bound error
|
| 177 |
+
if isinstance(self.action_space, gym.spaces.Box):
|
| 178 |
+
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
| 179 |
+
|
| 180 |
+
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
| 181 |
+
|
| 182 |
+
self.num_timesteps += env.num_envs
|
| 183 |
+
|
| 184 |
+
# Give access to local variables
|
| 185 |
+
callback.update_locals(locals())
|
| 186 |
+
if callback.on_step() is False:
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
self._update_info_buffer(infos)
|
| 190 |
+
n_steps += 1
|
| 191 |
+
|
| 192 |
+
if isinstance(self.action_space, gym.spaces.Discrete):
|
| 193 |
+
# Reshape in case of discrete action
|
| 194 |
+
actions = actions.reshape(-1, 1)
|
| 195 |
+
|
| 196 |
+
# Handle timeout by bootstraping with value function
|
| 197 |
+
# see GitHub issue #633
|
| 198 |
+
for idx, done in enumerate(dones):
|
| 199 |
+
if (
|
| 200 |
+
done
|
| 201 |
+
and infos[idx].get("terminal_observation") is not None
|
| 202 |
+
and infos[idx].get("TimeLimit.truncated", False)
|
| 203 |
+
):
|
| 204 |
+
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
|
| 205 |
+
with th.no_grad():
|
| 206 |
+
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
| 207 |
+
rewards[idx] += self.gamma * terminal_value
|
| 208 |
+
|
| 209 |
+
if done:
|
| 210 |
+
num_rollouts += 1
|
| 211 |
+
|
| 212 |
+
self.last_rollout_reward += rewards.sum()
|
| 213 |
+
|
| 214 |
+
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
|
| 215 |
+
self._last_obs = new_obs
|
| 216 |
+
self._last_episode_starts = dones
|
| 217 |
+
with th.no_grad():
|
| 218 |
+
# Compute value for the last timestep
|
| 219 |
+
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
|
| 220 |
+
|
| 221 |
+
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
| 222 |
+
|
| 223 |
+
self.last_rollout_reward /= num_rollouts
|
| 224 |
+
reward_gap = last_episode_reward - self.last_rollout_reward
|
| 225 |
+
self.need_restore = False
|
| 226 |
+
self.current_restore_step = 0
|
| 227 |
+
|
| 228 |
+
callback.on_rollout_end()
|
| 229 |
+
|
| 230 |
+
return True
|
| 231 |
+
|
| 232 |
+
def train(self) -> None:
|
| 233 |
+
"""
|
| 234 |
+
Consume current rollout data and update policy parameters.
|
| 235 |
+
Implemented by individual algorithms.
|
| 236 |
+
"""
|
| 237 |
+
raise NotImplementedError
|
| 238 |
+
|
| 239 |
+
def learn(
|
| 240 |
+
self,
|
| 241 |
+
total_timesteps: int,
|
| 242 |
+
callback: MaybeCallback = None,
|
| 243 |
+
log_interval: int = 1,
|
| 244 |
+
eval_env: Optional[GymEnv] = None,
|
| 245 |
+
eval_freq: int = -1,
|
| 246 |
+
n_eval_episodes: int = 5,
|
| 247 |
+
tb_log_name: str = "OnPolicyAlgorithm",
|
| 248 |
+
eval_log_path: Optional[str] = None,
|
| 249 |
+
reset_num_timesteps: bool = True,
|
| 250 |
+
iter_start=0,
|
| 251 |
+
) -> "OnPolicyAlgorithm":
|
| 252 |
+
iteration = iter_start
|
| 253 |
+
|
| 254 |
+
total_timesteps, callback = self._setup_learn(
|
| 255 |
+
total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps,
|
| 256 |
+
tb_log_name
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
callback.on_training_start(locals(), globals())
|
| 260 |
+
|
| 261 |
+
while self.num_timesteps < total_timesteps:
|
| 262 |
+
|
| 263 |
+
x = time.time()
|
| 264 |
+
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer,
|
| 265 |
+
n_rollout_steps=self.n_steps)
|
| 266 |
+
print("Rollout time:", time.time() - x)
|
| 267 |
+
|
| 268 |
+
if continue_training is False:
|
| 269 |
+
break
|
| 270 |
+
|
| 271 |
+
if self.need_restore and self.current_restore_step < 5:
|
| 272 |
+
print(f"Large performance drop detected. Restore previous model.")
|
| 273 |
+
self.set_parameters(self.last_policy_saved[0], exact_match=True, device=self.device)
|
| 274 |
+
continue
|
| 275 |
+
else:
|
| 276 |
+
self.last_policy_saved[0] = self.last_policy_saved[1]
|
| 277 |
+
self.last_policy_saved[1] = self.get_parameters()
|
| 278 |
+
|
| 279 |
+
iteration += 1
|
| 280 |
+
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
| 281 |
+
|
| 282 |
+
# Display training infos
|
| 283 |
+
|
| 284 |
+
if log_interval is not None and iteration % log_interval == 0:
|
| 285 |
+
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
|
| 286 |
+
self.logger.record("time/iterations", iteration, exclude="wandb")
|
| 287 |
+
self.logger.record("rollout/rollout_rew_mean", self.last_rollout_reward)
|
| 288 |
+
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
| 289 |
+
self.logger.record("rollout/ep_rew_mean",
|
| 290 |
+
safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
| 291 |
+
self.logger.record("rollout/ep_len_mean",
|
| 292 |
+
safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
| 293 |
+
self.logger.record("time/fps", fps)
|
| 294 |
+
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
| 295 |
+
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
| 296 |
+
self.logger.dump(step=iteration)
|
| 297 |
+
|
| 298 |
+
x = time.time()
|
| 299 |
+
self.train()
|
| 300 |
+
print("Train time:", time.time() - x)
|
| 301 |
+
|
| 302 |
+
callback.on_training_end()
|
| 303 |
+
|
| 304 |
+
return self
|
| 305 |
+
|
| 306 |
+
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
|
| 307 |
+
state_dicts = ["policy", "policy.optimizer"]
|
| 308 |
+
|
| 309 |
+
return state_dicts, []
|
| 310 |
+
|
| 311 |
+
def _excluded_save_params(self) -> List[str]:
|
| 312 |
+
"""
|
| 313 |
+
Returns the names of the parameters that should be excluded from being
|
| 314 |
+
saved by pickling. E.g. replay buffers are skipped by default
|
| 315 |
+
as they take up a lot of space. PyTorch variables should be excluded
|
| 316 |
+
with this so they can be stored with ``th.save``.
|
| 317 |
+
|
| 318 |
+
:return: List of parameters that should be excluded from being saved with pickle.
|
| 319 |
+
"""
|
| 320 |
+
return super()._excluded_save_params() + ["last_policy_saved"]
|