chenhaojun commited on
Commit
c8144dc
·
verified ·
1 Parent(s): b1c8657

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml +25 -0
  3. Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml +26 -0
  4. Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml +23 -0
  5. Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml +35 -0
  6. Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml +22 -0
  7. Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 +3 -0
  8. Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 +3 -0
  9. dexart-release/assets/sapien/102697/cues.txt +5 -0
  10. dexart-release/assets/sapien/102697/meta.json +1 -0
  11. dexart-release/assets/sapien/102697/mobility.urdf +502 -0
  12. dexart-release/assets/sapien/102697/mobility_v2.json +1 -0
  13. dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl +12 -0
  14. dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl +12 -0
  15. dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj +109 -0
  16. dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl +12 -0
  17. dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj +217 -0
  18. dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj +250 -0
  19. dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj +169 -0
  20. dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj +113 -0
  21. dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj +105 -0
  22. dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl +12 -0
  23. dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl +12 -0
  24. dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl +12 -0
  25. dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl +12 -0
  26. dexart-release/assets/sapien/102697/result.json +1 -0
  27. dexart-release/assets/sapien/102697/result_original.json +1 -0
  28. dexart-release/assets/sapien/102697/semantics.txt +5 -0
  29. dexart-release/dexart.egg-info/PKG-INFO +12 -0
  30. dexart-release/dexart.egg-info/SOURCES.txt +51 -0
  31. dexart-release/dexart.egg-info/dependency_links.txt +1 -0
  32. dexart-release/dexart.egg-info/requires.txt +4 -0
  33. dexart-release/dexart.egg-info/top_level.txt +1 -0
  34. dexart-release/examples/gen_demonstration_expert.py +238 -0
  35. dexart-release/examples/train.py +124 -0
  36. dexart-release/examples/utils.py +66 -0
  37. dexart-release/stable_baselines3/a2c/__init__.py +2 -0
  38. dexart-release/stable_baselines3/a2c/a2c.py +207 -0
  39. dexart-release/stable_baselines3/a2c/policies.py +7 -0
  40. dexart-release/stable_baselines3/common/__init__.py +0 -0
  41. dexart-release/stable_baselines3/common/base_class.py +835 -0
  42. dexart-release/stable_baselines3/common/buffers.py +1010 -0
  43. dexart-release/stable_baselines3/common/callbacks.py +602 -0
  44. dexart-release/stable_baselines3/common/distributions.py +699 -0
  45. dexart-release/stable_baselines3/common/env_util.py +104 -0
  46. dexart-release/stable_baselines3/common/evaluation.py +131 -0
  47. dexart-release/stable_baselines3/common/logger.py +644 -0
  48. dexart-release/stable_baselines3/common/monitor.py +239 -0
  49. dexart-release/stable_baselines3/common/noise.py +167 -0
  50. dexart-release/stable_baselines3/common/on_policy_algorithm.py +320 -0
.gitattributes CHANGED
@@ -43,3 +43,5 @@ Metaworld/zarr_path:[[:space:]]data/metaworld_disassemble_expert.zarr/data/point
43
  Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_hammer_expert.zarr/data/depth/0.0.0 filter=lfs diff=lfs merge=lfs -text
44
  Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/16.0.0 filter=lfs diff=lfs merge=lfs -text
45
  Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_drawer-open_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text
 
 
 
43
  Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_hammer_expert.zarr/data/depth/0.0.0 filter=lfs diff=lfs merge=lfs -text
44
  Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/16.0.0 filter=lfs diff=lfs merge=lfs -text
45
  Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_drawer-open_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text
46
+ Metaworld/zarr_path:[[:space:]]data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 filter=lfs diff=lfs merge=lfs -text
47
+ Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 filter=lfs diff=lfs merge=lfs -text
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mujoco>
2
+ <include file="../scene/basic_scene.xml"/>
3
+ <include file="../objects/assets/dial_dependencies.xml"/>
4
+ <include file="../objects/assets/xyz_base_dependencies.xml"/>
5
+ <worldbody>
6
+ <include file="../objects/assets/xyz_base.xml"/>
7
+
8
+ <body name="dial" pos="0 0.7 0.">
9
+ <include file="../objects/assets/dial.xml"/>
10
+ <site name="dialStart" pos="0 -0.05 0.035" size="0.005" rgba="0 0 1 1"/>
11
+ </body>
12
+
13
+ <site name="goal" pos="0. 0.74 0.07" size="0.02"
14
+ rgba=".8 0 0 1"/>
15
+
16
+ --> </worldbody>
17
+
18
+ <actuator>
19
+ <position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
20
+ <position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
21
+ </actuator>
22
+ <equality>
23
+ <weld body1="mocap" body2="hand" solref="0.02 1"></weld>
24
+ </equality>
25
+ </mujoco>
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mujoco>
2
+ <include file="../scene/basic_scene.xml"/>
3
+ <include file="../objects/assets/doorlock_dependencies.xml"/>
4
+ <include file="../objects/assets/xyz_base_dependencies.xml"/>
5
+
6
+ <worldbody>
7
+
8
+ <include file="../objects/assets/xyz_base.xml"/>
9
+
10
+ <body name="door" pos="0 0.9 0.15">
11
+ <include file="../objects/assets/doorlockA.xml"/>
12
+ </body>
13
+
14
+ <site name="goal_lock" pos="0 0.74 0.12" size="0.01"
15
+ rgba="0 0.8 0 1"/>
16
+ <site name="goal_unlock" pos="0.09 0.74 0.211" size="0.01"
17
+ rgba="0 0 0.8 1"/>
18
+ </worldbody>
19
+ <actuator>
20
+ <position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
21
+ <position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
22
+ </actuator>
23
+ <equality>
24
+ <weld body1="mocap" body2="hand" solref="0.02 1"/>
25
+ </equality>
26
+ </mujoco>
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mujoco>
2
+ <include file="../scene/basic_scene.xml"/>
3
+ <include file="../objects/assets/doorlock_dependencies.xml"/>
4
+ <include file="../objects/assets/xyz_base_dependencies.xml"/>
5
+
6
+ <worldbody>
7
+ <include file="../objects/assets/xyz_base.xml"/>
8
+
9
+ <body name="door" pos="0 0.9 0.15">
10
+ <include file="../objects/assets/doorlockB.xml"/>
11
+ </body>
12
+
13
+ <site name="goal" pos="-0.49 0.46 0.15" size="0.02"
14
+ rgba="0 0.8 0 1"/>
15
+ </worldbody>
16
+ <actuator>
17
+ <position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
18
+ <position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
19
+ </actuator>
20
+ <equality>
21
+ <weld body1="mocap" body2="hand" solref="0.02 1"/>
22
+ </equality>
23
+ </mujoco>
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mujoco>
2
+ <include file="../scene/basic_scene.xml"/>
3
+ <include file="../objects/assets/faucet_dependencies.xml"/>
4
+ <include file="../objects/assets/xyz_base_dependencies.xml"/>
5
+ <worldbody>
6
+ <include file="../objects/assets/xyz_base.xml"/>
7
+
8
+ <body name="faucetBase" pos="0 0.8 0">
9
+ <include file="../objects/assets/faucet.xml"/>
10
+
11
+ </body>
12
+
13
+
14
+ <site name="goal_open" pos="0.175 0.785 0.125" size="0.005"
15
+ rgba="0 .8 0 1"/>
16
+ <site name="goal_close" pos="-0.175 0.785 0.125" size="0.005"
17
+ rgba="0 0 0.8 1"/>
18
+
19
+ <!-- <body name="box" pos="0 0.8 0.05">
20
+ <geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.1 0.05 0.05" name="box_left" conaffinity="1" pos="0 0 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
21
+ <geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.1 0.05 0.05" name="box_right" conaffinity="1" pos="0 0.16 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
22
+ <geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.035 0.03 0.05" name="box_front" conaffinity="1" pos="0.065 0.08 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
23
+ <geom rgba="0.3 0.3 1 1" type="box" contype="1" size="0.035 0.03 0.05" name="box_behind" conaffinity="1" pos="-0.065 0.08 0" mass="1000" solimp="0.99 0.99 0.01" solref="0.01 1"/>
24
+ <joint type="slide" range="-0.2 0." axis="0 1 0" name="goal_slidey" pos="0 0 0" damping="1.0"/>
25
+ </body>
26
+ --> </worldbody>
27
+
28
+ <actuator>
29
+ <position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
30
+ <position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
31
+ </actuator>
32
+ <equality>
33
+ <weld body1="mocap" body2="hand" solref="0.02 1"></weld>
34
+ </equality>
35
+ </mujoco>
Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mujoco>
2
+ <include file="../scene/basic_scene.xml"/>
3
+ <include file="../objects/assets/laptop_dependencies.xml"/>
4
+ <include file="../objects/assets/xyz_base_dependencies.xml"/>
5
+
6
+ <worldbody>
7
+ <include file="../objects/assets/xyz_base.xml"/>
8
+
9
+ <body name="laptop" pos="0 0.8 0">
10
+ <include file="../objects/assets/laptop.xml"/>
11
+
12
+ </body>
13
+ </worldbody>
14
+
15
+ <actuator>
16
+ <position ctrllimited="true" ctrlrange="-1 1" joint="r_close" kp="400" user="1"/>
17
+ <position ctrllimited="true" ctrlrange="-1 1" joint="l_close" kp="400" user="1"/>
18
+ </actuator>
19
+ <equality>
20
+ <weld body1="mocap" body2="hand" solref="0.02 1"></weld>
21
+ </equality>
22
+ </mujoco>
Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:902da15cc5607a827c4a6cf9b7c396bacd2bd244963aa4e36f400f543b23497b
3
+ size 1231019
Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32e76433f75e66cccd7b4db7962d3fc1ae4fe8915409efc427235456347323c4
3
+ size 1213234
dexart-release/assets/sapien/102697/cues.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ link_0 hinge lever button
2
+ link_1 slider pump_lid lid
3
+ link_2 hinge lid lid
4
+ link_3 hinge seat seat
5
+ link_4 static base_body base_body
dexart-release/assets/sapien/102697/meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"user_id": "haozhu", "model_cat": "Toilet", "model_id": "db252ecd6286a334733badcb2e574996-0", "version": "1", "anno_id": "2697", "time_in_sec": "31"}
dexart-release/assets/sapien/102697/mobility.urdf ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <robot name="partnet_db252ecd6286a334733badcb2e574996-0">
2
+ <link name="base" />
3
+ <link name="link_0">
4
+ <collision>
5
+ <origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
6
+ <geometry>
7
+ <mesh filename="new_objs/102697_link_0_2.obj" />
8
+ </geometry>
9
+ </collision>
10
+ <collision>
11
+ <origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
12
+ <geometry>
13
+ <mesh filename="new_objs/102697_link_0_0.obj" />
14
+ </geometry>
15
+ </collision>
16
+ <collision>
17
+ <origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
18
+ <geometry>
19
+ <mesh filename="new_objs/102697_link_0_1.obj" />
20
+ </geometry>
21
+ </collision>
22
+ <visual name="button-part[&quot;id]">
23
+ <origin xyz="0.387179130380233 -0.5229989412652956 0.2586584187924479" />
24
+ <geometry>
25
+ <mesh filename="textured_objs/original-8.obj" />
26
+ </geometry>
27
+ </visual>
28
+ </link>
29
+ <joint name="joint_0" type="revolute">
30
+ <origin xyz="-0.387179130380233 0.5229989412652956 -0.2586584187924479" />
31
+ <axis xyz="-0.91844856752952 -6.467369864908537e-05 0.39554042096893877" />
32
+ <child link="link_0" />
33
+ <parent link="link_4" />
34
+ <limit lower="0.0" upper="0.5235987755982988" />
35
+ </joint>
36
+ <link name="link_1">
37
+ <collision>
38
+ <origin xyz="0 0 0" />
39
+ <geometry>
40
+ <mesh filename="new_objs/102697_link_1_6.obj" />
41
+ </geometry>
42
+ </collision>
43
+ <collision>
44
+ <origin xyz="0 0 0" />
45
+ <geometry>
46
+ <mesh filename="new_objs/102697_link_1_14.obj" />
47
+ </geometry>
48
+ </collision>
49
+ <collision>
50
+ <origin xyz="0 0 0" />
51
+ <geometry>
52
+ <mesh filename="new_objs/102697_link_1_11.obj" />
53
+ </geometry>
54
+ </collision>
55
+ <collision>
56
+ <origin xyz="0 0 0" />
57
+ <geometry>
58
+ <mesh filename="new_objs/102697_link_1_4.obj" />
59
+ </geometry>
60
+ </collision>
61
+ <collision>
62
+ <origin xyz="0 0 0" />
63
+ <geometry>
64
+ <mesh filename="new_objs/102697_link_1_8.obj" />
65
+ </geometry>
66
+ </collision>
67
+ <collision>
68
+ <origin xyz="0 0 0" />
69
+ <geometry>
70
+ <mesh filename="new_objs/102697_link_1_3.obj" />
71
+ </geometry>
72
+ </collision>
73
+ <collision>
74
+ <origin xyz="0 0 0" />
75
+ <geometry>
76
+ <mesh filename="new_objs/102697_link_1_1.obj" />
77
+ </geometry>
78
+ </collision>
79
+ <collision>
80
+ <origin xyz="0 0 0" />
81
+ <geometry>
82
+ <mesh filename="new_objs/102697_link_1_12.obj" />
83
+ </geometry>
84
+ </collision>
85
+ <collision>
86
+ <origin xyz="0 0 0" />
87
+ <geometry>
88
+ <mesh filename="new_objs/102697_link_1_5.obj" />
89
+ </geometry>
90
+ </collision>
91
+ <collision>
92
+ <origin xyz="0 0 0" />
93
+ <geometry>
94
+ <mesh filename="new_objs/102697_link_1_13.obj" />
95
+ </geometry>
96
+ </collision>
97
+ <collision>
98
+ <origin xyz="0 0 0" />
99
+ <geometry>
100
+ <mesh filename="new_objs/102697_link_1_0.obj" />
101
+ </geometry>
102
+ </collision>
103
+ <collision>
104
+ <origin xyz="0 0 0" />
105
+ <geometry>
106
+ <mesh filename="new_objs/102697_link_1_10.obj" />
107
+ </geometry>
108
+ </collision>
109
+ <collision>
110
+ <origin xyz="0 0 0" />
111
+ <geometry>
112
+ <mesh filename="new_objs/102697_link_1_2.obj" />
113
+ </geometry>
114
+ </collision>
115
+ <collision>
116
+ <origin xyz="0 0 0" />
117
+ <geometry>
118
+ <mesh filename="new_objs/102697_link_1_9.obj" />
119
+ </geometry>
120
+ </collision>
121
+ <collision>
122
+ <origin xyz="0 0 0" />
123
+ <geometry>
124
+ <mesh filename="new_objs/102697_link_1_7.obj" />
125
+ </geometry>
126
+ </collision>
127
+ <visual name="lid-part[&quot;id]">
128
+ <origin xyz="0 0 0" />
129
+ <geometry>
130
+ <mesh filename="textured_objs/original-4.obj" />
131
+ </geometry>
132
+ </visual>
133
+ <visual name="lid-part[&quot;id]">
134
+ <origin xyz="0 0 0" />
135
+ <geometry>
136
+ <mesh filename="textured_objs/original-3.obj" />
137
+ </geometry>
138
+ </visual>
139
+ </link>
140
+ <joint name="joint_1" type="prismatic">
141
+ <origin xyz="0 0 0" />
142
+ <axis xyz="0 1 0" />
143
+ <child link="link_1" />
144
+ <parent link="link_4" />
145
+ <limit lower="0" upper="0.02" />
146
+ </joint>
147
+ <link name="link_2">
148
+ <collision>
149
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
150
+ <geometry>
151
+ <mesh filename="new_objs/102697_link_2_0.obj" />
152
+ </geometry>
153
+ </collision>
154
+ <visual name="lid-part[&quot;id]">
155
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
156
+ <geometry>
157
+ <mesh filename="textured_objs/original-21.obj" />
158
+ </geometry>
159
+ </visual>
160
+ <visual name="lid-part[&quot;id]">
161
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
162
+ <geometry>
163
+ <mesh filename="textured_objs/original-20.obj" />
164
+ </geometry>
165
+ </visual>
166
+ </link>
167
+ <joint name="joint_2" type="revolute">
168
+ <origin xyz="0 0.05794601786221419 -0.06205458525801076" />
169
+ <axis xyz="-1 0 0" />
170
+ <child link="link_2" />
171
+ <parent link="link_4" />
172
+ <limit lower="-0.0" upper="1.6406094968746698" />
173
+ </joint>
174
+ <link name="link_3">
175
+ <collision>
176
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
177
+ <geometry>
178
+ <mesh filename="new_objs/102697_link_3_2.obj" />
179
+ </geometry>
180
+ </collision>
181
+ <collision>
182
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
183
+ <geometry>
184
+ <mesh filename="new_objs/102697_link_3_6.obj" />
185
+ </geometry>
186
+ </collision>
187
+ <collision>
188
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
189
+ <geometry>
190
+ <mesh filename="new_objs/102697_link_3_1.obj" />
191
+ </geometry>
192
+ </collision>
193
+ <collision>
194
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
195
+ <geometry>
196
+ <mesh filename="new_objs/102697_link_3_0.obj" />
197
+ </geometry>
198
+ </collision>
199
+ <collision>
200
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
201
+ <geometry>
202
+ <mesh filename="new_objs/102697_link_3_4.obj" />
203
+ </geometry>
204
+ </collision>
205
+ <collision>
206
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
207
+ <geometry>
208
+ <mesh filename="new_objs/102697_link_3_5.obj" />
209
+ </geometry>
210
+ </collision>
211
+ <collision>
212
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
213
+ <geometry>
214
+ <mesh filename="new_objs/102697_link_3_7.obj" />
215
+ </geometry>
216
+ </collision>
217
+ <collision>
218
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
219
+ <geometry>
220
+ <mesh filename="new_objs/102697_link_3_3.obj" />
221
+ </geometry>
222
+ </collision>
223
+ <visual name="seat-part[&quot;id]">
224
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
225
+ <geometry>
226
+ <mesh filename="textured_objs/original-18.obj" />
227
+ </geometry>
228
+ </visual>
229
+ <visual name="seat-part[&quot;id]">
230
+ <origin xyz="0 -0.05794601786221419 0.06205458525801076" />
231
+ <geometry>
232
+ <mesh filename="textured_objs/original-17.obj" />
233
+ </geometry>
234
+ </visual>
235
+ </link>
236
+ <joint name="joint_3" type="revolute">
237
+ <origin xyz="0 0.05794601786221419 -0.06205458525801076" />
238
+ <axis xyz="-1 0 0" />
239
+ <child link="link_3" />
240
+ <parent link="link_4" />
241
+ <limit lower="-0.0" upper="1.6406094968746698" />
242
+ </joint>
243
+ <link name="link_4">
244
+ <collision>
245
+ <origin xyz="0 0 0" />
246
+ <geometry>
247
+ <mesh filename="new_objs/102697_link_4_0.obj" />
248
+ </geometry>
249
+ </collision>
250
+ <collision>
251
+ <origin xyz="0 0 0" />
252
+ <geometry>
253
+ <mesh filename="new_objs/102697_link_4_18.obj" />
254
+ </geometry>
255
+ </collision>
256
+ <collision>
257
+ <origin xyz="0 0 0" />
258
+ <geometry>
259
+ <mesh filename="new_objs/102697_link_4_22.obj" />
260
+ </geometry>
261
+ </collision>
262
+ <collision>
263
+ <origin xyz="0 0 0" />
264
+ <geometry>
265
+ <mesh filename="new_objs/102697_link_4_12.obj" />
266
+ </geometry>
267
+ </collision>
268
+ <collision>
269
+ <origin xyz="0 0 0" />
270
+ <geometry>
271
+ <mesh filename="new_objs/102697_link_4_27.obj" />
272
+ </geometry>
273
+ </collision>
274
+ <collision>
275
+ <origin xyz="0 0 0" />
276
+ <geometry>
277
+ <mesh filename="new_objs/102697_link_4_23.obj" />
278
+ </geometry>
279
+ </collision>
280
+ <collision>
281
+ <origin xyz="0 0 0" />
282
+ <geometry>
283
+ <mesh filename="new_objs/102697_link_4_33.obj" />
284
+ </geometry>
285
+ </collision>
286
+ <collision>
287
+ <origin xyz="0 0 0" />
288
+ <geometry>
289
+ <mesh filename="new_objs/102697_link_4_2.obj" />
290
+ </geometry>
291
+ </collision>
292
+ <collision>
293
+ <origin xyz="0 0 0" />
294
+ <geometry>
295
+ <mesh filename="new_objs/102697_link_4_13.obj" />
296
+ </geometry>
297
+ </collision>
298
+ <collision>
299
+ <origin xyz="0 0 0" />
300
+ <geometry>
301
+ <mesh filename="new_objs/102697_link_4_28.obj" />
302
+ </geometry>
303
+ </collision>
304
+ <collision>
305
+ <origin xyz="0 0 0" />
306
+ <geometry>
307
+ <mesh filename="new_objs/102697_link_4_34.obj" />
308
+ </geometry>
309
+ </collision>
310
+ <collision>
311
+ <origin xyz="0 0 0" />
312
+ <geometry>
313
+ <mesh filename="new_objs/102697_link_4_7.obj" />
314
+ </geometry>
315
+ </collision>
316
+ <collision>
317
+ <origin xyz="0 0 0" />
318
+ <geometry>
319
+ <mesh filename="new_objs/102697_link_4_1.obj" />
320
+ </geometry>
321
+ </collision>
322
+ <collision>
323
+ <origin xyz="0 0 0" />
324
+ <geometry>
325
+ <mesh filename="new_objs/102697_link_4_10.obj" />
326
+ </geometry>
327
+ </collision>
328
+ <collision>
329
+ <origin xyz="0 0 0" />
330
+ <geometry>
331
+ <mesh filename="new_objs/102697_link_4_6.obj" />
332
+ </geometry>
333
+ </collision>
334
+ <collision>
335
+ <origin xyz="0 0 0" />
336
+ <geometry>
337
+ <mesh filename="new_objs/102697_link_4_19.obj" />
338
+ </geometry>
339
+ </collision>
340
+ <collision>
341
+ <origin xyz="0 0 0" />
342
+ <geometry>
343
+ <mesh filename="new_objs/102697_link_4_3.obj" />
344
+ </geometry>
345
+ </collision>
346
+ <collision>
347
+ <origin xyz="0 0 0" />
348
+ <geometry>
349
+ <mesh filename="new_objs/102697_link_4_26.obj" />
350
+ </geometry>
351
+ </collision>
352
+ <collision>
353
+ <origin xyz="0 0 0" />
354
+ <geometry>
355
+ <mesh filename="new_objs/102697_link_4_8.obj" />
356
+ </geometry>
357
+ </collision>
358
+ <collision>
359
+ <origin xyz="0 0 0" />
360
+ <geometry>
361
+ <mesh filename="new_objs/102697_link_4_25.obj" />
362
+ </geometry>
363
+ </collision>
364
+ <collision>
365
+ <origin xyz="0 0 0" />
366
+ <geometry>
367
+ <mesh filename="new_objs/102697_link_4_11.obj" />
368
+ </geometry>
369
+ </collision>
370
+ <collision>
371
+ <origin xyz="0 0 0" />
372
+ <geometry>
373
+ <mesh filename="new_objs/102697_link_4_21.obj" />
374
+ </geometry>
375
+ </collision>
376
+ <collision>
377
+ <origin xyz="0 0 0" />
378
+ <geometry>
379
+ <mesh filename="new_objs/102697_link_4_9.obj" />
380
+ </geometry>
381
+ </collision>
382
+ <collision>
383
+ <origin xyz="0 0 0" />
384
+ <geometry>
385
+ <mesh filename="new_objs/102697_link_4_17.obj" />
386
+ </geometry>
387
+ </collision>
388
+ <collision>
389
+ <origin xyz="0 0 0" />
390
+ <geometry>
391
+ <mesh filename="new_objs/102697_link_4_30.obj" />
392
+ </geometry>
393
+ </collision>
394
+ <collision>
395
+ <origin xyz="0 0 0" />
396
+ <geometry>
397
+ <mesh filename="new_objs/102697_link_4_16.obj" />
398
+ </geometry>
399
+ </collision>
400
+ <collision>
401
+ <origin xyz="0 0 0" />
402
+ <geometry>
403
+ <mesh filename="new_objs/102697_link_4_31.obj" />
404
+ </geometry>
405
+ </collision>
406
+ <collision>
407
+ <origin xyz="0 0 0" />
408
+ <geometry>
409
+ <mesh filename="new_objs/102697_link_4_20.obj" />
410
+ </geometry>
411
+ </collision>
412
+ <collision>
413
+ <origin xyz="0 0 0" />
414
+ <geometry>
415
+ <mesh filename="new_objs/102697_link_4_5.obj" />
416
+ </geometry>
417
+ </collision>
418
+ <collision>
419
+ <origin xyz="0 0 0" />
420
+ <geometry>
421
+ <mesh filename="new_objs/102697_link_4_29.obj" />
422
+ </geometry>
423
+ </collision>
424
+ <collision>
425
+ <origin xyz="0 0 0" />
426
+ <geometry>
427
+ <mesh filename="new_objs/102697_link_4_15.obj" />
428
+ </geometry>
429
+ </collision>
430
+ <collision>
431
+ <origin xyz="0 0 0" />
432
+ <geometry>
433
+ <mesh filename="new_objs/102697_link_4_4.obj" />
434
+ </geometry>
435
+ </collision>
436
+ <collision>
437
+ <origin xyz="0 0 0" />
438
+ <geometry>
439
+ <mesh filename="new_objs/102697_link_4_24.obj" />
440
+ </geometry>
441
+ </collision>
442
+ <collision>
443
+ <origin xyz="0 0 0" />
444
+ <geometry>
445
+ <mesh filename="new_objs/102697_link_4_32.obj" />
446
+ </geometry>
447
+ </collision>
448
+ <collision>
449
+ <origin xyz="0 0 0" />
450
+ <geometry>
451
+ <mesh filename="new_objs/102697_link_4_14.obj" />
452
+ </geometry>
453
+ </collision>
454
+ <visual name="base_body-part[&quot;id]">
455
+ <origin xyz="0 0 0" />
456
+ <geometry>
457
+ <mesh filename="textured_objs/original-13.obj" />
458
+ </geometry>
459
+ </visual>
460
+ <visual name="base_body-part[&quot;id]">
461
+ <origin xyz="0 0 0" />
462
+ <geometry>
463
+ <mesh filename="textured_objs/original-6.obj" />
464
+ </geometry>
465
+ </visual>
466
+ <visual name="base_body-part[&quot;id]">
467
+ <origin xyz="0 0 0" />
468
+ <geometry>
469
+ <mesh filename="textured_objs/original-15.obj" />
470
+ </geometry>
471
+ </visual>
472
+ <visual name="base_body-part[&quot;id]">
473
+ <origin xyz="0 0 0" />
474
+ <geometry>
475
+ <mesh filename="textured_objs/original-11.obj" />
476
+ </geometry>
477
+ </visual>
478
+ <visual name="base_body-part[&quot;id]">
479
+ <origin xyz="0 0 0" />
480
+ <geometry>
481
+ <mesh filename="textured_objs/original-7.obj" />
482
+ </geometry>
483
+ </visual>
484
+ <visual name="base_body-part[&quot;id]">
485
+ <origin xyz="0 0 0" />
486
+ <geometry>
487
+ <mesh filename="textured_objs/original-10.obj" />
488
+ </geometry>
489
+ </visual>
490
+ <visual name="base_body-part[&quot;id]">
491
+ <origin xyz="0 0 0" />
492
+ <geometry>
493
+ <mesh filename="textured_objs/original-14.obj" />
494
+ </geometry>
495
+ </visual>
496
+ </link>
497
+ <joint name="joint_4" type="fixed">
498
+ <origin rpy="1.570796326794897 0 -1.570796326794897" xyz="0 0 0" />
499
+ <child link="link_4" />
500
+ <parent link="base" />
501
+ </joint>
502
+ </robot>
dexart-release/assets/sapien/102697/mobility_v2.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"id":0,"parent":4,"joint":"hinge","name":"lever","parts":[{"id":1,"name":"button"}],"jointData":{"axis":{"origin":[-0.387179130380233,0.5229989412652956,-0.2586584187924479],"direction":[-0.91844856752952,-0.00006467369864908537,0.39554042096893877]},"limit":{"a":0,"b":30,"noLimit":false}}},{"id":1,"parent":4,"joint":"slider","name":"pump_lid","parts":[{"id":2,"name":"lid"}],"jointData":{"axis":{"origin":[0,0,0],"direction":[0,1,0]},"limit":{"a":0,"b":0.02,"noLimit":false,"rotates":false,"noRotationLimit":false,"rotationLimit":0}}},{"id":2,"parent":4,"joint":"hinge","name":"lid","parts":[{"id":3,"name":"lid"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":3,"parent":4,"joint":"hinge","name":"seat","parts":[{"id":4,"name":"seat"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":4,"parent":-1,"joint":"static","name":"base_body","parts":[{"id":5,"name":"base_body"}],"jointData":{}}]
dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender MTL File: 'None'
2
+ # Material Count: 1
3
+
4
+ newmtl Shape.14067
5
+ Ns 400.000000
6
+ Ka 0.400000 0.400000 0.400000
7
+ Kd 0.336000 0.232000 0.584000
8
+ Ks 0.250000 0.250000 0.250000
9
+ Ke 0.000000 0.000000 0.000000
10
+ Ni 1.000000
11
+ d 0.500000
12
+ illum 2
dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender MTL File: 'None'
2
+ # Material Count: 1
3
+
4
+ newmtl Shape.14069
5
+ Ns 400.000000
6
+ Ka 0.400000 0.400000 0.400000
7
+ Kd 0.296000 0.784000 0.192000
8
+ Ks 0.250000 0.250000 0.250000
9
+ Ke 0.000000 0.000000 0.000000
10
+ Ni 1.000000
11
+ d 0.500000
12
+ illum 2
dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender v2.79 (sub 0) OBJ File: ''
2
+ # www.blender.org
3
+ mtllib 102697_link_1_14.mtl
4
+ o Shape_IndexedFaceSet.014_Shape_IndexedFaceSet.14061
5
+ v 0.381286 0.716318 -0.278277
6
+ v 0.467366 0.669121 -0.458755
7
+ v 0.467366 0.666342 -0.458755
8
+ v 0.314641 0.669121 -0.486500
9
+ v 0.314641 0.666342 -0.275488
10
+ v 0.314641 0.716318 -0.486500
11
+ v 0.442356 0.716318 -0.486520
12
+ v 0.406252 0.666342 -0.275488
13
+ v 0.314641 0.716318 -0.275488
14
+ v 0.453478 0.702433 -0.433757
15
+ v 0.464367 0.666416 -0.486207
16
+ v 0.397925 0.702433 -0.275488
17
+ v 0.386833 0.666342 -0.486500
18
+ v 0.447917 0.713539 -0.453197
19
+ v 0.433456 0.670746 -0.350149
20
+ v 0.411813 0.705207 -0.311600
21
+ v 0.453478 0.710760 -0.486520
22
+ v 0.459025 0.702433 -0.455986
23
+ v 0.397925 0.713539 -0.300506
24
+ v 0.409033 0.674673 -0.283855
25
+ v 0.430634 0.693883 -0.355457
26
+ v 0.461805 0.669121 -0.436525
27
+ v 0.447917 0.710760 -0.439314
28
+ vn 0.6201 0.7564 0.2082
29
+ vn -1.0000 0.0000 0.0000
30
+ vn 0.0000 1.0000 0.0000
31
+ vn -0.0002 0.0000 -1.0000
32
+ vn 0.0000 -1.0000 0.0000
33
+ vn 0.0000 0.0000 1.0000
34
+ vn 0.9941 0.0000 -0.1086
35
+ vn 0.0406 0.2433 0.9691
36
+ vn -0.0385 -0.9992 -0.0132
37
+ vn 0.0010 -1.0000 -0.0028
38
+ vn 0.1781 0.9826 0.0522
39
+ vn -0.0001 -0.0002 -1.0000
40
+ vn 0.9633 0.2356 -0.1285
41
+ vn -0.0000 -0.0004 -1.0000
42
+ vn 0.0038 -0.0061 -1.0000
43
+ vn 0.4470 0.8945 -0.0000
44
+ vn 0.7133 0.6982 0.0608
45
+ vn 0.9470 0.2175 0.2363
46
+ vn 0.9624 0.2498 -0.1067
47
+ vn 0.5705 0.7506 0.3332
48
+ vn 0.2834 0.9545 0.0928
49
+ vn 0.6574 0.6887 0.3057
50
+ vn 0.8546 0.1972 0.4804
51
+ vn 0.9385 0.0321 0.3438
52
+ vn 0.8973 0.2493 0.3642
53
+ vn 0.9155 0.2217 0.3357
54
+ vn 0.8943 0.3344 0.2974
55
+ vn 0.9251 0.1885 0.3296
56
+ vn 0.9350 0.1837 0.3034
57
+ vn 0.9701 0.0000 0.2427
58
+ vn 0.8018 -0.5344 0.2674
59
+ vn 0.9470 0.2170 0.2369
60
+ vn 0.8672 -0.4031 0.2922
61
+ vn 0.9325 0.2086 0.2948
62
+ vn 0.7291 0.6431 0.2341
63
+ vn 0.7174 0.6831 0.1368
64
+ vn 0.7541 0.6292 0.1882
65
+ vn 0.5142 0.8410 0.1683
66
+ usemtl Shape.14069
67
+ s off
68
+ f 19//1 16//1 23//1
69
+ f 4//2 5//2 6//2
70
+ f 6//3 1//3 7//3
71
+ f 4//4 6//4 7//4
72
+ f 5//5 3//5 8//5
73
+ f 5//6 8//6 9//6
74
+ f 1//3 6//3 9//3
75
+ f 6//2 5//2 9//2
76
+ f 2//7 3//7 11//7
77
+ f 9//6 8//6 12//6
78
+ f 1//8 9//8 12//8
79
+ f 5//9 4//9 13//9
80
+ f 3//5 5//5 13//5
81
+ f 11//10 3//10 13//10
82
+ f 7//11 1//11 14//11
83
+ f 4//12 7//12 17//12
84
+ f 2//13 11//13 17//13
85
+ f 13//14 4//14 17//14
86
+ f 11//15 13//15 17//15
87
+ f 7//16 14//16 17//16
88
+ f 17//17 14//17 18//17
89
+ f 10//18 2//18 18//18
90
+ f 2//19 17//19 18//19
91
+ f 1//20 12//20 19//20
92
+ f 14//21 1//21 19//21
93
+ f 12//22 16//22 19//22
94
+ f 12//23 8//23 20//23
95
+ f 8//24 15//24 20//24
96
+ f 16//25 12//25 20//25
97
+ f 16//26 20//26 21//26
98
+ f 10//27 16//27 21//27
99
+ f 20//28 15//28 21//28
100
+ f 21//29 15//29 22//29
101
+ f 3//30 2//30 22//30
102
+ f 8//31 3//31 22//31
103
+ f 2//32 10//32 22//32
104
+ f 15//33 8//33 22//33
105
+ f 10//34 21//34 22//34
106
+ f 16//35 10//35 23//35
107
+ f 18//36 14//36 23//36
108
+ f 10//37 18//37 23//37
109
+ f 14//38 19//38 23//38
dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender MTL File: 'None'
2
+ # Material Count: 1
3
+
4
+ newmtl Shape.14060
5
+ Ns 400.000000
6
+ Ka 0.400000 0.400000 0.400000
7
+ Kd 0.576000 0.288000 0.088000
8
+ Ks 0.250000 0.250000 0.250000
9
+ Ke 0.000000 0.000000 0.000000
10
+ Ni 1.000000
11
+ d 0.500000
12
+ illum 2
dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender v2.79 (sub 0) OBJ File: ''
2
+ # www.blender.org
3
+ mtllib 102697_link_3_0.mtl
4
+ o Shape_IndexedFaceSet_Shape_IndexedFaceSet.16309
5
+ v -0.318776 0.060655 0.107693
6
+ v -0.063870 0.047615 -0.076335
7
+ v -0.006959 0.042434 0.076198
8
+ v -0.006959 0.088979 -0.019441
9
+ v -0.309458 0.088979 0.114981
10
+ v -0.182789 0.086387 -0.042715
11
+ v -0.257765 0.042434 0.006414
12
+ v -0.312052 0.042434 0.114981
13
+ v -0.167256 0.073461 0.114981
14
+ v -0.006959 0.042434 -0.068570
15
+ v -0.006959 0.073461 0.060669
16
+ v -0.006959 0.088979 -0.065989
17
+ v -0.262922 0.086387 0.008996
18
+ v -0.151785 0.042434 0.114981
19
+ v -0.154809 0.063235 -0.062488
20
+ v -0.255171 0.088979 0.114981
21
+ v -0.265516 0.065707 -0.001350
22
+ v -0.058683 0.083799 -0.073734
23
+ v -0.155045 0.046362 -0.055734
24
+ v -0.006959 0.055369 0.078780
25
+ v -0.317239 0.052781 0.091707
26
+ v -0.309458 0.083799 0.086544
27
+ v -0.069027 0.088979 -0.065989
28
+ v -0.006959 0.070873 -0.078916
29
+ v -0.159536 0.083799 -0.058243
30
+ v -0.107812 0.042434 -0.063407
31
+ v -0.257765 0.045027 -0.001350
32
+ v -0.159811 0.051020 -0.059069
33
+ v -0.306895 0.088979 0.096890
34
+ v -0.149191 0.060536 0.114981
35
+ v -0.056120 0.078632 -0.076335
36
+ v -0.314646 0.045027 0.096890
37
+ v -0.006959 0.047615 -0.076335
38
+ v -0.319802 0.076049 0.114981
39
+ v -0.260328 0.078632 -0.001350
40
+ v -0.312052 0.068285 0.081361
41
+ v -0.250915 0.087266 0.009661
42
+ v -0.006959 0.070873 0.065852
43
+ v -0.304301 0.050193 0.068434
44
+ v -0.268110 0.070873 0.003833
45
+ v -0.006959 0.060536 0.076198
46
+ v -0.006959 0.086387 -0.071152
47
+ v -0.006959 0.078632 0.034833
48
+ v -0.167107 0.053534 -0.056233
49
+ v -0.162264 0.065707 -0.059140
50
+ v -0.317239 0.045027 0.114981
51
+ v -0.198290 0.081220 -0.037551
52
+ v -0.247420 0.042434 -0.001350
53
+ vn -0.2189 -0.8734 -0.4350
54
+ vn 0.0000 -1.0000 0.0000
55
+ vn 0.0000 0.0000 1.0000
56
+ vn 1.0000 0.0000 0.0000
57
+ vn -0.0000 1.0000 0.0000
58
+ vn 0.2540 -0.1893 0.9485
59
+ vn -0.0785 0.9576 -0.2772
60
+ vn -0.0984 0.7614 -0.6407
61
+ vn -0.0223 -0.8995 -0.4364
62
+ vn -0.1804 -0.1959 -0.9639
63
+ vn -0.1633 -0.5893 -0.7912
64
+ vn -0.1721 -0.6838 -0.7091
65
+ vn -0.3065 -0.7412 -0.5972
66
+ vn -0.5117 0.8123 -0.2799
67
+ vn 0.2435 0.3404 0.9082
68
+ vn 0.2453 -0.0352 0.9688
69
+ vn -0.1444 0.0361 -0.9889
70
+ vn -0.0504 0.0126 -0.9986
71
+ vn -0.1621 0.1636 -0.9731
72
+ vn -0.1399 0.3890 -0.9106
73
+ vn -0.2162 -0.9704 -0.1081
74
+ vn -0.4800 -0.8321 -0.2779
75
+ vn 0.0000 -0.8318 -0.5550
76
+ vn 0.0000 -0.1103 -0.9939
77
+ vn -0.9962 -0.0274 -0.0823
78
+ vn -0.7761 0.6209 -0.1100
79
+ vn -0.7791 0.6162 -0.1155
80
+ vn -0.3952 0.6848 -0.6123
81
+ vn -0.9554 0.1496 -0.2548
82
+ vn -0.9305 0.2462 -0.2713
83
+ vn -0.0669 0.9924 -0.1037
84
+ vn -0.0356 0.9974 -0.0631
85
+ vn -0.0693 0.9955 -0.0640
86
+ vn -0.0240 0.9991 -0.0350
87
+ vn 0.1545 0.8754 0.4580
88
+ vn 0.1519 0.8843 0.4415
89
+ vn -0.8035 -0.3012 -0.5135
90
+ vn -0.7750 -0.5093 -0.3742
91
+ vn -0.6037 -0.7165 -0.3495
92
+ vn -0.8716 -0.0235 -0.4897
93
+ vn -0.8750 -0.0297 -0.4832
94
+ vn -0.7861 0.4152 -0.4579
95
+ vn -0.7546 0.4201 -0.5041
96
+ vn -0.7121 0.2858 -0.6413
97
+ vn -0.8694 0.0560 -0.4909
98
+ vn -0.8355 0.2946 -0.4637
99
+ vn 0.2448 0.4334 0.8673
100
+ vn 0.2223 0.6897 0.6891
101
+ vn 0.0000 0.8937 -0.4487
102
+ vn -0.0128 0.8231 -0.5677
103
+ vn 0.0237 0.4474 -0.8940
104
+ vn 0.0214 0.4580 -0.8887
105
+ vn 0.1009 0.9773 0.1863
106
+ vn 0.1037 0.9753 0.1952
107
+ vn -0.4960 -0.1859 -0.8482
108
+ vn -0.3894 -0.0969 -0.9159
109
+ vn -0.4480 -0.3961 -0.8015
110
+ vn -0.3790 0.1027 -0.9197
111
+ vn -0.4224 -0.0481 -0.9051
112
+ vn -0.4884 -0.0141 -0.8725
113
+ vn -0.9926 -0.1156 -0.0385
114
+ vn -0.4462 -0.8926 -0.0640
115
+ vn -0.9107 -0.3918 -0.1305
116
+ vn -0.9960 -0.0823 0.0336
117
+ vn -0.4229 0.5449 -0.7240
118
+ vn -0.4253 0.5896 -0.6867
119
+ vn -0.5000 0.2007 -0.8425
120
+ vn -0.4868 0.0799 -0.8698
121
+ vn -0.4738 0.1147 -0.8731
122
+ vn -0.1247 -0.9517 -0.2806
123
+ vn -0.2313 -0.9228 -0.3082
124
+ usemtl Shape.16319
125
+ s off
126
+ f 27//1 19//1 48//1
127
+ f 7//2 3//2 8//2
128
+ f 5//3 8//3 9//3
129
+ f 4//4 3//4 10//4
130
+ f 3//2 7//2 10//2
131
+ f 3//4 4//4 11//4
132
+ f 5//5 4//5 12//5
133
+ f 4//4 10//4 12//4
134
+ f 8//2 3//2 14//2
135
+ f 9//3 8//3 14//3
136
+ f 4//5 5//5 16//5
137
+ f 5//3 9//3 16//3
138
+ f 3//4 11//4 20//4
139
+ f 14//6 3//6 20//6
140
+ f 5//5 12//5 23//5
141
+ f 12//4 10//4 24//4
142
+ f 6//7 23//7 25//7
143
+ f 23//8 18//8 25//8
144
+ f 2//9 10//9 26//9
145
+ f 10//2 7//2 26//2
146
+ f 15//10 2//10 28//10
147
+ f 2//11 26//11 28//11
148
+ f 26//12 19//12 28//12
149
+ f 19//13 27//13 28//13
150
+ f 13//14 22//14 29//14
151
+ f 5//5 23//5 29//5
152
+ f 9//3 14//3 30//3
153
+ f 20//15 9//15 30//15
154
+ f 14//16 20//16 30//16
155
+ f 2//17 15//17 31//17
156
+ f 24//18 2//18 31//18
157
+ f 15//19 25//19 31//19
158
+ f 25//20 18//20 31//20
159
+ f 7//21 8//21 32//21
160
+ f 27//22 7//22 32//22
161
+ f 10//23 2//23 33//23
162
+ f 2//24 24//24 33//24
163
+ f 24//4 10//4 33//4
164
+ f 8//3 5//3 34//3
165
+ f 21//25 1//25 34//25
166
+ f 5//26 29//26 34//26
167
+ f 29//27 22//27 34//27
168
+ f 13//28 6//28 35//28
169
+ f 21//29 34//29 36//29
170
+ f 34//30 22//30 36//30
171
+ f 6//31 13//31 37//31
172
+ f 23//32 6//32 37//32
173
+ f 13//33 29//33 37//33
174
+ f 29//34 23//34 37//34
175
+ f 16//35 9//35 38//35
176
+ f 11//36 16//36 38//36
177
+ f 20//4 11//4 38//4
178
+ f 17//37 27//37 39//37
179
+ f 32//38 21//38 39//38
180
+ f 27//39 32//39 39//39
181
+ f 36//40 17//40 39//40
182
+ f 21//41 36//41 39//41
183
+ f 22//42 13//42 40//42
184
+ f 13//43 35//43 40//43
185
+ f 35//44 17//44 40//44
186
+ f 17//45 36//45 40//45
187
+ f 36//46 22//46 40//46
188
+ f 9//47 20//47 41//47
189
+ f 38//48 9//48 41//48
190
+ f 20//4 38//4 41//4
191
+ f 23//49 12//49 42//49
192
+ f 18//50 23//50 42//50
193
+ f 12//4 24//4 42//4
194
+ f 24//51 31//51 42//51
195
+ f 31//52 18//52 42//52
196
+ f 11//4 4//4 43//4
197
+ f 4//53 16//53 43//53
198
+ f 16//54 11//54 43//54
199
+ f 27//55 17//55 44//55
200
+ f 15//56 28//56 44//56
201
+ f 28//57 27//57 44//57
202
+ f 25//58 15//58 45//58
203
+ f 15//59 44//59 45//59
204
+ f 44//60 17//60 45//60
205
+ f 1//61 21//61 46//61
206
+ f 32//62 8//62 46//62
207
+ f 21//63 32//63 46//63
208
+ f 34//64 1//64 46//64
209
+ f 8//3 34//3 46//3
210
+ f 6//65 25//65 47//65
211
+ f 35//66 6//66 47//66
212
+ f 17//67 35//67 47//67
213
+ f 45//68 17//68 47//68
214
+ f 25//69 45//69 47//69
215
+ f 26//2 7//2 48//2
216
+ f 19//70 26//70 48//70
217
+ f 7//71 27//71 48//71
dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender v2.79 (sub 0) OBJ File: ''
2
+ # www.blender.org
3
+ mtllib 102697_link_3_5.mtl
4
+ o Shape_IndexedFaceSet.005_Shape_IndexedFaceSet.16314
5
+ v -0.006959 0.055369 0.663225
6
+ v -0.035413 0.042434 0.789889
7
+ v -0.006985 0.088979 0.766613
8
+ v -0.224145 0.088979 0.575274
9
+ v -0.112985 0.042434 0.575296
10
+ v -0.244834 0.042434 0.624424
11
+ v -0.216380 0.083799 0.676129
12
+ v -0.128515 0.070873 0.575274
13
+ v -0.154338 0.047615 0.743337
14
+ v -0.061288 0.088979 0.779539
15
+ v -0.006959 0.042434 0.792466
16
+ v -0.006985 0.073461 0.683903
17
+ v -0.006959 0.042434 0.668378
18
+ v -0.260364 0.088979 0.590820
19
+ v -0.030254 0.068285 0.800239
20
+ v -0.221565 0.065707 0.678706
21
+ v -0.268103 0.042434 0.575274
22
+ v -0.146599 0.081220 0.745913
23
+ v -0.081951 0.055369 0.784714
24
+ v -0.006985 0.088979 0.789889
25
+ v -0.270683 0.060536 0.598550
26
+ v -0.102641 0.042434 0.764037
27
+ v -0.206061 0.055369 0.696785
28
+ v -0.150775 0.086767 0.728688
29
+ v -0.087137 0.076049 0.779539
30
+ v -0.154338 0.065707 0.745913
31
+ v -0.262944 0.081220 0.603725
32
+ v -0.273262 0.081220 0.575274
33
+ v -0.262944 0.047615 0.606323
34
+ v -0.032834 0.047615 0.797640
35
+ v -0.006985 0.068294 0.800239
36
+ v -0.244834 0.088979 0.619249
37
+ v -0.030254 0.083799 0.795064
38
+ v -0.006985 0.070878 0.676129
39
+ v -0.087137 0.050193 0.782138
40
+ v -0.006985 0.081215 0.722659
41
+ v -0.213800 0.078632 0.683881
42
+ v -0.242254 0.052781 0.645102
43
+ v -0.180186 0.081220 0.575274
44
+ v -0.143993 0.083799 0.745913
45
+ v -0.079372 0.083799 0.779539
46
+ v -0.209848 0.046828 0.679494
47
+ v -0.262944 0.088979 0.575274
48
+ v -0.009565 0.047615 0.660627
49
+ v -0.092296 0.068285 0.779539
50
+ v -0.159523 0.070873 0.740739
51
+ v -0.275868 0.068285 0.583047
52
+ v -0.273262 0.045027 0.575274
53
+ v -0.032834 0.055369 0.800239
54
+ v -0.125910 0.068285 0.575274
55
+ v -0.006985 0.063123 0.668378
56
+ v -0.074186 0.045027 0.782138
57
+ v -0.204443 0.044694 0.679584
58
+ vn -0.3528 -0.8802 0.3176
59
+ vn 0.0000 -1.0000 -0.0000
60
+ vn -0.0000 1.0000 -0.0000
61
+ vn 1.0000 0.0000 0.0000
62
+ vn 0.0000 0.0000 -1.0000
63
+ vn 1.0000 0.0008 0.0000
64
+ vn 1.0000 0.0006 0.0001
65
+ vn -0.6739 -0.1041 0.7314
66
+ vn -0.8426 0.1891 0.5042
67
+ vn -0.7059 0.6605 0.2560
68
+ vn -0.0513 -0.8223 0.5667
69
+ vn 1.0000 0.0006 0.0013
70
+ vn -0.1903 0.9645 0.1830
71
+ vn -0.0704 0.9943 0.0806
72
+ vn -0.5935 0.7366 0.3242
73
+ vn -0.6160 0.6947 0.3714
74
+ vn -0.0901 0.8767 0.4726
75
+ vn 0.0988 0.4453 0.8899
76
+ vn -0.0001 0.3164 0.9486
77
+ vn 1.0000 0.0023 -0.0008
78
+ vn 0.2531 0.9181 -0.3050
79
+ vn -0.3097 -0.7486 0.5862
80
+ vn -0.4927 -0.1227 0.8615
81
+ vn 1.0000 0.0017 -0.0003
82
+ vn 0.1515 0.9734 -0.1719
83
+ vn 1.0000 0.0019 -0.0004
84
+ vn 0.1659 0.9670 -0.1935
85
+ vn -0.7073 0.1481 0.6912
86
+ vn -0.7933 0.3506 0.4977
87
+ vn -0.8130 0.2852 0.5077
88
+ vn -0.8528 0.0077 0.5221
89
+ vn -0.7975 -0.2017 0.5686
90
+ vn -0.6081 -0.6770 0.4146
91
+ vn -0.8486 -0.2184 0.4819
92
+ vn 0.1702 0.9644 -0.2025
93
+ vn 0.1908 0.9528 -0.2362
94
+ vn -0.2498 0.9330 0.2591
95
+ vn -0.1520 0.9623 0.2257
96
+ vn -0.5596 0.5915 0.5805
97
+ vn -0.5675 0.5734 0.5909
98
+ vn -0.2916 0.2921 0.9108
99
+ vn -0.4189 0.4197 0.8052
100
+ vn -0.2076 0.7248 0.6569
101
+ vn -0.2872 0.3031 0.9086
102
+ vn -0.2434 0.8497 0.4677
103
+ vn -0.4183 0.4227 0.8039
104
+ vn -0.5264 -0.7109 0.4664
105
+ vn -0.5764 -0.6997 0.4220
106
+ vn -0.6036 -0.6544 0.4554
107
+ vn -0.5981 0.7953 0.0993
108
+ vn 0.6344 0.0453 -0.7717
109
+ vn 0.8968 -0.1637 -0.4110
110
+ vn 0.5171 -0.6211 -0.5890
111
+ vn -0.3140 0.1256 0.9411
112
+ vn -0.3097 0.2058 0.9283
113
+ vn -0.4499 0.2990 0.8416
114
+ vn -0.4707 0.2348 0.8505
115
+ vn -0.4460 0.0014 0.8950
116
+ vn -0.4762 -0.0095 0.8793
117
+ vn -0.5346 0.2667 0.8020
118
+ vn -0.6914 0.0290 0.7219
119
+ vn -0.6165 0.4454 0.6493
120
+ vn -0.7048 0.1501 0.6933
121
+ vn -0.8834 0.2281 0.4094
122
+ vn -0.8746 0.3668 0.3172
123
+ vn -0.9481 0.0000 -0.3179
124
+ vn -0.4393 -0.8740 0.2080
125
+ vn -0.4713 -0.8521 0.2276
126
+ vn -0.8850 -0.3363 0.3221
127
+ vn -0.9562 -0.1834 0.2281
128
+ vn -0.3008 0.0601 0.9518
129
+ vn 0.1250 -0.3153 0.9407
130
+ vn 0.1424 -0.2848 0.9480
131
+ vn 0.0000 0.0000 1.0000
132
+ vn -0.2970 -0.1701 0.9396
133
+ vn -0.2748 -0.3056 0.9117
134
+ vn 0.5887 0.2937 -0.7531
135
+ vn 0.0001 -0.0008 -1.0000
136
+ vn 0.9999 0.0100 -0.0100
137
+ vn 0.5062 0.6097 -0.6100
138
+ vn 0.5596 0.4600 -0.6894
139
+ vn 0.5341 0.5376 -0.6524
140
+ vn -0.1298 -0.9323 0.3377
141
+ vn -0.1701 -0.7922 0.5861
142
+ vn -0.3013 -0.7554 0.5819
143
+ vn -0.2439 -0.6115 0.7527
144
+ vn -0.1402 -0.9798 0.1428
145
+ vn -0.1678 -0.9699 0.1763
146
+ vn -0.3550 -0.8867 0.2963
147
+ usemtl Shape.16324
148
+ s off
149
+ f 9//1 42//1 53//1
150
+ f 5//2 2//2 6//2
151
+ f 3//3 4//3 10//3
152
+ f 2//2 5//2 11//2
153
+ f 1//4 11//4 13//4
154
+ f 11//2 5//2 13//2
155
+ f 10//3 4//3 14//3
156
+ f 5//2 6//2 17//2
157
+ f 4//5 8//5 17//5
158
+ f 1//6 3//6 20//6
159
+ f 3//3 10//3 20//3
160
+ f 11//7 1//7 20//7
161
+ f 6//2 2//2 22//2
162
+ f 23//8 9//8 26//8
163
+ f 21//9 16//9 27//9
164
+ f 4//5 17//5 28//5
165
+ f 27//10 14//10 28//10
166
+ f 2//11 11//11 30//11
167
+ f 11//12 20//12 31//12
168
+ f 10//3 14//3 32//3
169
+ f 7//13 24//13 32//13
170
+ f 24//14 10//14 32//14
171
+ f 14//15 27//15 32//15
172
+ f 27//16 7//16 32//16
173
+ f 20//17 10//17 33//17
174
+ f 31//18 20//18 33//18
175
+ f 15//19 31//19 33//19
176
+ f 12//20 1//20 34//20
177
+ f 8//21 12//21 34//21
178
+ f 9//22 22//22 35//22
179
+ f 26//23 9//23 35//23
180
+ f 3//24 1//24 36//24
181
+ f 4//25 3//25 36//25
182
+ f 1//26 12//26 36//26
183
+ f 12//27 4//27 36//27
184
+ f 16//28 23//28 37//28
185
+ f 7//29 27//29 37//29
186
+ f 27//30 16//30 37//30
187
+ f 16//31 21//31 38//31
188
+ f 23//32 16//32 38//32
189
+ f 29//33 6//33 38//33
190
+ f 21//34 29//34 38//34
191
+ f 8//5 4//5 39//5
192
+ f 4//35 12//35 39//35
193
+ f 12//36 8//36 39//36
194
+ f 24//37 7//37 40//37
195
+ f 10//38 24//38 40//38
196
+ f 7//39 37//39 40//39
197
+ f 37//40 18//40 40//40
198
+ f 25//41 15//41 41//41
199
+ f 18//42 25//42 41//42
200
+ f 33//43 10//43 41//43
201
+ f 15//44 33//44 41//44
202
+ f 10//45 40//45 41//45
203
+ f 40//46 18//46 41//46
204
+ f 9//47 23//47 42//47
205
+ f 38//48 6//48 42//48
206
+ f 23//49 38//49 42//49
207
+ f 14//3 4//3 43//3
208
+ f 4//5 28//5 43//5
209
+ f 28//50 14//50 43//50
210
+ f 5//51 1//51 44//51
211
+ f 1//52 13//52 44//52
212
+ f 13//53 5//53 44//53
213
+ f 19//54 15//54 45//54
214
+ f 15//55 25//55 45//55
215
+ f 25//56 18//56 45//56
216
+ f 18//57 26//57 45//57
217
+ f 35//58 19//58 45//58
218
+ f 26//59 35//59 45//59
219
+ f 26//60 18//60 46//60
220
+ f 23//61 26//61 46//61
221
+ f 18//62 37//62 46//62
222
+ f 37//63 23//63 46//63
223
+ f 21//64 27//64 47//64
224
+ f 27//65 28//65 47//65
225
+ f 47//66 28//66 48//66
226
+ f 17//67 6//67 48//67
227
+ f 28//5 17//5 48//5
228
+ f 6//68 29//68 48//68
229
+ f 29//69 21//69 48//69
230
+ f 21//70 47//70 48//70
231
+ f 15//71 19//71 49//71
232
+ f 30//72 11//72 49//72
233
+ f 11//73 31//73 49//73
234
+ f 31//74 15//74 49//74
235
+ f 19//75 35//75 49//75
236
+ f 35//76 30//76 49//76
237
+ f 1//77 5//77 50//77
238
+ f 5//78 17//78 50//78
239
+ f 17//5 8//5 50//5
240
+ f 34//79 1//79 51//79
241
+ f 8//80 34//80 51//80
242
+ f 1//81 50//81 51//81
243
+ f 50//82 8//82 51//82
244
+ f 22//83 2//83 52//83
245
+ f 2//84 30//84 52//84
246
+ f 35//85 22//85 52//85
247
+ f 30//86 35//86 52//86
248
+ f 6//87 22//87 53//87
249
+ f 22//88 9//88 53//88
250
+ f 42//89 6//89 53//89
dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender v2.79 (sub 0) OBJ File: ''
2
+ # www.blender.org
3
+ mtllib 102697_link_3_6.mtl
4
+ o Shape_IndexedFaceSet.006_Shape_IndexedFaceSet.16315
5
+ v 0.267139 0.042434 0.557173
6
+ v 0.174056 0.055369 0.316720
7
+ v 0.220597 0.088979 0.557173
8
+ v 0.117181 0.047615 0.557173
9
+ v 0.321419 0.088979 0.316720
10
+ v 0.176650 0.042434 0.316720
11
+ v 0.192149 0.073461 0.316720
12
+ v 0.272305 0.086387 0.557173
13
+ v 0.324013 0.042434 0.316720
14
+ v 0.132681 0.070873 0.549382
15
+ v 0.311086 0.052781 0.466641
16
+ v 0.272305 0.088979 0.316720
17
+ v 0.124942 0.042434 0.557173
18
+ v 0.308492 0.081220 0.461463
19
+ v 0.137869 0.073461 0.557173
20
+ v 0.282638 0.068285 0.549382
21
+ v 0.298159 0.088979 0.469231
22
+ v 0.334346 0.068290 0.319286
23
+ v 0.308492 0.045027 0.458873
24
+ v 0.124942 0.065707 0.554560
25
+ v 0.334346 0.052781 0.316697
26
+ v 0.184389 0.070873 0.321875
27
+ v 0.285232 0.050193 0.539048
28
+ v 0.324013 0.083799 0.358079
29
+ v 0.267139 0.088979 0.557173
30
+ v 0.137869 0.073461 0.551971
31
+ v 0.119775 0.057957 0.554560
32
+ v 0.174056 0.045027 0.316720
33
+ v 0.316253 0.042434 0.383948
34
+ v 0.331774 0.047610 0.321875
35
+ v 0.329180 0.083803 0.316697
36
+ v 0.272305 0.045027 0.557173
37
+ v 0.220597 0.088979 0.551971
38
+ v 0.119775 0.045027 0.551971
39
+ v 0.313680 0.088979 0.386514
40
+ v 0.282638 0.042434 0.520946
41
+ v 0.179222 0.063128 0.316720
42
+ vn -0.0002 0.0002 -1.0000
43
+ vn 0.0000 0.0000 1.0000
44
+ vn 0.0000 -1.0000 -0.0000
45
+ vn 0.0000 1.0000 0.0000
46
+ vn 0.8767 0.3663 0.3117
47
+ vn 0.9441 0.1404 0.2983
48
+ vn 0.6868 0.6920 0.2223
49
+ vn 0.9787 0.1197 0.1671
50
+ vn -0.3480 0.2786 0.8951
51
+ vn -0.5230 0.8497 0.0660
52
+ vn 0.0000 -0.0022 -1.0000
53
+ vn 0.9879 -0.0256 0.1532
54
+ vn -0.6125 0.7781 -0.1392
55
+ vn 0.9409 -0.0559 0.3340
56
+ vn 0.8012 -0.5355 0.2670
57
+ vn 0.9537 0.2610 0.1497
58
+ vn 0.4429 0.8828 0.1562
59
+ vn -0.1899 0.9808 -0.0438
60
+ vn -0.1844 0.9829 0.0000
61
+ vn -0.4464 0.8948 0.0000
62
+ vn -0.3652 0.9271 -0.0843
63
+ vn -0.4071 0.9087 -0.0925
64
+ vn -0.9578 0.1845 -0.2206
65
+ vn -0.4906 0.3271 0.8076
66
+ vn -0.9731 0.0000 -0.2302
67
+ vn -0.0001 -0.0001 -1.0000
68
+ vn 0.8869 -0.4394 0.1424
69
+ vn 0.6665 -0.6663 -0.3344
70
+ vn 0.9361 -0.3202 0.1452
71
+ vn 0.5256 -0.8486 0.0607
72
+ vn 0.6247 -0.7755 0.0915
73
+ vn 0.0000 0.0044 -1.0000
74
+ vn -0.0003 0.0014 -1.0000
75
+ vn 0.7031 0.1171 -0.7014
76
+ vn 0.9363 0.3313 0.1169
77
+ vn -0.0001 0.0000 -1.0000
78
+ vn 0.6020 0.0000 0.7985
79
+ vn 0.8242 -0.1871 0.5345
80
+ vn -0.1842 0.9821 -0.0405
81
+ vn -0.5501 -0.8240 0.1356
82
+ vn -0.3801 -0.9213 -0.0817
83
+ vn -0.8628 -0.4647 -0.1991
84
+ vn -0.6977 -0.6980 -0.1610
85
+ vn 0.6532 0.7472 0.1226
86
+ vn 0.6915 0.7120 0.1216
87
+ vn 0.5539 0.8303 0.0614
88
+ vn 0.6054 0.7923 0.0757
89
+ vn 0.6261 -0.7451 0.2297
90
+ vn 0.2367 -0.9698 0.0581
91
+ vn 0.4406 -0.8777 0.1885
92
+ vn 0.6228 -0.7475 0.2311
93
+ vn -0.5476 0.6851 -0.4804
94
+ vn -0.7588 0.6260 -0.1800
95
+ vn -0.8168 0.5439 -0.1923
96
+ vn -0.8165 0.5444 -0.1922
97
+ vn -0.0002 0.0001 -1.0000
98
+ usemtl Shape.16325
99
+ s off
100
+ f 7//1 31//1 37//1
101
+ f 1//2 3//2 4//2
102
+ f 3//2 1//2 8//2
103
+ f 1//3 6//3 9//3
104
+ f 3//4 5//4 12//4
105
+ f 1//2 4//2 13//2
106
+ f 6//3 1//3 13//3
107
+ f 4//2 3//2 15//2
108
+ f 14//5 8//5 16//5
109
+ f 11//6 14//6 16//6
110
+ f 5//4 3//4 17//4
111
+ f 8//7 14//7 17//7
112
+ f 14//8 11//8 18//8
113
+ f 4//9 15//9 20//9
114
+ f 15//10 10//10 20//10
115
+ f 9//11 6//11 21//11
116
+ f 18//12 11//12 21//12
117
+ f 20//13 10//13 22//13
118
+ f 11//14 16//14 23//14
119
+ f 19//15 11//15 23//15
120
+ f 14//16 18//16 24//16
121
+ f 3//2 8//2 25//2
122
+ f 17//4 3//4 25//4
123
+ f 8//17 17//17 25//17
124
+ f 12//18 7//18 26//18
125
+ f 15//19 3//19 26//19
126
+ f 10//20 15//20 26//20
127
+ f 7//21 22//21 26//21
128
+ f 22//22 10//22 26//22
129
+ f 2//23 4//23 27//23
130
+ f 4//24 20//24 27//24
131
+ f 4//25 2//25 28//25
132
+ f 21//26 6//26 28//26
133
+ f 1//3 9//3 29//3
134
+ f 11//27 19//27 30//27
135
+ f 9//28 21//28 30//28
136
+ f 21//29 11//29 30//29
137
+ f 29//30 9//30 30//30
138
+ f 19//31 29//31 30//31
139
+ f 12//32 5//32 31//32
140
+ f 7//33 12//33 31//33
141
+ f 18//34 21//34 31//34
142
+ f 24//35 18//35 31//35
143
+ f 28//36 2//36 31//36
144
+ f 21//36 28//36 31//36
145
+ f 8//2 1//2 32//2
146
+ f 16//37 8//37 32//37
147
+ f 23//38 16//38 32//38
148
+ f 3//4 12//4 33//4
149
+ f 26//19 3//19 33//19
150
+ f 12//39 26//39 33//39
151
+ f 13//40 4//40 34//40
152
+ f 6//41 13//41 34//41
153
+ f 4//42 28//42 34//42
154
+ f 28//43 6//43 34//43
155
+ f 5//4 17//4 35//4
156
+ f 17//44 14//44 35//44
157
+ f 14//45 24//45 35//45
158
+ f 31//46 5//46 35//46
159
+ f 24//47 31//47 35//47
160
+ f 19//48 23//48 36//48
161
+ f 1//3 29//3 36//3
162
+ f 29//49 19//49 36//49
163
+ f 32//50 1//50 36//50
164
+ f 23//51 32//51 36//51
165
+ f 22//52 7//52 37//52
166
+ f 20//53 22//53 37//53
167
+ f 2//54 27//54 37//54
168
+ f 27//55 20//55 37//55
169
+ f 31//56 2//56 37//56
dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender v2.79 (sub 0) OBJ File: ''
2
+ # www.blender.org
3
+ mtllib 102697_link_4_11.mtl
4
+ o Shape_IndexedFaceSet.011_Shape_IndexedFaceSet.7497
5
+ v -0.323117 -0.045453 0.089563
6
+ v -0.106232 0.036714 0.031262
7
+ v -0.106232 -0.014976 0.031262
8
+ v -0.207356 0.064630 -0.188472
9
+ v -0.342601 0.058882 0.201139
10
+ v -0.106232 -0.022362 -0.168161
11
+ v -0.269912 -0.053693 0.229603
12
+ v -0.099529 0.061840 -0.178611
13
+ v -0.268728 0.058882 0.223327
14
+ v -0.217054 -0.014976 -0.168161
15
+ v -0.114215 -0.048395 -0.086531
16
+ v -0.354324 -0.050570 0.218569
17
+ v -0.106232 -0.051916 0.009150
18
+ v -0.349976 0.051485 0.178989
19
+ v -0.239154 0.021942 0.215931
20
+ v -0.224454 0.036714 -0.168161
21
+ v -0.170368 -0.047738 -0.078248
22
+ v -0.331607 -0.019798 0.104076
23
+ v -0.246554 -0.022362 0.223327
24
+ v -0.202304 -0.022362 -0.168161
25
+ v -0.126343 0.050010 0.012102
26
+ v -0.246554 0.044099 0.223327
27
+ v -0.224454 -0.000205 -0.168161
28
+ vn -0.9305 0.0000 -0.3663
29
+ vn 0.9995 0.0000 0.0319
30
+ vn 0.0870 -0.1295 -0.9878
31
+ vn -0.0040 0.9999 0.0134
32
+ vn 0.0243 0.9996 0.0176
33
+ vn -0.2829 0.1803 0.9420
34
+ vn -0.1273 0.0565 0.9902
35
+ vn 0.9995 -0.0157 0.0262
36
+ vn 0.9966 -0.0810 -0.0135
37
+ vn 0.8380 -0.5383 -0.0897
38
+ vn -0.4515 0.8806 -0.1437
39
+ vn -0.9523 0.1448 0.2687
40
+ vn 0.8116 0.0000 0.5842
41
+ vn -0.8922 0.3024 -0.3355
42
+ vn -0.0793 -0.9951 -0.0587
43
+ vn -0.0330 -0.9990 -0.0300
44
+ vn -0.0269 -0.9992 -0.0280
45
+ vn -0.0169 -0.9992 -0.0354
46
+ vn -0.9542 -0.1811 -0.2380
47
+ vn -0.9782 -0.0375 -0.2042
48
+ vn -0.9321 0.1193 -0.3421
49
+ vn 0.7252 -0.4335 0.5349
50
+ vn 0.7686 -0.3286 0.5489
51
+ vn 0.8076 -0.0366 0.5886
52
+ vn -0.0000 -0.2274 -0.9738
53
+ vn -0.4300 -0.8588 -0.2785
54
+ vn -0.1163 -0.2322 -0.9657
55
+ vn 0.0000 -0.9527 -0.3038
56
+ vn -0.2235 -0.9559 -0.1904
57
+ vn -0.0489 -0.9657 -0.2552
58
+ vn 0.4655 0.8769 0.1198
59
+ vn 0.3830 0.8970 0.2205
60
+ vn 0.1907 0.9778 0.0875
61
+ vn 0.5197 0.7795 0.3497
62
+ vn 0.0368 0.0552 0.9978
63
+ vn 0.8066 0.0736 0.5865
64
+ vn 0.2595 0.0000 0.9657
65
+ vn 0.7069 0.0000 0.7073
66
+ vn -0.8241 -0.4128 -0.3879
67
+ vn -0.3724 -0.1866 -0.9091
68
+ vn -0.7650 0.0000 -0.6440
69
+ vn -0.9238 -0.0961 -0.3705
70
+ usemtl Shape.7503
71
+ s off
72
+ f 18//1 16//1 23//1
73
+ f 2//2 3//2 8//2
74
+ f 6//3 4//3 8//3
75
+ f 4//4 5//4 9//4
76
+ f 8//5 4//5 9//5
77
+ f 9//6 5//6 12//6
78
+ f 7//7 9//7 12//7
79
+ f 8//8 3//8 13//8
80
+ f 6//9 8//9 13//9
81
+ f 11//10 6//10 13//10
82
+ f 5//11 4//11 14//11
83
+ f 12//12 5//12 14//12
84
+ f 3//13 2//13 15//13
85
+ f 14//14 4//14 16//14
86
+ f 12//15 1//15 17//15
87
+ f 7//16 12//16 17//16
88
+ f 13//17 7//17 17//17
89
+ f 11//18 13//18 17//18
90
+ f 1//19 12//19 18//19
91
+ f 12//20 14//20 18//20
92
+ f 14//21 16//21 18//21
93
+ f 7//22 13//22 19//22
94
+ f 13//23 3//23 19//23
95
+ f 3//24 15//24 19//24
96
+ f 4//25 6//25 20//25
97
+ f 1//26 10//26 20//26
98
+ f 10//27 4//27 20//27
99
+ f 6//28 11//28 20//28
100
+ f 17//29 1//29 20//29
101
+ f 11//30 17//30 20//30
102
+ f 2//31 8//31 21//31
103
+ f 9//32 2//32 21//32
104
+ f 8//33 9//33 21//33
105
+ f 2//34 9//34 22//34
106
+ f 9//35 7//35 22//35
107
+ f 15//36 2//36 22//36
108
+ f 7//37 19//37 22//37
109
+ f 19//38 15//38 22//38
110
+ f 10//39 1//39 23//39
111
+ f 4//40 10//40 23//40
112
+ f 16//41 4//41 23//41
113
+ f 1//42 18//42 23//42
dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender v2.79 (sub 0) OBJ File: ''
2
+ # www.blender.org
3
+ mtllib 102697_link_4_19.mtl
4
+ o Shape_IndexedFaceSet.019_Shape_IndexedFaceSet.7505
5
+ v 0.418781 0.627924 -0.508767
6
+ v 0.428328 0.622978 -0.508489
7
+ v -0.056475 0.145332 -0.517066
8
+ v -0.076686 0.620168 -0.567029
9
+ v 0.366468 0.620168 -0.567029
10
+ v -0.076686 0.125411 -0.567029
11
+ v -0.054501 0.599577 -0.517235
12
+ v 0.320187 0.115689 -0.504922
13
+ v 0.322183 0.125411 -0.567029
14
+ v 0.418075 0.269311 -0.508853
15
+ v 0.411683 0.362817 -0.537486
16
+ v -0.067659 0.626136 -0.530364
17
+ v 0.388636 0.635034 -0.507943
18
+ v 0.389069 0.133686 -0.508902
19
+ v 0.351689 0.214056 -0.567029
20
+ v 0.314606 0.110688 -0.506236
21
+ v 0.413343 0.608120 -0.537486
22
+ v 0.390810 0.209582 -0.537486
23
+ v 0.432230 0.376647 -0.508872
24
+ v 0.373857 0.605402 -0.567029
25
+ v 0.359079 0.620168 -0.507943
26
+ v -0.067415 0.625895 -0.551722
27
+ vn 0.0041 0.9694 -0.2455
28
+ vn 0.4154 0.7751 -0.4762
29
+ vn 0.0000 0.0000 -1.0000
30
+ vn -0.0322 0.0005 0.9995
31
+ vn -0.9254 -0.0111 0.3788
32
+ vn -0.9710 -0.0000 0.2391
33
+ vn -0.7030 0.0033 0.7112
34
+ vn -0.0527 0.4214 0.9053
35
+ vn 0.0729 0.1956 0.9780
36
+ vn 0.2100 0.9266 -0.3119
37
+ vn 0.0149 0.0038 0.9999
38
+ vn 0.0612 -0.0134 0.9980
39
+ vn -0.0967 -0.9105 0.4021
40
+ vn -0.0476 -0.2038 0.9778
41
+ vn 0.0000 -0.9719 -0.2354
42
+ vn 0.2833 -0.9396 -0.1922
43
+ vn 0.1599 -0.4139 0.8961
44
+ vn 0.5278 0.6137 -0.5872
45
+ vn 0.8262 -0.1125 -0.5520
46
+ vn 0.6491 -0.2811 -0.7069
47
+ vn 0.8763 -0.1873 -0.4438
48
+ vn 0.5947 -0.0810 -0.7998
49
+ vn 0.5773 -0.1922 -0.7936
50
+ vn 0.0375 -0.0010 0.9993
51
+ vn 0.0502 -0.0064 0.9987
52
+ vn 0.8317 -0.1098 -0.5442
53
+ vn 0.8852 0.0147 -0.4650
54
+ vn 0.8135 -0.0055 -0.5815
55
+ vn 0.4969 -0.0281 -0.8673
56
+ vn 0.5624 0.2814 -0.7775
57
+ vn 0.5992 -0.0041 -0.8006
58
+ vn -0.0228 0.0077 0.9997
59
+ vn -0.0249 0.0495 0.9985
60
+ vn -0.0031 0.0062 1.0000
61
+ vn 0.0000 0.9366 -0.3504
62
+ vn -0.5068 0.8619 -0.0155
63
+ vn -0.0189 0.9998 -0.0115
64
+ usemtl Shape.7511
65
+ s off
66
+ f 13//1 5//1 22//1
67
+ f 1//2 2//2 5//2
68
+ f 4//3 5//3 6//3
69
+ f 7//4 3//4 8//4
70
+ f 6//3 5//3 9//3
71
+ f 6//5 3//5 12//5
72
+ f 4//6 6//6 12//6
73
+ f 3//7 7//7 12//7
74
+ f 12//8 7//8 13//8
75
+ f 2//9 1//9 13//9
76
+ f 1//10 5//10 13//10
77
+ f 8//11 2//11 13//11
78
+ f 10//12 8//12 14//12
79
+ f 9//3 5//3 15//3
80
+ f 3//13 6//13 16//13
81
+ f 8//14 3//14 16//14
82
+ f 6//15 9//15 16//15
83
+ f 9//16 14//16 16//16
84
+ f 14//17 8//17 16//17
85
+ f 5//18 2//18 17//18
86
+ f 11//19 10//19 18//19
87
+ f 14//20 9//20 18//20
88
+ f 10//21 14//21 18//21
89
+ f 15//22 11//22 18//22
90
+ f 9//23 15//23 18//23
91
+ f 2//24 8//24 19//24
92
+ f 8//25 10//25 19//25
93
+ f 10//26 11//26 19//26
94
+ f 17//27 2//27 19//27
95
+ f 11//28 17//28 19//28
96
+ f 11//29 15//29 20//29
97
+ f 15//3 5//3 20//3
98
+ f 5//30 17//30 20//30
99
+ f 17//31 11//31 20//31
100
+ f 7//32 8//32 21//32
101
+ f 13//33 7//33 21//33
102
+ f 8//34 13//34 21//34
103
+ f 5//35 4//35 22//35
104
+ f 4//36 12//36 22//36
105
+ f 12//37 13//37 22//37
dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender MTL File: 'None'
2
+ # Material Count: 1
3
+
4
+ newmtl Shape.7495
5
+ Ns 400.000000
6
+ Ka 0.400000 0.400000 0.400000
7
+ Kd 0.168000 0.496000 0.216000
8
+ Ks 0.250000 0.250000 0.250000
9
+ Ke 0.000000 0.000000 0.000000
10
+ Ni 1.000000
11
+ d 0.500000
12
+ illum 2
dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender MTL File: 'None'
2
+ # Material Count: 1
3
+
4
+ newmtl Shape.7525
5
+ Ns 400.000000
6
+ Ka 0.400000 0.400000 0.400000
7
+ Kd 0.272000 0.624000 0.536000
8
+ Ks 0.250000 0.250000 0.250000
9
+ Ke 0.000000 0.000000 0.000000
10
+ Ni 1.000000
11
+ d 0.500000
12
+ illum 2
dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender MTL File: 'None'
2
+ # Material Count: 1
3
+
4
+ newmtl Shape.7496
5
+ Ns 400.000000
6
+ Ka 0.400000 0.400000 0.400000
7
+ Kd 0.720000 0.472000 0.504000
8
+ Ks 0.250000 0.250000 0.250000
9
+ Ke 0.000000 0.000000 0.000000
10
+ Ni 1.000000
11
+ d 0.500000
12
+ illum 2
dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Blender MTL File: 'None'
2
+ # Material Count: 1
3
+
4
+ newmtl Shape.7500
5
+ Ns 400.000000
6
+ Ka 0.400000 0.400000 0.400000
7
+ Kd 0.184000 0.536000 0.280000
8
+ Ks 0.250000 0.250000 0.250000
9
+ Ke 0.000000 0.000000 0.000000
10
+ Ni 1.000000
11
+ d 0.500000
12
+ illum 2
dexart-release/assets/sapien/102697/result.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"text": "Toilet", "name": "Toilet", "id": 0, "children": [{"text": "button", "name": "button", "id": 1, "objs": ["original-8"]}, {"text": "lid", "name": "lid", "id": 2, "objs": ["original-4", "original-3"]}, {"text": "lid", "name": "lid", "id": 3, "objs": ["original-21", "original-20"]}, {"text": "seat", "name": "seat", "id": 4, "objs": ["original-18", "original-17"]}, {"text": "base_body", "name": "base_body", "id": 5, "objs": ["original-13", "original-6", "original-15", "original-11", "original-7", "original-10", "original-14"]}]}]
dexart-release/assets/sapien/102697/result_original.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"id": 0, "name": "Toilet", "text": "Toilet", "children": [{"id": 1, "name": "button", "text": "button", "objs": ["original-8"]}, {"id": 2, "name": "lid", "text": "lid", "objs": ["original-4", "original-3"]}, {"id": 3, "name": "lid", "text": "lid", "objs": ["original-21", "original-20"]}, {"id": 4, "name": "seat", "text": "seat", "objs": ["original-18", "original-17"]}]}]
dexart-release/assets/sapien/102697/semantics.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ link_0 hinge lever
2
+ link_1 slider pump_lid
3
+ link_2 hinge lid
4
+ link_3 hinge seat
5
+ link_4 static toilet_body
dexart-release/dexart.egg-info/PKG-INFO ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: dexart
3
+ Version: 0.1.0
4
+ Summary: UNKNOWN
5
+ Home-page: https://github.com/Kami-code/dexart-release
6
+ Author: Xiaolong Wang's Lab
7
+ License: UNKNOWN
8
+ Platform: UNKNOWN
9
+ License-File: LICENSE
10
+
11
+ UNKNOWN
12
+
dexart-release/dexart.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ dexart.egg-info/PKG-INFO
5
+ dexart.egg-info/SOURCES.txt
6
+ dexart.egg-info/dependency_links.txt
7
+ dexart.egg-info/requires.txt
8
+ dexart.egg-info/top_level.txt
9
+ stable_baselines3/__init__.py
10
+ stable_baselines3/pickle_utils.py
11
+ stable_baselines3/a2c/__init__.py
12
+ stable_baselines3/a2c/a2c.py
13
+ stable_baselines3/a2c/policies.py
14
+ stable_baselines3/common/__init__.py
15
+ stable_baselines3/common/base_class.py
16
+ stable_baselines3/common/buffers.py
17
+ stable_baselines3/common/callbacks.py
18
+ stable_baselines3/common/distributions.py
19
+ stable_baselines3/common/env_util.py
20
+ stable_baselines3/common/evaluation.py
21
+ stable_baselines3/common/logger.py
22
+ stable_baselines3/common/monitor.py
23
+ stable_baselines3/common/noise.py
24
+ stable_baselines3/common/on_policy_algorithm.py
25
+ stable_baselines3/common/policies.py
26
+ stable_baselines3/common/preprocessing.py
27
+ stable_baselines3/common/running_mean_std.py
28
+ stable_baselines3/common/save_util.py
29
+ stable_baselines3/common/torch_layers.py
30
+ stable_baselines3/common/type_aliases.py
31
+ stable_baselines3/common/utils.py
32
+ stable_baselines3/common/vec_env/__init__.py
33
+ stable_baselines3/common/vec_env/base_vec_env.py
34
+ stable_baselines3/common/vec_env/dummy_vec_env.py
35
+ stable_baselines3/common/vec_env/maniskill2_utils_common.py
36
+ stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py
37
+ stable_baselines3/common/vec_env/maniskill2_vec_env.py
38
+ stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py
39
+ stable_baselines3/common/vec_env/stacked_observations.py
40
+ stable_baselines3/common/vec_env/subproc_vec_env.py
41
+ stable_baselines3/common/vec_env/util.py
42
+ stable_baselines3/common/vec_env/vec_check_nan.py
43
+ stable_baselines3/common/vec_env/vec_extract_dict_obs.py
44
+ stable_baselines3/common/vec_env/vec_frame_stack.py
45
+ stable_baselines3/common/vec_env/vec_monitor.py
46
+ stable_baselines3/common/vec_env/vec_normalize.py
47
+ stable_baselines3/common/vec_env/vec_transpose.py
48
+ stable_baselines3/common/vec_env/vec_video_recorder.py
49
+ stable_baselines3/ppo/__init__.py
50
+ stable_baselines3/ppo/policies.py
51
+ stable_baselines3/ppo/ppo.py
dexart-release/dexart.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
dexart-release/dexart.egg-info/requires.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transforms3d
2
+ sapien==2.2.1
3
+ numpy
4
+ open3d>=0.15.1
dexart-release/dexart.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ stable_baselines3
dexart-release/examples/gen_demonstration_expert.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+
6
+ import argparse
7
+ import zarr
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ import pytorch3d.ops as torch3d_ops
12
+ from dexart.env.task_setting import TRAIN_CONFIG, RANDOM_CONFIG
13
+ from dexart.env.create_env import create_env
14
+ from stable_baselines3 import PPO
15
+ # from examples.train import get_3d_policy_kwargs
16
+ from train import get_3d_policy_kwargs
17
+ from tqdm import tqdm
18
+ from termcolor import cprint
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--task_name', type=str, required=True)
24
+ parser.add_argument('--checkpoint_path', type=str, required=True)
25
+ parser.add_argument('--num_episodes', type=int, default=10, help='number of total episodes')
26
+ parser.add_argument('--use_test_set', dest='use_test_set', action='store_true', default=False)
27
+ parser.add_argument('--root_dir', type=str, default='data', help='directory to save data')
28
+ parser.add_argument('--img_size', type=int, default=84, help='image size')
29
+ parser.add_argument('--num_points', type=int, default=1024, help='number of points in point cloud')
30
+ args = parser.parse_args()
31
+ return args
32
+
33
+ def downsample_with_fps(points: np.ndarray, num_points: int = 512):
34
+ # fast point cloud sampling using torch3d
35
+ points = torch.from_numpy(points).unsqueeze(0).cuda()
36
+ num_points = torch.tensor([num_points]).cuda()
37
+ # remember to only use coord to sample
38
+ _, sampled_indices = torch3d_ops.sample_farthest_points(points=points[...,:3], K=num_points)
39
+ points = points.squeeze(0).cpu().numpy()
40
+ points = points[sampled_indices.squeeze(0).cpu().numpy()]
41
+ return points
42
+
43
+ def main():
44
+ args = parse_args()
45
+ task_name = args.task_name
46
+ use_test_set = args.use_test_set
47
+ checkpoint_path = args.checkpoint_path
48
+
49
+
50
+ save_dir = os.path.join(args.root_dir, 'dexart_'+args.task_name+'_expert.zarr')
51
+ if os.path.exists(save_dir):
52
+ cprint('Data already exists at {}'.format(save_dir), 'red')
53
+ cprint("If you want to overwrite, delete the existing directory first.", "red")
54
+ cprint("Do you want to overwrite? (y/n)", "red")
55
+ # user_input = input()
56
+ user_input = 'y'
57
+ if user_input == 'y':
58
+ cprint('Overwriting {}'.format(save_dir), 'red')
59
+ os.system('rm -rf {}'.format(save_dir))
60
+ else:
61
+ cprint('Exiting', 'red')
62
+ return
63
+ os.makedirs(save_dir, exist_ok=True)
64
+
65
+
66
+ if use_test_set:
67
+ indeces = TRAIN_CONFIG[task_name]['unseen']
68
+ cprint(f"using unseen instances {indeces}", 'yellow')
69
+ else:
70
+ indeces = TRAIN_CONFIG[task_name]['seen']
71
+ cprint(f"using seen instances {indeces}", 'yellow')
72
+
73
+ rand_pos = RANDOM_CONFIG[task_name]['rand_pos']
74
+ rand_degree = RANDOM_CONFIG[task_name]['rand_degree']
75
+ env = create_env(task_name=task_name,
76
+ use_visual_obs=True,
77
+ use_gui=False,
78
+ is_eval=True,
79
+ pc_noise=True,
80
+ pc_seg=True,
81
+ index=indeces,
82
+ img_type='robot',
83
+ rand_pos=rand_pos,
84
+ rand_degree=rand_degree)
85
+
86
+ policy = PPO.load(checkpoint_path, env, 'cuda:0',
87
+ policy_kwargs=get_3d_policy_kwargs(extractor_name='smallpn'),
88
+ check_obs_space=False, force_load=True)
89
+
90
+ eval_instances = len(env.instance_list)
91
+ num_episodes = args.num_episodes
92
+ cprint(f"generate {num_episodes} episodes in total", 'yellow')
93
+
94
+ success_list = []
95
+ reward_list = []
96
+
97
+ total_count = 0
98
+ img_arrays = []
99
+ point_cloud_arrays = []
100
+ depth_arrays = []
101
+ state_arrays = []
102
+ imagin_robot_arrays = []
103
+ action_arrays = []
104
+ episode_ends_arrays = []
105
+
106
+
107
+ with tqdm(total=num_episodes) as pbar:
108
+ num_success = 0
109
+ while num_success < num_episodes:
110
+
111
+ # obs dict keys: 'instance_1-seg_gt', 'instance_1-point_cloud',
112
+ # 'instance_1-rgb', 'imagination_robot', 'state', 'oracle_state'
113
+ obs = env.reset()
114
+ eval_success = False
115
+ reward_sum = 0
116
+
117
+ img_arrays_sub = []
118
+ point_cloud_arrays_sub = []
119
+ depth_arrays_sub = []
120
+ state_arrays_sub = []
121
+ imagin_robot_arrays_sub = []
122
+ action_arrays_sub = []
123
+ total_count_sub = 0
124
+ for j in range(env.horizon):
125
+
126
+ if isinstance(obs, dict):
127
+ for key, value in obs.items():
128
+ obs[key] = value[np.newaxis, :]
129
+ else:
130
+ obs = obs[np.newaxis, :]
131
+ action = policy.predict(observation=obs, deterministic=True)[0]
132
+
133
+ # fetch data
134
+ total_count_sub += 1
135
+ obs_state = obs['state'][0] # (32)
136
+ obs_imagin_robot = obs['imagination_robot'][0] # (96,7)
137
+ obs_point_cloud = obs['instance_1-point_cloud'][0] # (1024,3)
138
+ obs_depth = obs['instance_1-depth'][0] # (84,84)
139
+
140
+ if obs_point_cloud.shape[0] > args.num_points:
141
+ obs_point_cloud = downsample_with_fps(obs_point_cloud, num_points=args.num_points)
142
+ obs_image = obs['instance_1-rgb'][0] # (84,84,3), [0,1]
143
+
144
+
145
+ # to 0-255
146
+ obs_image = (obs_image*255).astype(np.uint8)
147
+
148
+ # interpolate to target image size
149
+ if obs_image.shape[0] != args.img_size:
150
+ obs_image = F.interpolate(torch.from_numpy(obs_image).permute(2,0,1).unsqueeze(0),
151
+ size=args.img_size).squeeze().permute(1,2,0).numpy()
152
+ # save data
153
+ img_arrays_sub.append(obs_image)
154
+ imagin_robot_arrays_sub.append(obs_imagin_robot)
155
+ point_cloud_arrays_sub.append(obs_point_cloud)
156
+ depth_arrays_sub.append(obs_depth)
157
+ state_arrays_sub.append(obs_state)
158
+ action_arrays_sub.append(action)
159
+
160
+ # step
161
+ obs, reward, done, _ = env.step(action)
162
+ reward_sum += reward
163
+ if env.is_eval_done:
164
+ eval_success = True
165
+ if done:
166
+ break
167
+
168
+ if eval_success:
169
+ total_count += total_count_sub
170
+ episode_ends_arrays.append(total_count) # the index of the last step of the episode
171
+ reward_list.append(reward_sum)
172
+ success_list.append(int(eval_success))
173
+
174
+ img_arrays.extend(img_arrays_sub)
175
+ imagin_robot_arrays.extend(imagin_robot_arrays_sub)
176
+ point_cloud_arrays.extend(point_cloud_arrays_sub)
177
+ depth_arrays.extend(depth_arrays_sub)
178
+ state_arrays.extend(state_arrays_sub)
179
+ action_arrays.extend(action_arrays_sub)
180
+
181
+ num_success += 1
182
+
183
+ pbar.update(1)
184
+ pbar.set_description(f"reward = {reward_sum}, success = {eval_success}")
185
+ else:
186
+ print("episode failed. continue.")
187
+ continue
188
+
189
+
190
+ cprint(f"reward_mean = {np.mean(reward_list)}, success rate = {np.mean(success_list)}", 'yellow')
191
+
192
+ ###############################
193
+ # save data
194
+ ###############################
195
+ # create zarr file
196
+ zarr_root = zarr.group(save_dir)
197
+ zarr_data = zarr_root.create_group('data')
198
+ zarr_meta = zarr_root.create_group('meta')
199
+ # save img, state, action arrays into data, and episode ends arrays into meta
200
+ img_arrays = np.stack(img_arrays, axis=0)
201
+ if img_arrays.shape[1] == 3: # make channel last
202
+ img_arrays = np.transpose(img_arrays, (0,2,3,1))
203
+ state_arrays = np.stack(state_arrays, axis=0)
204
+ imagin_robot_arrays = np.stack(imagin_robot_arrays, axis=0)
205
+ point_cloud_arrays = np.stack(point_cloud_arrays, axis=0)
206
+ depth_arrays = np.stack(depth_arrays, axis=0)
207
+ action_arrays = np.stack(action_arrays, axis=0)
208
+ episode_ends_arrays = np.array(episode_ends_arrays)
209
+
210
+ compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=1)
211
+ img_chunk_size = (env.horizon, img_arrays.shape[1], img_arrays.shape[2], img_arrays.shape[3])
212
+ imagin_robot_chunk_size = (env.horizon, imagin_robot_arrays.shape[1], imagin_robot_arrays.shape[2])
213
+ point_cloud_chunk_size = (env.horizon, point_cloud_arrays.shape[1], point_cloud_arrays.shape[2])
214
+ depth_chunk_size = (env.horizon, depth_arrays.shape[1], depth_arrays.shape[2])
215
+ state_chunk_size = (env.horizon, state_arrays.shape[1])
216
+ action_chunk_size = (env.horizon, action_arrays.shape[1])
217
+ zarr_data.create_dataset('img', data=img_arrays, chunks=img_chunk_size, dtype='uint8', overwrite=True, compressor=compressor)
218
+ zarr_data.create_dataset('state', data=state_arrays, chunks=state_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
219
+ zarr_data.create_dataset('imagin_robot', data=imagin_robot_arrays, chunks=imagin_robot_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
220
+ zarr_data.create_dataset('point_cloud', data=point_cloud_arrays, chunks=point_cloud_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
221
+ zarr_data.create_dataset('depth', data=depth_arrays, chunks=depth_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
222
+ zarr_data.create_dataset('action', data=action_arrays, chunks=action_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
223
+ zarr_meta.create_dataset('episode_ends', data=episode_ends_arrays, dtype='int64', overwrite=True, compressor=compressor)
224
+
225
+ # print shape
226
+ cprint(f'img shape: {img_arrays.shape}, range: [{np.min(img_arrays)}, {np.max(img_arrays)}]', 'green')
227
+ cprint(f'imagin_robot shape: {imagin_robot_arrays.shape}, range: [{np.min(imagin_robot_arrays)}, {np.max(imagin_robot_arrays)}]', 'green')
228
+ cprint(f'point_cloud shape: {point_cloud_arrays.shape}, range: [{np.min(point_cloud_arrays)}, {np.max(point_cloud_arrays)}]', 'green')
229
+ cprint(f'depth shape: {depth_arrays.shape}, range: [{np.min(depth_arrays)}, {np.max(depth_arrays)}]', 'green')
230
+ cprint(f'state shape: {state_arrays.shape}, range: [{np.min(state_arrays)}, {np.max(state_arrays)}]', 'green')
231
+ cprint(f'action shape: {action_arrays.shape}, range: [{np.min(action_arrays)}, {np.max(action_arrays)}]', 'green')
232
+ cprint(f'Saved zarr file to {save_dir}', 'green')
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()
237
+
238
+
dexart-release/examples/train.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+ import random
5
+ from collections import OrderedDict
6
+ import torch.nn as nn
7
+ import argparse
8
+ from dexart.env.create_env import create_env
9
+ from dexart.env.task_setting import TRAIN_CONFIG, IMG_CONFIG, RANDOM_CONFIG
10
+ from stable_baselines3.common.torch_layers import PointNetImaginationExtractorGP
11
+ from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
12
+ from stable_baselines3.ppo import PPO
13
+ import torch
14
+
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+
17
+
18
+ def get_3d_policy_kwargs(extractor_name):
19
+ feature_extractor_class = PointNetImaginationExtractorGP
20
+ feature_extractor_kwargs = {"pc_key": "instance_1-point_cloud", "gt_key": "instance_1-seg_gt",
21
+ "extractor_name": extractor_name,
22
+ "imagination_keys": [f'imagination_{key}' for key in IMG_CONFIG['robot'].keys()],
23
+ "state_key": "state"}
24
+
25
+ policy_kwargs = {
26
+ "features_extractor_class": feature_extractor_class,
27
+ "features_extractor_kwargs": feature_extractor_kwargs,
28
+ "net_arch": [dict(pi=[64, 64], vf=[64, 64])],
29
+ "activation_fn": nn.ReLU,
30
+ }
31
+ return policy_kwargs
32
+
33
+
34
+ if __name__ == '__main__':
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('--n', type=int, default=10)
37
+ parser.add_argument('--workers', type=int, default=1)
38
+ parser.add_argument('--lr', type=float, default=3e-4)
39
+ parser.add_argument('--ep', type=int, default=10)
40
+ parser.add_argument('--bs', type=int, default=10)
41
+ parser.add_argument('--seed', type=int, default=100)
42
+ parser.add_argument('--iter', type=int, default=1000)
43
+ parser.add_argument('--freeze', dest='freeze', action='store_true', default=False)
44
+ parser.add_argument('--task_name', type=str, default="laptop")
45
+ parser.add_argument('--extractor_name', type=str, default="smallpn")
46
+ parser.add_argument('--pretrain_path', type=str, default=None)
47
+ args = parser.parse_args()
48
+
49
+ task_name = args.task_name
50
+ extractor_name = args.extractor_name
51
+ seed = args.seed if args.seed >= 0 else random.randint(0, 100000)
52
+ pretrain_path = args.pretrain_path
53
+ horizon = 200
54
+ env_iter = args.iter * horizon * args.n
55
+ print(f"freeze: {args.freeze}")
56
+
57
+ rand_pos = RANDOM_CONFIG[task_name]['rand_pos']
58
+ rand_degree = RANDOM_CONFIG[task_name]['rand_degree']
59
+
60
+
61
+ def create_env_fn():
62
+ seen_indeces = TRAIN_CONFIG[task_name]['seen']
63
+ environment = create_env(task_name=task_name,
64
+ use_visual_obs=True,
65
+ use_gui=False,
66
+ is_eval=False,
67
+ pc_noise=True,
68
+ index=seen_indeces,
69
+ img_type='robot',
70
+ rand_pos=rand_pos,
71
+ rand_degree=rand_degree
72
+ )
73
+ return environment
74
+
75
+
76
+ def create_eval_env_fn():
77
+ unseen_indeces = TRAIN_CONFIG[task_name]['unseen']
78
+ environment = create_env(task_name=task_name,
79
+ use_visual_obs=True,
80
+ use_gui=False,
81
+ is_eval=True,
82
+ pc_noise=True,
83
+ index=unseen_indeces,
84
+ img_type='robot',
85
+ rand_pos=rand_pos,
86
+ rand_degree=rand_degree)
87
+ return environment
88
+
89
+
90
+ env = SubprocVecEnv([create_env_fn] * args.workers, "spawn") # train on a list of envs.
91
+
92
+ model = PPO("PointCloudPolicy", env, verbose=1,
93
+ n_epochs=args.ep,
94
+ n_steps=(args.n // args.workers) * horizon,
95
+ learning_rate=args.lr,
96
+ batch_size=args.bs,
97
+ seed=seed,
98
+ policy_kwargs=get_3d_policy_kwargs(extractor_name=extractor_name),
99
+ min_lr=args.lr,
100
+ max_lr=args.lr,
101
+ adaptive_kl=0.02,
102
+ target_kl=0.2,
103
+ )
104
+
105
+ if pretrain_path is not None:
106
+ state_dict: OrderedDict = torch.load(pretrain_path)
107
+ model.policy.features_extractor.extractor.load_state_dict(state_dict, strict=False)
108
+ print("load pretrained model: ", pretrain_path)
109
+
110
+ rollout = int(model.num_timesteps / (horizon * args.n))
111
+
112
+ # after loading or init the model, then freeze it if needed
113
+ if args.freeze:
114
+ model.policy.features_extractor.extractor.eval()
115
+ for param in model.policy.features_extractor.extractor.parameters():
116
+ param.requires_grad = False
117
+ print("freeze model!")
118
+
119
+ model.learn(
120
+ total_timesteps=int(env_iter),
121
+ reset_num_timesteps=False,
122
+ iter_start=rollout,
123
+ callback=None
124
+ )
dexart-release/examples/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: UTF-8 -*-
3
+ import numpy as np
4
+ from dexart.env.task_setting import ROBUSTNESS_INIT_CAMERA_CONFIG
5
+ import open3d as o3d
6
+
7
+ def visualize_observation(obs, use_seg=False, img_type=None):
8
+ def visualize_pc_with_seg_label(cloud):
9
+ pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cloud[:, :3]))
10
+
11
+ def map(feature):
12
+ color = np.zeros((feature.shape[0], 3))
13
+ COLOR20 = np.array(
14
+ [[230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48],
15
+ [145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190],
16
+ [0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0],
17
+ [170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128]]) / 255
18
+ for i in range(feature.shape[0]):
19
+ for j in range(feature.shape[1]):
20
+ if feature[i, j] == 1:
21
+ color[i, :] = COLOR20[j, :]
22
+ return color
23
+
24
+ color = map(cloud[:, 3:])
25
+ pc.colors = o3d.utility.Vector3dVector(color)
26
+ return pc
27
+
28
+ pc = obs["instance_1-point_cloud"]
29
+ if use_seg:
30
+ gt_seg = obs["instance_1-seg_gt"]
31
+ pc = np.concatenate([pc, gt_seg], axis=1)
32
+ pc = visualize_pc_with_seg_label(pc)
33
+ if img_type == "robot":
34
+ robot_pc = obs["imagination_robot"]
35
+ pc += visualize_pc_with_seg_label(robot_pc)
36
+ else:
37
+ raise NotImplementedError
38
+ return pc
39
+
40
+
41
+ def get_viewpoint_camera_parameter():
42
+ robustness_init_camera_config = ROBUSTNESS_INIT_CAMERA_CONFIG['laptop']
43
+ r = robustness_init_camera_config['r']
44
+ phi = robustness_init_camera_config['phi']
45
+ theta = robustness_init_camera_config['theta']
46
+ center = robustness_init_camera_config['center']
47
+
48
+ x0, y0, z0 = center
49
+ # phi in [0, pi/2]
50
+ # theta in [0, 2 * pi]
51
+ x = x0 + r * np.sin(phi) * np.cos(theta)
52
+ y = y0 + r * np.sin(phi) * np.sin(theta)
53
+ z = z0 + r * np.cos(phi)
54
+
55
+ cam_pos = np.array([x, y, z])
56
+ forward = np.array([x0 - x, y0 - y, z0 - z])
57
+ forward /= np.linalg.norm(forward)
58
+
59
+ left = np.cross([0, 0, 1], forward)
60
+ left = left / np.linalg.norm(left)
61
+
62
+ up = np.cross(forward, left)
63
+ mat44 = np.eye(4)
64
+ mat44[:3, :3] = np.stack([forward, left, up], axis=1)
65
+ mat44[:3, 3] = cam_pos
66
+ return cam_pos, center, up, mat44
dexart-release/stable_baselines3/a2c/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from stable_baselines3.a2c.a2c import A2C
2
+ from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
dexart-release/stable_baselines3/a2c/a2c.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Type, Union
2
+
3
+ import torch as th
4
+ from gym import spaces
5
+ from torch.nn import functional as F
6
+
7
+ from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
8
+ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
9
+ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
10
+ from stable_baselines3.common.utils import explained_variance
11
+
12
+
13
+ class A2C(OnPolicyAlgorithm):
14
+ """
15
+ Advantage Actor Critic (A2C)
16
+
17
+ Paper: https://arxiv.org/abs/1602.01783
18
+ Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
19
+ and Stable Baselines (https://github.com/hill-a/stable-baselines)
20
+
21
+ Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
22
+
23
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
24
+ :param env: The environment to learn from (if registered in Gym, can be str)
25
+ :param learning_rate: The learning rate, it can be a function
26
+ of the current progress remaining (from 1 to 0)
27
+ :param n_steps: The number of steps to run for each environment per update
28
+ (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
29
+ :param gamma: Discount factor
30
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
31
+ Equivalent to classic advantage when set to 1.
32
+ :param ent_coef: Entropy coefficient for the loss calculation
33
+ :param vf_coef: Value function coefficient for the loss calculation
34
+ :param max_grad_norm: The maximum value for the gradient clipping
35
+ :param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator
36
+ of RMSProp update
37
+ :param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer
38
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
39
+ instead of action noise exploration (default: False)
40
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
41
+ Default: -1 (only sample at the beginning of the rollout)
42
+ :param normalize_advantage: Whether to normalize or not the advantage
43
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
44
+ :param create_eval_env: Whether to create a second environment that will be
45
+ used for evaluating the agent periodically. (Only available when passing string for the environment)
46
+ :param policy_kwargs: additional arguments to be passed to the policy on creation
47
+ :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
48
+ :param seed: Seed for the pseudo random generators
49
+ :param device: Device (cpu, cuda, ...) on which the code should be run.
50
+ Setting it to auto, the code will be run on the GPU if possible.
51
+ :param _init_setup_model: Whether or not to build the network at the creation of the instance
52
+ """
53
+
54
+ policy_aliases: Dict[str, Type[BasePolicy]] = {
55
+ "MlpPolicy": ActorCriticPolicy,
56
+ "CnnPolicy": ActorCriticCnnPolicy,
57
+ "MultiInputPolicy": MultiInputActorCriticPolicy,
58
+ }
59
+
60
+ def __init__(
61
+ self,
62
+ policy: Union[str, Type[ActorCriticPolicy]],
63
+ env: Union[GymEnv, str],
64
+ learning_rate: Union[float, Schedule] = 7e-4,
65
+ n_steps: int = 5,
66
+ gamma: float = 0.99,
67
+ gae_lambda: float = 1.0,
68
+ ent_coef: float = 0.0,
69
+ vf_coef: float = 0.5,
70
+ max_grad_norm: float = 0.5,
71
+ rms_prop_eps: float = 1e-5,
72
+ use_rms_prop: bool = True,
73
+ use_sde: bool = False,
74
+ sde_sample_freq: int = -1,
75
+ normalize_advantage: bool = False,
76
+ tensorboard_log: Optional[str] = None,
77
+ create_eval_env: bool = False,
78
+ policy_kwargs: Optional[Dict[str, Any]] = None,
79
+ verbose: int = 0,
80
+ seed: Optional[int] = None,
81
+ device: Union[th.device, str] = "auto",
82
+ _init_setup_model: bool = True,
83
+ ):
84
+
85
+ super().__init__(
86
+ policy,
87
+ env,
88
+ learning_rate=learning_rate,
89
+ n_steps=n_steps,
90
+ gamma=gamma,
91
+ gae_lambda=gae_lambda,
92
+ ent_coef=ent_coef,
93
+ vf_coef=vf_coef,
94
+ max_grad_norm=max_grad_norm,
95
+ use_sde=use_sde,
96
+ sde_sample_freq=sde_sample_freq,
97
+ tensorboard_log=tensorboard_log,
98
+ policy_kwargs=policy_kwargs,
99
+ verbose=verbose,
100
+ device=device,
101
+ create_eval_env=create_eval_env,
102
+ seed=seed,
103
+ _init_setup_model=False,
104
+ supported_action_spaces=(
105
+ spaces.Box,
106
+ spaces.Discrete,
107
+ spaces.MultiDiscrete,
108
+ spaces.MultiBinary,
109
+ ),
110
+ )
111
+
112
+ self.normalize_advantage = normalize_advantage
113
+
114
+ # Update optimizer inside the policy if we want to use RMSProp
115
+ # (original implementation) rather than Adam
116
+ if use_rms_prop and "optimizer_class" not in self.policy_kwargs:
117
+ self.policy_kwargs["optimizer_class"] = th.optim.RMSprop
118
+ self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0)
119
+
120
+ if _init_setup_model:
121
+ self._setup_model()
122
+
123
+ def train(self) -> None:
124
+ """
125
+ Update policy using the currently gathered
126
+ rollout buffer (one gradient step over whole data).
127
+ """
128
+ # Switch to train mode (this affects batch norm / dropout)
129
+ self.policy.set_training_mode(True)
130
+
131
+ # Update optimizer learning rate
132
+ self._update_learning_rate(self.policy.optimizer)
133
+
134
+ # This will only loop once (get all data in one go)
135
+ for rollout_data in self.rollout_buffer.get(batch_size=None):
136
+
137
+ actions = rollout_data.actions
138
+ if isinstance(self.action_space, spaces.Discrete):
139
+ # Convert discrete action from float to long
140
+ actions = actions.long().flatten()
141
+
142
+ values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
143
+ values = values.flatten()
144
+
145
+ # Normalize advantage (not present in the original implementation)
146
+ advantages = rollout_data.advantages
147
+ if self.normalize_advantage:
148
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
149
+
150
+ # Policy gradient loss
151
+ policy_loss = -(advantages * log_prob).mean()
152
+
153
+ # Value loss using the TD(gae_lambda) target
154
+ value_loss = F.mse_loss(rollout_data.returns, values)
155
+
156
+ # Entropy loss favor exploration
157
+ if entropy is None:
158
+ # Approximate entropy when no analytical form
159
+ entropy_loss = -th.mean(-log_prob)
160
+ else:
161
+ entropy_loss = -th.mean(entropy)
162
+
163
+ loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
164
+
165
+ # Optimization step
166
+ self.policy.optimizer.zero_grad()
167
+ loss.backward()
168
+
169
+ # Clip grad norm
170
+ th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
171
+ self.policy.optimizer.step()
172
+
173
+ explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
174
+
175
+ self._n_updates += 1
176
+ self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
177
+ self.logger.record("train/explained_variance", explained_var)
178
+ self.logger.record("train/entropy_loss", entropy_loss.item())
179
+ self.logger.record("train/policy_loss", policy_loss.item())
180
+ self.logger.record("train/value_loss", value_loss.item())
181
+ if hasattr(self.policy, "log_std"):
182
+ self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
183
+
184
+ def learn(
185
+ self,
186
+ total_timesteps: int,
187
+ callback: MaybeCallback = None,
188
+ log_interval: int = 100,
189
+ eval_env: Optional[GymEnv] = None,
190
+ eval_freq: int = -1,
191
+ n_eval_episodes: int = 5,
192
+ tb_log_name: str = "A2C",
193
+ eval_log_path: Optional[str] = None,
194
+ reset_num_timesteps: bool = True,
195
+ ) -> "A2C":
196
+
197
+ return super().learn(
198
+ total_timesteps=total_timesteps,
199
+ callback=callback,
200
+ log_interval=log_interval,
201
+ eval_env=eval_env,
202
+ eval_freq=eval_freq,
203
+ n_eval_episodes=n_eval_episodes,
204
+ tb_log_name=tb_log_name,
205
+ eval_log_path=eval_log_path,
206
+ reset_num_timesteps=reset_num_timesteps,
207
+ )
dexart-release/stable_baselines3/a2c/policies.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # This file is here just to define MlpPolicy/CnnPolicy
2
+ # that work for A2C
3
+ from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
4
+
5
+ MlpPolicy = ActorCriticPolicy
6
+ CnnPolicy = ActorCriticCnnPolicy
7
+ MultiInputPolicy = MultiInputActorCriticPolicy
dexart-release/stable_baselines3/common/__init__.py ADDED
File without changes
dexart-release/stable_baselines3/common/base_class.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract base classes for RL algorithms."""
2
+
3
+ import io
4
+ import pathlib
5
+ import time
6
+ from abc import ABC, abstractmethod
7
+ from collections import deque
8
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
9
+
10
+ import gym
11
+ import numpy as np
12
+ import torch as th
13
+
14
+ from stable_baselines3.common import utils
15
+ from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
16
+ from stable_baselines3.common.env_util import is_wrapped
17
+ from stable_baselines3.common.logger import Logger
18
+ from stable_baselines3.common.monitor import Monitor
19
+ from stable_baselines3.common.noise import ActionNoise
20
+ from stable_baselines3.common.policies import BasePolicy
21
+ from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
22
+ from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
23
+ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
24
+ from stable_baselines3.common.utils import (
25
+ check_for_correct_spaces,
26
+ get_device,
27
+ get_schedule_fn,
28
+ get_system_info,
29
+ set_random_seed,
30
+ update_learning_rate,
31
+ )
32
+ from stable_baselines3.common.vec_env import (
33
+ DummyVecEnv,
34
+ VecEnv,
35
+ VecNormalize,
36
+ VecTransposeImage,
37
+ is_vecenv_wrapped,
38
+ unwrap_vec_normalize,
39
+ )
40
+
41
+
42
+ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
43
+ """If env is a string, make the environment; otherwise, return env.
44
+
45
+ :param env: The environment to learn from.
46
+ :param verbose: logging verbosity
47
+ :return A Gym (vector) environment.
48
+ """
49
+ if isinstance(env, str):
50
+ if verbose >= 1:
51
+ print(f"Creating environment from the given name '{env}'")
52
+ env = gym.make(env)
53
+ return env
54
+
55
+
56
+ class BaseAlgorithm(ABC):
57
+ """
58
+ The base of RL algorithms
59
+
60
+ :param policy: Policy object
61
+ :param env: The environment to learn from
62
+ (if registered in Gym, can be str. Can be None for loading trained models)
63
+ :param learning_rate: learning rate for the optimizer,
64
+ it can be a function of the current progress remaining (from 1 to 0)
65
+ :param policy_kwargs: Additional arguments to be passed to the policy on creation
66
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
67
+ :param verbose: The verbosity level: 0 none, 1 training information, 2 debug
68
+ :param device: Device on which the code should run.
69
+ By default, it will try to use a Cuda compatible device and fallback to cpu
70
+ if it is not possible.
71
+ :param support_multi_env: Whether the algorithm supports training
72
+ with multiple environments (as in A2C)
73
+ :param create_eval_env: Whether to create a second environment that will be
74
+ used for evaluating the agent periodically. (Only available when passing string for the environment)
75
+ :param monitor_wrapper: When creating an environment, whether to wrap it
76
+ or not in a Monitor wrapper.
77
+ :param seed: Seed for the pseudo random generators
78
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
79
+ instead of action noise exploration (default: False)
80
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
81
+ Default: -1 (only sample at the beginning of the rollout)
82
+ :param supported_action_spaces: The action spaces supported by the algorithm.
83
+ """
84
+
85
+ # Policy aliases (see _get_policy_from_name())
86
+ policy_aliases: Dict[str, Type[BasePolicy]] = {}
87
+
88
+ def __init__(
89
+ self,
90
+ policy: Type[BasePolicy],
91
+ env: Union[GymEnv, str, None],
92
+ learning_rate: Union[float, Schedule],
93
+ policy_kwargs: Optional[Dict[str, Any]] = None,
94
+ tensorboard_log: Optional[str] = None,
95
+ verbose: int = 0,
96
+ device: Union[th.device, str] = "auto",
97
+ support_multi_env: bool = False,
98
+ create_eval_env: bool = False,
99
+ monitor_wrapper: bool = True,
100
+ seed: Optional[int] = None,
101
+ use_sde: bool = False,
102
+ sde_sample_freq: int = -1,
103
+ supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
104
+ ):
105
+ if isinstance(policy, str):
106
+ self.policy_class = self._get_policy_from_name(policy)
107
+ else:
108
+ self.policy_class = policy
109
+
110
+ self.device = get_device(device)
111
+ if verbose > 0:
112
+ print(f"Using {self.device} device")
113
+
114
+ self.env = None # type: Optional[GymEnv]
115
+ # get VecNormalize object if needed
116
+ self._vec_normalize_env = unwrap_vec_normalize(env)
117
+ self.verbose = verbose
118
+ self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
119
+ self.observation_space = None # type: Optional[gym.spaces.Space]
120
+ self.action_space = None # type: Optional[gym.spaces.Space]
121
+ self.n_envs = None
122
+ self.num_timesteps = 0
123
+ # Used for updating schedules
124
+ self._total_timesteps = 0
125
+ # Used for computing fps, it is updated at each call of learn()
126
+ self._num_timesteps_at_start = 0
127
+ self.eval_env = None
128
+ self.seed = seed
129
+ self.action_noise = None # type: Optional[ActionNoise]
130
+ self.start_time = None
131
+ self.policy = None
132
+ self.learning_rate = learning_rate
133
+ self.tensorboard_log = tensorboard_log
134
+ self.lr_schedule = None # type: Optional[Schedule]
135
+ self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
136
+ self._last_episode_starts = None # type: Optional[np.ndarray]
137
+ # When using VecNormalize:
138
+ self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
139
+ self._episode_num = 0
140
+ # Used for gSDE only
141
+ self.use_sde = use_sde
142
+ self.sde_sample_freq = sde_sample_freq
143
+ # Track the training progress remaining (from 1 to 0)
144
+ # this is used to update the learning rate
145
+ self._current_progress_remaining = 1
146
+ # Buffers for logging
147
+ self.ep_info_buffer = None # type: Optional[deque]
148
+ self.ep_success_buffer = None # type: Optional[deque]
149
+ # For logging (and TD3 delayed updates)
150
+ self._n_updates = 0 # type: int
151
+ # The logger object
152
+ self._logger = None # type: Logger
153
+ # Whether the user passed a custom logger or not
154
+ self._custom_logger = False
155
+
156
+ # Create and wrap the env if needed
157
+ if env is not None:
158
+ if isinstance(env, str):
159
+ if create_eval_env:
160
+ self.eval_env = maybe_make_env(env, self.verbose)
161
+
162
+ env = maybe_make_env(env, self.verbose)
163
+ env = self._wrap_env(env, self.verbose, monitor_wrapper)
164
+
165
+ self.observation_space = env.observation_space
166
+ self.action_space = env.action_space
167
+ self.n_envs = env.num_envs
168
+ self.env = env
169
+
170
+ if supported_action_spaces is not None:
171
+ assert isinstance(self.action_space, supported_action_spaces), (
172
+ f"The algorithm only supports {supported_action_spaces} as action spaces "
173
+ f"but {self.action_space} was provided"
174
+ )
175
+
176
+ if not support_multi_env and self.n_envs > 1:
177
+ raise ValueError(
178
+ "Error: the model does not support multiple envs; it requires " "a single vectorized environment."
179
+ )
180
+
181
+ # Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
182
+ if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
183
+ raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
184
+
185
+ if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
186
+ raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
187
+
188
+ @staticmethod
189
+ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
190
+ """ "
191
+ Wrap environment with the appropriate wrappers if needed.
192
+ For instance, to have a vectorized environment
193
+ or to re-order the image channels.
194
+
195
+ :param env:
196
+ :param verbose:
197
+ :param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible.
198
+ :return: The wrapped environment.
199
+ """
200
+ if not isinstance(env, VecEnv):
201
+ if not is_wrapped(env, Monitor) and monitor_wrapper:
202
+ if verbose >= 1:
203
+ print("Wrapping the env with a `Monitor` wrapper")
204
+ env = Monitor(env)
205
+ if verbose >= 1:
206
+ print("Wrapping the env in a DummyVecEnv.")
207
+ env = DummyVecEnv([lambda: env])
208
+
209
+ # Make sure that dict-spaces are not nested (not supported)
210
+ check_for_nested_spaces(env.observation_space)
211
+
212
+ if not is_vecenv_wrapped(env, VecTransposeImage):
213
+ wrap_with_vectranspose = False
214
+ if isinstance(env.observation_space, gym.spaces.Dict):
215
+ # If even one of the keys is a image-space in need of transpose, apply transpose
216
+ # If the image spaces are not consistent (for instance one is channel first,
217
+ # the other channel last), VecTransposeImage will throw an error
218
+ for space in env.observation_space.spaces.values():
219
+ wrap_with_vectranspose = wrap_with_vectranspose or (
220
+ is_image_space(space) and not is_image_space_channels_first(space)
221
+ )
222
+ else:
223
+ wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
224
+ env.observation_space
225
+ )
226
+
227
+ if wrap_with_vectranspose:
228
+ if verbose >= 1:
229
+ print("Wrapping the env in a VecTransposeImage.")
230
+ env = VecTransposeImage(env)
231
+
232
+ return env
233
+
234
+ @abstractmethod
235
+ def _setup_model(self) -> None:
236
+ """Create networks, buffer and optimizers."""
237
+
238
+ def set_logger(self, logger: Logger) -> None:
239
+ """
240
+ Setter for for logger object.
241
+
242
+ .. warning::
243
+
244
+ When passing a custom logger object,
245
+ this will overwrite ``tensorboard_log`` and ``verbose`` settings
246
+ passed to the constructor.
247
+ """
248
+ self._logger = logger
249
+ # User defined logger
250
+ self._custom_logger = True
251
+
252
+ @property
253
+ def logger(self) -> Logger:
254
+ """Getter for the logger object."""
255
+ return self._logger
256
+
257
+ def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
258
+ """
259
+ Return the environment that will be used for evaluation.
260
+
261
+ :param eval_env:)
262
+ :return:
263
+ """
264
+ if eval_env is None:
265
+ eval_env = self.eval_env
266
+
267
+ if eval_env is not None:
268
+ eval_env = self._wrap_env(eval_env, self.verbose)
269
+ assert eval_env.num_envs == 1
270
+ return eval_env
271
+
272
+ def _setup_lr_schedule(self) -> None:
273
+ """Transform to callable if needed."""
274
+ self.lr_schedule = get_schedule_fn(self.learning_rate)
275
+
276
+ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
277
+ """
278
+ Compute current progress remaining (starts from 1 and ends to 0)
279
+
280
+ :param num_timesteps: current number of timesteps
281
+ :param total_timesteps:
282
+ """
283
+ self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
284
+
285
+ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
286
+ """
287
+ Update the optimizers learning rate using the current learning rate schedule
288
+ and the current progress remaining (from 1 to 0).
289
+
290
+ :param optimizers:
291
+ An optimizer or a list of optimizers.
292
+ """
293
+ # Log the current learning rate
294
+ self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
295
+
296
+ if not isinstance(optimizers, list):
297
+ optimizers = [optimizers]
298
+ for optimizer in optimizers:
299
+ update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
300
+
301
+ def _excluded_save_params(self) -> List[str]:
302
+ """
303
+ Returns the names of the parameters that should be excluded from being
304
+ saved by pickling. E.g. replay buffers are skipped by default
305
+ as they take up a lot of space. PyTorch variables should be excluded
306
+ with this so they can be stored with ``th.save``.
307
+
308
+ :return: List of parameters that should be excluded from being saved with pickle.
309
+ """
310
+ return [
311
+ "policy",
312
+ "device",
313
+ "env",
314
+ "eval_env",
315
+ "replay_buffer",
316
+ "rollout_buffer",
317
+ "_vec_normalize_env",
318
+ "_episode_storage",
319
+ "_logger",
320
+ "_custom_logger",
321
+ ]
322
+
323
+ def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
324
+ """
325
+ Get a policy class from its name representation.
326
+
327
+ The goal here is to standardize policy naming, e.g.
328
+ all algorithms can call upon "MlpPolicy" or "CnnPolicy",
329
+ and they receive respective policies that work for them.
330
+
331
+ :param policy_name: Alias of the policy
332
+ :return: A policy class (type)
333
+ """
334
+
335
+ if policy_name in self.policy_aliases:
336
+ return self.policy_aliases[policy_name]
337
+ else:
338
+ raise ValueError(f"Policy {policy_name} unknown")
339
+
340
+ def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
341
+ """
342
+ Get the name of the torch variables that will be saved with
343
+ PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
344
+ pickling strategy. This is to handle device placement correctly.
345
+
346
+ Names can point to specific variables under classes, e.g.
347
+ "policy.optimizer" would point to ``optimizer`` object of ``self.policy``
348
+ if this object.
349
+
350
+ :return:
351
+ List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
352
+ and list of other Torch variables to store with ``th.save``.
353
+ """
354
+ state_dicts = ["policy"]
355
+
356
+ return state_dicts, []
357
+
358
+ def _init_callback(
359
+ self,
360
+ callback: MaybeCallback,
361
+ eval_env: Optional[VecEnv] = None,
362
+ eval_freq: int = 10000,
363
+ n_eval_episodes: int = 5,
364
+ log_path: Optional[str] = None,
365
+ ) -> BaseCallback:
366
+ """
367
+ :param callback: Callback(s) called at every step with state of the algorithm.
368
+ :param eval_freq: How many steps between evaluations; if None, do not evaluate.
369
+ :param n_eval_episodes: How many episodes to play per evaluation
370
+ :param n_eval_episodes: Number of episodes to rollout during evaluation.
371
+ :param log_path: Path to a folder where the evaluations will be saved
372
+ :return: A hybrid callback calling `callback` and performing evaluation.
373
+ """
374
+ # Convert a list of callbacks into a callback
375
+ if isinstance(callback, list):
376
+ callback = CallbackList(callback)
377
+
378
+ # Convert functional callback to object
379
+ if not isinstance(callback, BaseCallback):
380
+ callback = ConvertCallback(callback)
381
+
382
+ # Create eval callback in charge of the evaluation
383
+ if eval_env is not None:
384
+ eval_callback = EvalCallback(
385
+ eval_env,
386
+ best_model_save_path=log_path,
387
+ log_path=log_path,
388
+ eval_freq=eval_freq,
389
+ n_eval_episodes=n_eval_episodes,
390
+ )
391
+ callback = CallbackList([callback, eval_callback])
392
+
393
+ callback.init_callback(self)
394
+ return callback
395
+
396
+ def _setup_learn(
397
+ self,
398
+ total_timesteps: int,
399
+ eval_env: Optional[GymEnv],
400
+ callback: MaybeCallback = None,
401
+ eval_freq: int = 10000,
402
+ n_eval_episodes: int = 5,
403
+ log_path: Optional[str] = None,
404
+ reset_num_timesteps: bool = True,
405
+ tb_log_name: str = "run",
406
+ ) -> Tuple[int, BaseCallback]:
407
+ """
408
+ Initialize different variables needed for training.
409
+
410
+ :param total_timesteps: The total number of samples (env steps) to train on
411
+ :param eval_env: Environment to use for evaluation.
412
+ :param callback: Callback(s) called at every step with state of the algorithm.
413
+ :param eval_freq: How many steps between evaluations
414
+ :param n_eval_episodes: How many episodes to play per evaluation
415
+ :param log_path: Path to a folder where the evaluations will be saved
416
+ :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
417
+ :param tb_log_name: the name of the run for tensorboard log
418
+ :return:
419
+ """
420
+ self.start_time = time.time()
421
+
422
+ if self.ep_info_buffer is None or reset_num_timesteps:
423
+ # Initialize buffers if they don't exist, or reinitialize if resetting counters
424
+ self.ep_info_buffer = deque(maxlen=100)
425
+ self.ep_success_buffer = deque(maxlen=100)
426
+
427
+ if self.action_noise is not None:
428
+ self.action_noise.reset()
429
+
430
+ if reset_num_timesteps:
431
+ self.num_timesteps = 0
432
+ self._episode_num = 0
433
+ else:
434
+ # Make sure training timesteps are ahead of the internal counter
435
+ total_timesteps += self.num_timesteps
436
+ self._total_timesteps = total_timesteps
437
+ self._num_timesteps_at_start = self.num_timesteps
438
+
439
+ # Avoid resetting the environment when calling ``.learn()`` consecutive times
440
+ if reset_num_timesteps or self._last_obs is None:
441
+ self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
442
+ self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
443
+ # Retrieve unnormalized observation for saving into the buffer
444
+ if self._vec_normalize_env is not None:
445
+ self._last_original_obs = self._vec_normalize_env.get_original_obs()
446
+
447
+ if eval_env is not None and self.seed is not None:
448
+ eval_env.seed(self.seed)
449
+
450
+ eval_env = self._get_eval_env(eval_env)
451
+
452
+ # Configure logger's outputs if no logger was passed
453
+ if not self._custom_logger:
454
+ self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
455
+
456
+ # Create eval callback if needed
457
+ callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
458
+
459
+ return total_timesteps, callback
460
+
461
+ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
462
+ """
463
+ Retrieve reward, episode length, episode success and update the buffer
464
+ if using Monitor wrapper or a GoalEnv.
465
+
466
+ :param infos: List of additional information about the transition.
467
+ :param dones: Termination signals
468
+ """
469
+ if dones is None:
470
+ dones = np.array([False] * len(infos))
471
+ for idx, info in enumerate(infos):
472
+ maybe_ep_info = info.get("episode")
473
+ maybe_is_success = info.get("is_success")
474
+ if maybe_ep_info is not None:
475
+ self.ep_info_buffer.extend([maybe_ep_info])
476
+ if maybe_is_success is not None and dones[idx]:
477
+ self.ep_success_buffer.append(maybe_is_success)
478
+
479
+ def get_env(self) -> Optional[VecEnv]:
480
+ """
481
+ Returns the current environment (can be None if not defined).
482
+
483
+ :return: The current environment
484
+ """
485
+ return self.env
486
+
487
+ def get_vec_normalize_env(self) -> Optional[VecNormalize]:
488
+ """
489
+ Return the ``VecNormalize`` wrapper of the training env
490
+ if it exists.
491
+
492
+ :return: The ``VecNormalize`` env.
493
+ """
494
+ return self._vec_normalize_env
495
+
496
+ def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
497
+ """
498
+ Checks the validity of the environment, and if it is coherent, set it as the current environment.
499
+ Furthermore wrap any non vectorized env into a vectorized
500
+ checked parameters:
501
+ - observation_space
502
+ - action_space
503
+
504
+ :param env: The environment for learning a policy
505
+ :param force_reset: Force call to ``reset()`` before training
506
+ to avoid unexpected behavior.
507
+ See issue https://github.com/DLR-RM/stable-baselines3/issues/597
508
+ """
509
+ # if it is not a VecEnv, make it a VecEnv
510
+ # and do other transformations (dict obs, image transpose) if needed
511
+ env = self._wrap_env(env, self.verbose)
512
+ # Check that the observation spaces match
513
+ check_for_correct_spaces(env, self.observation_space, self.action_space)
514
+ # Update VecNormalize object
515
+ # otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637
516
+ self._vec_normalize_env = unwrap_vec_normalize(env)
517
+
518
+ # Discard `_last_obs`, this will force the env to reset before training
519
+ # See issue https://github.com/DLR-RM/stable-baselines3/issues/597
520
+ if force_reset:
521
+ self._last_obs = None
522
+
523
+ self.n_envs = env.num_envs
524
+ self.env = env
525
+
526
+ @abstractmethod
527
+ def learn(
528
+ self,
529
+ total_timesteps: int,
530
+ callback: MaybeCallback = None,
531
+ log_interval: int = 100,
532
+ tb_log_name: str = "run",
533
+ eval_env: Optional[GymEnv] = None,
534
+ eval_freq: int = -1,
535
+ n_eval_episodes: int = 5,
536
+ eval_log_path: Optional[str] = None,
537
+ reset_num_timesteps: bool = True,
538
+ ) -> "BaseAlgorithm":
539
+ """
540
+ Return a trained model.
541
+
542
+ :param total_timesteps: The total number of samples (env steps) to train on
543
+ :param callback: callback(s) called at every step with state of the algorithm.
544
+ :param log_interval: The number of timesteps before logging.
545
+ :param tb_log_name: the name of the run for TensorBoard logging
546
+ :param eval_env: Environment that will be used to evaluate the agent
547
+ :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
548
+ :param n_eval_episodes: Number of episode to evaluate the agent
549
+ :param eval_log_path: Path to a folder where the evaluations will be saved
550
+ :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
551
+ :return: the trained model
552
+ """
553
+
554
+ def predict(
555
+ self,
556
+ observation: np.ndarray,
557
+ state: Optional[Tuple[np.ndarray, ...]] = None,
558
+ episode_start: Optional[np.ndarray] = None,
559
+ deterministic: bool = False,
560
+ ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
561
+ """
562
+ Get the policy action from an observation (and optional hidden state).
563
+ Includes sugar-coating to handle different observations (e.g. normalizing images).
564
+
565
+ :param observation: the input observation
566
+ :param state: The last hidden states (can be None, used in recurrent policies)
567
+ :param episode_start: The last masks (can be None, used in recurrent policies)
568
+ this correspond to beginning of episodes,
569
+ where the hidden states of the RNN must be reset.
570
+ :param deterministic: Whether or not to return deterministic actions.
571
+ :return: the model's action and the next hidden state
572
+ (used in recurrent policies)
573
+ """
574
+ return self.policy.predict(observation, state, episode_start, deterministic)
575
+
576
+ def set_random_seed(self, seed: Optional[int] = None) -> None:
577
+ """
578
+ Set the seed of the pseudo-random generators
579
+ (python, numpy, pytorch, gym, action_space)
580
+
581
+ :param seed:
582
+ """
583
+ if seed is None:
584
+ return
585
+ set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
586
+ self.action_space.seed(seed)
587
+ if self.env is not None:
588
+ self.env.seed(seed)
589
+ if self.eval_env is not None:
590
+ self.eval_env.seed(seed)
591
+
592
+ def set_parameters(
593
+ self,
594
+ load_path_or_dict: Union[str, Dict[str, Dict]],
595
+ exact_match: bool = True,
596
+ device: Union[th.device, str] = "auto",
597
+ ) -> None:
598
+ """
599
+ Load parameters from a given zip-file or a nested dictionary containing parameters for
600
+ different modules (see ``get_parameters``).
601
+
602
+ :param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested
603
+ dictionary containing nn.Module parameters used by the policy. The dictionary maps
604
+ object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``.
605
+ :param exact_match: If True, the given parameters should include parameters for each
606
+ module and each of their parameters, otherwise raises an Exception. If set to False, this
607
+ can be used to update only specific parameters.
608
+ :param device: Device on which the code should run.
609
+ """
610
+ params = None
611
+ if isinstance(load_path_or_dict, dict):
612
+ params = load_path_or_dict
613
+ else:
614
+ _, params, _ = load_from_zip_file(load_path_or_dict, device=device)
615
+
616
+ # Keep track which objects were updated.
617
+ # `_get_torch_save_params` returns [params, other_pytorch_variables].
618
+ # We are only interested in former here.
619
+ objects_needing_update = set(self._get_torch_save_params()[0])
620
+ updated_objects = set()
621
+
622
+ for name in params:
623
+ attr = None
624
+ try:
625
+ attr = recursive_getattr(self, name)
626
+ except Exception:
627
+ # What errors recursive_getattr could throw? KeyError, but
628
+ # possible something else too (e.g. if key is an int?).
629
+ # Catch anything for now.
630
+ raise ValueError(f"Key {name} is an invalid object name.")
631
+
632
+ if isinstance(attr, th.optim.Optimizer):
633
+ # Optimizers do not support "strict" keyword...
634
+ # Seems like they will just replace the whole
635
+ # optimizer state with the given one.
636
+ # On top of this, optimizer state-dict
637
+ # seems to change (e.g. first ``optim.step()``),
638
+ # which makes comparing state dictionary keys
639
+ # invalid (there is also a nesting of dictionaries
640
+ # with lists with dictionaries with ...), adding to the
641
+ # mess.
642
+ #
643
+ # TL;DR: We might not be able to reliably say
644
+ # if given state-dict is missing keys.
645
+ #
646
+ # Solution: Just load the state-dict as is, and trust
647
+ # the user has provided a sensible state dictionary.
648
+ attr.load_state_dict(params[name])
649
+ else:
650
+ # Assume attr is th.nn.Module
651
+ attr.load_state_dict(params[name], strict=exact_match)
652
+ updated_objects.add(name)
653
+
654
+ if exact_match and updated_objects != objects_needing_update:
655
+ raise ValueError(
656
+ "Names of parameters do not match agents' parameters: "
657
+ f"expected {objects_needing_update}, got {updated_objects}"
658
+ )
659
+
660
+ @classmethod
661
+ def load(
662
+ cls,
663
+ path: Union[str, pathlib.Path, io.BufferedIOBase],
664
+ env: Optional[GymEnv] = None,
665
+ device: Union[th.device, str] = "auto",
666
+ custom_objects: Optional[Dict[str, Any]] = None,
667
+ print_system_info: bool = False,
668
+ force_reset: bool = True,
669
+ check_obs_space: bool = True,
670
+ force_load: bool = False,
671
+ **kwargs,
672
+ ) -> "BaseAlgorithm":
673
+ """
674
+ Load the model from a zip-file.
675
+ Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
676
+ For an in-place load use ``set_parameters`` instead.
677
+
678
+ :param path: path to the file (or a file-like) where to
679
+ load the agent from
680
+ :param env: the new environment to run the loaded model on
681
+ (can be None if you only need prediction from a trained model) has priority over any saved environment
682
+ :param device: Device on which the code should run.
683
+ :param custom_objects: Dictionary of objects to replace
684
+ upon loading. If a variable is present in this dictionary as a
685
+ key, it will not be deserialized and the corresponding item
686
+ will be used instead. Similar to custom_objects in
687
+ ``keras.models.load_model``. Useful when you have an object in
688
+ file that can not be deserialized.
689
+ :param print_system_info: Whether to print system info from the saved model
690
+ and the current system info (useful to debug loading issues)
691
+ :param force_reset: Force call to ``reset()`` before training
692
+ to avoid unexpected behavior.
693
+ See https://github.com/DLR-RM/stable-baselines3/issues/597
694
+ :param kwargs: extra arguments to change the model when loading
695
+ :return: new model instance with loaded parameters
696
+ """
697
+ if print_system_info:
698
+ print("== CURRENT SYSTEM INFO ==")
699
+ get_system_info()
700
+
701
+ data, params, pytorch_variables = load_from_zip_file(
702
+ path, device=device, custom_objects=custom_objects, print_system_info=print_system_info
703
+ )
704
+
705
+ # Remove stored device information and replace with ours
706
+ if "policy_kwargs" in data:
707
+ if "device" in data["policy_kwargs"]:
708
+ del data["policy_kwargs"]["device"]
709
+
710
+ if not force_load:
711
+ if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
712
+ raise ValueError(
713
+ f"The specified policy kwargs do not equal the stored policy kwargs."
714
+ f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
715
+ )
716
+
717
+ if "observation_space" not in data or "action_space" not in data:
718
+ raise KeyError("The observation_space and action_space were not given, can't verify new environments")
719
+
720
+ if env is not None:
721
+ # Wrap first if needed
722
+ env = cls._wrap_env(env, data["verbose"])
723
+ # Check if given env is valid
724
+ if check_obs_space:
725
+ check_for_correct_spaces(env, data["observation_space"], data["action_space"])
726
+ # Discard `_last_obs`, this will force the env to reset before training
727
+ # See issue https://github.com/DLR-RM/stable-baselines3/issues/597
728
+ if force_reset and data is not None:
729
+ data["_last_obs"] = None
730
+ else:
731
+ # Use stored env, if one exists. If not, continue as is (can be used for predict)
732
+ if "env" in data:
733
+ env = data["env"]
734
+
735
+ # noinspection PyArgumentList
736
+ model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
737
+ policy=data["policy_class"],
738
+ env=env,
739
+ device=device,
740
+ _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
741
+ )
742
+
743
+ # load parameters
744
+ if not force_load:
745
+ model.__dict__.update(data) # TODO: Helin cancelled this
746
+ model.__dict__.update(kwargs)
747
+ model._setup_model()
748
+
749
+ # put state_dicts back in place
750
+ model.set_parameters(params, exact_match=True, device=device)
751
+
752
+ # put other pytorch variables back in place
753
+ if pytorch_variables is not None:
754
+ for name in pytorch_variables:
755
+ # Skip if PyTorch variable was not defined (to ensure backward compatibility).
756
+ # This happens when using SAC/TQC.
757
+ # SAC has an entropy coefficient which can be fixed or optimized.
758
+ # If it is optimized, an additional PyTorch variable `log_ent_coef` is defined,
759
+ # otherwise it is initialized to `None`.
760
+ if pytorch_variables[name] is None:
761
+ continue
762
+ # Set the data attribute directly to avoid issue when using optimizers
763
+ # See https://github.com/DLR-RM/stable-baselines3/issues/391
764
+ recursive_setattr(model, name + ".data", pytorch_variables[name].data)
765
+
766
+ # Sample gSDE exploration matrix, so it uses the right device
767
+ # see issue #44
768
+ if model.use_sde:
769
+ model.policy.reset_noise() # pytype: disable=attribute-error
770
+ return model
771
+
772
+ def get_parameters(self) -> Dict[str, Dict]:
773
+ """
774
+ Return the parameters of the agent. This includes parameters from different networks, e.g.
775
+ critics (value functions) and policies (pi functions).
776
+
777
+ :return: Mapping of from names of the objects to PyTorch state-dicts.
778
+ """
779
+ state_dicts_names, _ = self._get_torch_save_params()
780
+ params = {}
781
+ for name in state_dicts_names:
782
+ attr = recursive_getattr(self, name)
783
+ # Retrieve state dict
784
+ params[name] = attr.state_dict()
785
+ return params
786
+
787
+ def save(
788
+ self,
789
+ path: Union[str, pathlib.Path, io.BufferedIOBase],
790
+ exclude: Optional[Iterable[str]] = None,
791
+ include: Optional[Iterable[str]] = None,
792
+ ) -> None:
793
+ """
794
+ Save all the attributes of the object and the model parameters in a zip-file.
795
+
796
+ :param path: path to the file where the rl agent should be saved
797
+ :param exclude: name of parameters that should be excluded in addition to the default ones
798
+ :param include: name of parameters that might be excluded but should be included anyway
799
+ """
800
+ # Copy parameter list so we don't mutate the original dict
801
+ data = self.__dict__.copy()
802
+
803
+ # Exclude is union of specified parameters (if any) and standard exclusions
804
+ if exclude is None:
805
+ exclude = []
806
+ exclude = set(exclude).union(self._excluded_save_params())
807
+
808
+ # Do not exclude params if they are specifically included
809
+ if include is not None:
810
+ exclude = exclude.difference(include)
811
+
812
+ state_dicts_names, torch_variable_names = self._get_torch_save_params()
813
+ all_pytorch_variables = state_dicts_names + torch_variable_names
814
+ for torch_var in all_pytorch_variables:
815
+ # We need to get only the name of the top most module as we'll remove that
816
+ var_name = torch_var.split(".")[0]
817
+ # Any params that are in the save vars must not be saved by data
818
+ exclude.add(var_name)
819
+
820
+ # Remove parameter entries of parameters which are to be excluded
821
+ for param_name in exclude:
822
+ data.pop(param_name, None)
823
+
824
+ # Build dict of torch variables
825
+ pytorch_variables = None
826
+ if torch_variable_names is not None:
827
+ pytorch_variables = {}
828
+ for name in torch_variable_names:
829
+ attr = recursive_getattr(self, name)
830
+ pytorch_variables[name] = attr
831
+
832
+ # Build dict of state_dicts
833
+ params_to_save = self.get_parameters()
834
+
835
+ save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
dexart-release/stable_baselines3/common/buffers.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, Generator, List, Optional, Union
4
+ import stable_baselines3.pickle_utils as pickle_utils
5
+ import numpy as np
6
+ import torch as th
7
+ from gym import spaces
8
+
9
+ from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
10
+ from stable_baselines3.common.type_aliases import (
11
+ DictReplayBufferSamples,
12
+ DictRolloutBufferSamples,
13
+ DictSSLRolloutBufferSamples,
14
+ ReplayBufferSamples,
15
+ RolloutBufferSamples,
16
+ )
17
+
18
+ from stable_baselines3.common.vec_env import VecNormalize
19
+
20
+ try:
21
+ # Check memory used by replay buffer when possible
22
+ import psutil
23
+ except ImportError:
24
+ psutil = None
25
+
26
+
27
+ class BaseBuffer(ABC):
28
+ """
29
+ Base class that represent a buffer (rollout or replay)
30
+
31
+ :param buffer_size: Max number of element in the buffer
32
+ :param observation_space: Observation space
33
+ :param action_space: Action space
34
+ :param device: PyTorch device
35
+ to which the values will be converted
36
+ :param n_envs: Number of parallel environments
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ buffer_size: int,
42
+ observation_space: spaces.Space,
43
+ action_space: spaces.Space,
44
+ device: Union[th.device, str] = "cpu",
45
+ n_envs: int = 1,
46
+ ):
47
+ super().__init__()
48
+ self.buffer_size = buffer_size
49
+ self.observation_space = observation_space
50
+ self.action_space = action_space
51
+ self.obs_shape = get_obs_shape(observation_space)
52
+
53
+ self.action_dim = get_action_dim(action_space)
54
+ self.pos = 0
55
+ self.full = False
56
+ self.device = device
57
+ self.n_envs = n_envs
58
+
59
+ @staticmethod
60
+ def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
61
+ """
62
+ Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
63
+ to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
64
+ to [n_steps * n_envs, ...] (which maintain the order)
65
+
66
+ :param arr:
67
+ :return:
68
+ """
69
+ shape = arr.shape
70
+ if len(shape) < 3:
71
+ shape = shape + (1,)
72
+ return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
73
+
74
+ def size(self) -> int:
75
+ """
76
+ :return: The current size of the buffer
77
+ """
78
+ if self.full:
79
+ return self.buffer_size
80
+ return self.pos
81
+
82
+ def add(self, *args, **kwargs) -> None:
83
+ """
84
+ Add elements to the buffer.
85
+ """
86
+ raise NotImplementedError()
87
+
88
+ def extend(self, *args, **kwargs) -> None:
89
+ """
90
+ Add a new batch of transitions to the buffer
91
+ """
92
+ # Do a for loop along the batch axis
93
+ for data in zip(*args):
94
+ self.add(*data)
95
+
96
+ def reset(self) -> None:
97
+ """
98
+ Reset the buffer.
99
+ """
100
+ self.pos = 0
101
+ self.full = False
102
+
103
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
104
+ """
105
+ :param batch_size: Number of element to sample
106
+ :param env: associated gym VecEnv
107
+ to normalize the observations/rewards when sampling
108
+ :return:
109
+ """
110
+ upper_bound = self.buffer_size if self.full else self.pos
111
+ batch_inds = np.random.randint(0, upper_bound, size=batch_size)
112
+ return self._get_samples(batch_inds, env=env)
113
+
114
+ @abstractmethod
115
+ def _get_samples(
116
+ self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
117
+ ) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
118
+ """
119
+ :param batch_inds:
120
+ :param env:
121
+ :return:
122
+ """
123
+ raise NotImplementedError()
124
+
125
+ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
126
+ """
127
+ Convert a numpy array to a PyTorch tensor.
128
+ Note: it copies the data by default
129
+
130
+ :param array:
131
+ :param copy: Whether to copy or not the data
132
+ (may be useful to avoid changing things be reference)
133
+ :return:
134
+ """
135
+ if copy:
136
+ return th.tensor(array).to(self.device)
137
+ return th.as_tensor(array).to(self.device)
138
+
139
+ @staticmethod
140
+ def _normalize_obs(
141
+ obs: Union[np.ndarray, Dict[str, np.ndarray]],
142
+ env: Optional[VecNormalize] = None,
143
+ ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
144
+ if env is not None:
145
+ return env.normalize_obs(obs)
146
+ return obs
147
+
148
+ @staticmethod
149
+ def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
150
+ if env is not None:
151
+ return env.normalize_reward(reward).astype(np.float32)
152
+ return reward
153
+
154
+
155
+ class ExpertBuffer(BaseBuffer):
156
+ def __init__(self,
157
+ buffer_size: int,
158
+ observation_space: spaces.Space,
159
+ action_space: spaces.Space,
160
+ device: Union[th.device, str] = "cpu",
161
+ n_envs: int = 1,
162
+ dataset_path=''
163
+ ):
164
+ super(ExpertBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
165
+ data = pickle_utils.load_data(dataset_path)
166
+
167
+ data_obs = []
168
+ data_action = []
169
+
170
+ self.optimize_memory_usage = False
171
+
172
+ # print(data[0])
173
+
174
+ for trajectory in data:
175
+ print(trajectory.keys())
176
+ # for k, v in trajectory.items():
177
+ data_obs.append(trajectory['observations'])
178
+ data_action.append(trajectory['actions'])
179
+ self.observations = np.concatenate(data_obs, axis=0)
180
+ self.actions = np.concatenate(data_action, axis=0)
181
+
182
+ assert len(self.observations) == len(self.actions), "Demo Dataset Error: Obs num does not match Action num."
183
+ print('Expert buffer info:', self.observations.shape, self.actions.shape)
184
+ self.buffer_size = len(self.observations)
185
+ self.full = True
186
+
187
+ def add(
188
+ self,
189
+ obs: np.ndarray,
190
+ next_obs: np.ndarray,
191
+ action: np.ndarray,
192
+ reward: np.ndarray,
193
+ done: np.ndarray,
194
+ infos: List[Dict[str, Any]],
195
+ ) -> None:
196
+ assert False, "We do not expect user to use this method."
197
+
198
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
199
+ # Sample randomly the env idx
200
+ data = (
201
+ self._normalize_obs(self.observations[batch_inds, :], env),
202
+ self.actions[batch_inds, :],
203
+ )
204
+ return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
205
+
206
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
207
+ """
208
+ Sample elements from the replay buffer.
209
+ Custom sampling when using memory efficient variant,
210
+ as we should not sample the element with index `self.pos`
211
+ See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
212
+
213
+ :param batch_size: Number of element to sample
214
+ :param env: associated gym VecEnv
215
+ to normalize the observations/rewards when sampling
216
+ :return:
217
+ """
218
+ if not self.optimize_memory_usage:
219
+ return super().sample(batch_size=batch_size, env=env)
220
+ # Do not sample the element with index `self.pos` as the transitions is invalid
221
+ # (we use only one array to store `obs` and `next_obs`)
222
+ if self.full:
223
+ batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
224
+ else:
225
+ batch_inds = np.random.randint(0, self.pos, size=batch_size)
226
+ return self._get_samples(batch_inds, env=env)
227
+
228
+ def get_all_samples(self, env=None):
229
+ data = (
230
+ self._normalize_obs(self.observations, env),
231
+ self.actions,
232
+ )
233
+ return ReplayBufferSamples(*tuple(map(self.to_torch, data)), None, None, None)
234
+
235
+
236
+ class ReplayBuffer(BaseBuffer):
237
+ """
238
+ Replay buffer used in off-policy algorithms like SAC/TD3.
239
+
240
+ :param buffer_size: Max number of element in the buffer
241
+ :param observation_space: Observation space
242
+ :param action_space: Action space
243
+ :param device:
244
+ :param n_envs: Number of parallel environments
245
+ :param optimize_memory_usage: Enable a memory efficient variant
246
+ of the replay buffer which reduces by almost a factor two the memory used,
247
+ at a cost of more complexity.
248
+ See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
249
+ and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
250
+ :param handle_timeout_termination: Handle timeout termination (due to timelimit)
251
+ separately and treat the task as infinite horizon task.
252
+ https://github.com/DLR-RM/stable-baselines3/issues/284
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ buffer_size: int,
258
+ observation_space: spaces.Space,
259
+ action_space: spaces.Space,
260
+ device: Union[th.device, str] = "cpu",
261
+ n_envs: int = 1,
262
+ optimize_memory_usage: bool = False,
263
+ handle_timeout_termination: bool = True,
264
+ ):
265
+ super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
266
+
267
+ # Adjust buffer size
268
+ self.buffer_size = max(buffer_size // n_envs, 1)
269
+
270
+ # Check that the replay buffer can fit into the memory
271
+ if psutil is not None:
272
+ mem_available = psutil.virtual_memory().available
273
+
274
+ self.optimize_memory_usage = optimize_memory_usage
275
+
276
+ self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
277
+
278
+ if optimize_memory_usage:
279
+ # `observations` contains also the next observation
280
+ self.next_observations = None
281
+ else:
282
+ self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
283
+
284
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
285
+
286
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
287
+ self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
288
+ # Handle timeouts termination properly if needed
289
+ # see https://github.com/DLR-RM/stable-baselines3/issues/284
290
+ self.handle_timeout_termination = handle_timeout_termination
291
+ self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
292
+
293
+ if psutil is not None:
294
+ total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
295
+
296
+ if self.next_observations is not None:
297
+ total_memory_usage += self.next_observations.nbytes
298
+
299
+ if total_memory_usage > mem_available:
300
+ # Convert to GB
301
+ total_memory_usage /= 1e9
302
+ mem_available /= 1e9
303
+ warnings.warn(
304
+ "This system does not have apparently enough memory to store the complete "
305
+ f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
306
+ )
307
+
308
+ def add(
309
+ self,
310
+ obs: np.ndarray,
311
+ next_obs: np.ndarray,
312
+ action: np.ndarray,
313
+ reward: np.ndarray,
314
+ done: np.ndarray,
315
+ infos: List[Dict[str, Any]],
316
+ ) -> None:
317
+
318
+ # Reshape needed when using multiple envs with discrete observations
319
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
320
+ if isinstance(self.observation_space, spaces.Discrete):
321
+ obs = obs.reshape((self.n_envs,) + self.obs_shape)
322
+ next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
323
+
324
+ # Same, for actions
325
+ if isinstance(self.action_space, spaces.Discrete):
326
+ action = action.reshape((self.n_envs, self.action_dim))
327
+
328
+ # Copy to avoid modification by reference
329
+ self.observations[self.pos] = np.array(obs).copy()
330
+
331
+ if self.optimize_memory_usage:
332
+ self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
333
+ else:
334
+ self.next_observations[self.pos] = np.array(next_obs).copy()
335
+
336
+ self.actions[self.pos] = np.array(action).copy()
337
+ self.rewards[self.pos] = np.array(reward).copy()
338
+ self.dones[self.pos] = np.array(done).copy()
339
+
340
+ if self.handle_timeout_termination:
341
+ self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
342
+
343
+ self.pos += 1
344
+ if self.pos == self.buffer_size:
345
+ self.full = True
346
+ self.pos = 0
347
+
348
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
349
+ """
350
+ Sample elements from the replay buffer.
351
+ Custom sampling when using memory efficient variant,
352
+ as we should not sample the element with index `self.pos`
353
+ See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
354
+
355
+ :param batch_size: Number of element to sample
356
+ :param env: associated gym VecEnv
357
+ to normalize the observations/rewards when sampling
358
+ :return:
359
+ """
360
+ if not self.optimize_memory_usage:
361
+ return super().sample(batch_size=batch_size, env=env)
362
+ # Do not sample the element with index `self.pos` as the transitions is invalid
363
+ # (we use only one array to store `obs` and `next_obs`)
364
+ if self.full:
365
+ batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
366
+ else:
367
+ batch_inds = np.random.randint(0, self.pos, size=batch_size)
368
+ return self._get_samples(batch_inds, env=env)
369
+
370
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
371
+ # Sample randomly the env idx
372
+ env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
373
+
374
+ if self.optimize_memory_usage:
375
+ next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
376
+ else:
377
+ next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
378
+
379
+ data = (
380
+ self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
381
+ self.actions[batch_inds, env_indices, :],
382
+ next_obs,
383
+ # Only use dones that are not due to timeouts
384
+ # deactivated by default (timeouts is initialized as an array of False)
385
+ (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
386
+ self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
387
+ )
388
+ return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
389
+
390
+
391
+ class RolloutBuffer(BaseBuffer):
392
+ """
393
+ Rollout buffer used in on-policy algorithms like A2C/PPO.
394
+ It corresponds to ``buffer_size`` transitions collected
395
+ using the current policy.
396
+ This experience will be discarded after the policy update.
397
+ In order to use PPO objective, we also store the current value of each state
398
+ and the log probability of each taken action.
399
+
400
+ The term rollout here refers to the model-free notion and should not
401
+ be used with the concept of rollout used in model-based RL or planning.
402
+ Hence, it is only involved in policy and value function training but not action selection.
403
+
404
+ :param buffer_size: Max number of element in the buffer
405
+ :param observation_space: Observation space
406
+ :param action_space: Action space
407
+ :param device:
408
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
409
+ Equivalent to classic advantage when set to 1.
410
+ :param gamma: Discount factor
411
+ :param n_envs: Number of parallel environments
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ buffer_size: int,
417
+ observation_space: spaces.Space,
418
+ action_space: spaces.Space,
419
+ device: Union[th.device, str] = "cpu",
420
+ gae_lambda: float = 1,
421
+ gamma: float = 0.99,
422
+ n_envs: int = 1,
423
+ ):
424
+
425
+ super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
426
+ self.gae_lambda = gae_lambda
427
+ self.gamma = gamma
428
+ self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
429
+ self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
430
+ self.generator_ready = False
431
+ self.reset()
432
+
433
+ def reset(self) -> None:
434
+
435
+ self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
436
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
437
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
438
+ self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
439
+ self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
440
+ self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
441
+ self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
442
+ self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
443
+ self.generator_ready = False
444
+ super().reset()
445
+
446
+ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
447
+ """
448
+ Post-processing step: compute the lambda-return (TD(lambda) estimate)
449
+ and GAE(lambda) advantage.
450
+
451
+ Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
452
+ to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
453
+ where R is the sum of discounted reward with value bootstrap
454
+ (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
455
+
456
+ The TD(lambda) estimator has also two special cases:
457
+ - TD(1) is Monte-Carlo estimate (sum of discounted rewards)
458
+ - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
459
+
460
+ For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
461
+
462
+ :param last_values: state value estimation for the last step (one for each env)
463
+ :param dones: if the last step was a terminal step (one bool for each env).
464
+ """
465
+ # Convert to numpy
466
+ last_values = last_values.clone().cpu().numpy().flatten()
467
+
468
+ last_gae_lam = 0
469
+ for step in reversed(range(self.buffer_size)):
470
+ if step == self.buffer_size - 1:
471
+ next_non_terminal = 1.0 - dones
472
+ next_values = last_values
473
+ else:
474
+ next_non_terminal = 1.0 - self.episode_starts[step + 1]
475
+ next_values = self.values[step + 1]
476
+ delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
477
+ last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
478
+ self.advantages[step] = last_gae_lam
479
+ # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
480
+ # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
481
+ self.returns = self.advantages + self.values
482
+
483
+ def add(
484
+ self,
485
+ obs: np.ndarray,
486
+ action: np.ndarray,
487
+ reward: np.ndarray,
488
+ episode_start: np.ndarray,
489
+ value: th.Tensor,
490
+ log_prob: th.Tensor,
491
+ ) -> None:
492
+ """
493
+ :param obs: Observation
494
+ :param action: Action
495
+ :param reward:
496
+ :param episode_start: Start of episode signal.
497
+ :param value: estimated value of the current state
498
+ following the current policy.
499
+ :param log_prob: log probability of the action
500
+ following the current policy.
501
+ """
502
+ if len(log_prob.shape) == 0:
503
+ # Reshape 0-d tensor to avoid error
504
+ log_prob = log_prob.reshape(-1, 1)
505
+
506
+ # Reshape needed when using multiple envs with discrete observations
507
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
508
+ if isinstance(self.observation_space, spaces.Discrete):
509
+ obs = obs.reshape((self.n_envs,) + self.obs_shape)
510
+
511
+ self.observations[self.pos] = np.array(obs).copy()
512
+ self.actions[self.pos] = np.array(action).copy()
513
+ self.rewards[self.pos] = np.array(reward).copy()
514
+ self.episode_starts[self.pos] = np.array(episode_start).copy()
515
+ self.values[self.pos] = value.clone().cpu().numpy().flatten()
516
+ self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
517
+ self.pos += 1
518
+ if self.pos == self.buffer_size:
519
+ self.full = True
520
+
521
+ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
522
+ assert self.full, ""
523
+ indices = np.random.permutation(self.buffer_size * self.n_envs)
524
+ # Prepare the data
525
+ if not self.generator_ready:
526
+
527
+ _tensor_names = [
528
+ "observations",
529
+ "actions",
530
+ "values",
531
+ "log_probs",
532
+ "advantages",
533
+ "returns",
534
+ ]
535
+
536
+ for tensor in _tensor_names:
537
+ self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
538
+ self.generator_ready = True
539
+
540
+ # Return everything, don't create minibatches
541
+ if batch_size is None:
542
+ batch_size = self.buffer_size * self.n_envs
543
+
544
+ start_idx = 0
545
+ while start_idx < self.buffer_size * self.n_envs:
546
+ yield self._get_samples(indices[start_idx : start_idx + batch_size])
547
+ start_idx += batch_size
548
+
549
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
550
+ data = (
551
+ self.observations[batch_inds],
552
+ self.actions[batch_inds],
553
+ self.values[batch_inds].flatten(),
554
+ self.log_probs[batch_inds].flatten(),
555
+ self.advantages[batch_inds].flatten(),
556
+ self.returns[batch_inds].flatten(),
557
+ )
558
+ return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
559
+
560
+
561
+ class DictReplayBuffer(ReplayBuffer):
562
+ """
563
+ Dict Replay buffer used in off-policy algorithms like SAC/TD3.
564
+ Extends the ReplayBuffer to use dictionary observations
565
+
566
+ :param buffer_size: Max number of element in the buffer
567
+ :param observation_space: Observation space
568
+ :param action_space: Action space
569
+ :param device:
570
+ :param n_envs: Number of parallel environments
571
+ :param optimize_memory_usage: Enable a memory efficient variant
572
+ Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
573
+ :param handle_timeout_termination: Handle timeout termination (due to timelimit)
574
+ separately and treat the task as infinite horizon task.
575
+ https://github.com/DLR-RM/stable-baselines3/issues/284
576
+ """
577
+
578
+ def __init__(
579
+ self,
580
+ buffer_size: int,
581
+ observation_space: spaces.Space,
582
+ action_space: spaces.Space,
583
+ device: Union[th.device, str] = "cpu",
584
+ n_envs: int = 1,
585
+ optimize_memory_usage: bool = False,
586
+ handle_timeout_termination: bool = True,
587
+ ):
588
+ super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
589
+
590
+ assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
591
+ self.buffer_size = max(buffer_size // n_envs, 1)
592
+
593
+ # Check that the replay buffer can fit into the memory
594
+ if psutil is not None:
595
+ mem_available = psutil.virtual_memory().available
596
+
597
+ assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
598
+ # disabling as this adds quite a bit of complexity
599
+ # https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
600
+ self.optimize_memory_usage = optimize_memory_usage
601
+
602
+ self.observations = {
603
+ key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
604
+ for key, _obs_shape in self.obs_shape.items()
605
+ }
606
+ self.next_observations = {
607
+ key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
608
+ for key, _obs_shape in self.obs_shape.items()
609
+ }
610
+
611
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
612
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
613
+ self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
614
+
615
+ # Handle timeouts termination properly if needed
616
+ # see https://github.com/DLR-RM/stable-baselines3/issues/284
617
+ self.handle_timeout_termination = handle_timeout_termination
618
+ self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
619
+
620
+ if psutil is not None:
621
+ obs_nbytes = 0
622
+ for _, obs in self.observations.items():
623
+ obs_nbytes += obs.nbytes
624
+
625
+ total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
626
+ if self.next_observations is not None:
627
+ next_obs_nbytes = 0
628
+ for _, obs in self.observations.items():
629
+ next_obs_nbytes += obs.nbytes
630
+ total_memory_usage += next_obs_nbytes
631
+
632
+ if total_memory_usage > mem_available:
633
+ # Convert to GB
634
+ total_memory_usage /= 1e9
635
+ mem_available /= 1e9
636
+ warnings.warn(
637
+ "This system does not have apparently enough memory to store the complete "
638
+ f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
639
+ )
640
+
641
+ def add(
642
+ self,
643
+ obs: Dict[str, np.ndarray],
644
+ next_obs: Dict[str, np.ndarray],
645
+ action: np.ndarray,
646
+ reward: np.ndarray,
647
+ done: np.ndarray,
648
+ infos: List[Dict[str, Any]],
649
+ ) -> None:
650
+ # Copy to avoid modification by reference
651
+ for key in self.observations.keys():
652
+ # Reshape needed when using multiple envs with discrete observations
653
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
654
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
655
+ obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
656
+ self.observations[key][self.pos] = np.array(obs[key])
657
+
658
+ for key in self.next_observations.keys():
659
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
660
+ next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
661
+ self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
662
+
663
+ # Same reshape, for actions
664
+ if isinstance(self.action_space, spaces.Discrete):
665
+ action = action.reshape((self.n_envs, self.action_dim))
666
+
667
+ self.actions[self.pos] = np.array(action).copy()
668
+ self.rewards[self.pos] = np.array(reward).copy()
669
+ self.dones[self.pos] = np.array(done).copy()
670
+
671
+ if self.handle_timeout_termination:
672
+ self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
673
+
674
+ self.pos += 1
675
+ if self.pos == self.buffer_size:
676
+ self.full = True
677
+ self.pos = 0
678
+
679
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
680
+ """
681
+ Sample elements from the replay buffer.
682
+
683
+ :param batch_size: Number of element to sample
684
+ :param env: associated gym VecEnv
685
+ to normalize the observations/rewards when sampling
686
+ :return:
687
+ """
688
+ return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
689
+
690
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
691
+ # Sample randomly the env idx
692
+ env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
693
+
694
+ # Normalize if needed and remove extra dimension (we are using only one env for now)
695
+ obs_ = self._normalize_obs({key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}, env)
696
+ next_obs_ = self._normalize_obs(
697
+ {key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env
698
+ )
699
+
700
+ # Convert to torch tensor
701
+ observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
702
+ next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
703
+
704
+ return DictReplayBufferSamples(
705
+ observations=observations,
706
+ actions=self.to_torch(self.actions[batch_inds, env_indices]),
707
+ next_observations=next_observations,
708
+ # Only use dones that are not due to timeouts
709
+ # deactivated by default (timeouts is initialized as an array of False)
710
+ dones=self.to_torch(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(
711
+ -1, 1
712
+ ),
713
+ rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)),
714
+ )
715
+
716
+
717
+ class DictRolloutBuffer(RolloutBuffer):
718
+ """
719
+ Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
720
+ Extends the RolloutBuffer to use dictionary observations
721
+
722
+ It corresponds to ``buffer_size`` transitions collected
723
+ using the current policy.
724
+ This experience will be discarded after the policy update.
725
+ In order to use PPO objective, we also store the current value of each state
726
+ and the log probability of each taken action.
727
+
728
+ The term rollout here refers to the model-free notion and should not
729
+ be used with the concept of rollout used in model-based RL or planning.
730
+ Hence, it is only involved in policy and value function training but not action selection.
731
+
732
+ :param buffer_size: Max number of element in the buffer
733
+ :param observation_space: Observation space
734
+ :param action_space: Action space
735
+ :param device:
736
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
737
+ Equivalent to Monte-Carlo advantage estimate when set to 1.
738
+ :param gamma: Discount factor
739
+ :param n_envs: Number of parallel environments
740
+ """
741
+
742
+ def __init__(
743
+ self,
744
+ buffer_size: int,
745
+ observation_space: spaces.Space,
746
+ action_space: spaces.Space,
747
+ device: Union[th.device, str] = "cpu",
748
+ gae_lambda: float = 1,
749
+ gamma: float = 0.99,
750
+ n_envs: int = 1,
751
+ ):
752
+
753
+ super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
754
+
755
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
756
+
757
+ self.gae_lambda = gae_lambda
758
+ self.gamma = gamma
759
+ self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
760
+ self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
761
+ self.generator_ready = False
762
+ self.reset()
763
+
764
+ def reset(self) -> None:
765
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
766
+ self.observations = {}
767
+ for key, obs_input_shape in self.obs_shape.items():
768
+ self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
769
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
770
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
771
+ self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
772
+ self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
773
+ self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
774
+ self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
775
+ self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
776
+ self.generator_ready = False
777
+ super(RolloutBuffer, self).reset()
778
+
779
+ def add(
780
+ self,
781
+ obs: Dict[str, np.ndarray],
782
+ action: np.ndarray,
783
+ reward: np.ndarray,
784
+ episode_start: np.ndarray,
785
+ value: th.Tensor,
786
+ log_prob: th.Tensor,
787
+ ) -> None:
788
+ """
789
+ :param obs: Observation
790
+ :param action: Action
791
+ :param reward:
792
+ :param episode_start: Start of episode signal.
793
+ :param value: estimated value of the current state
794
+ following the current policy.
795
+ :param log_prob: log probability of the action
796
+ following the current policy.
797
+ """
798
+ if len(log_prob.shape) == 0:
799
+ # Reshape 0-d tensor to avoid error
800
+ log_prob = log_prob.reshape(-1, 1)
801
+
802
+ for key in self.observations.keys():
803
+ obs_ = np.array(obs[key]).copy()
804
+ # Reshape needed when using multiple envs with discrete observations
805
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
806
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
807
+ obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
808
+ self.observations[key][self.pos] = obs_
809
+
810
+ self.actions[self.pos] = np.array(action).copy()
811
+ self.rewards[self.pos] = np.array(reward).copy()
812
+ self.episode_starts[self.pos] = np.array(episode_start).copy()
813
+ self.values[self.pos] = value.clone().cpu().numpy().flatten()
814
+ self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
815
+ self.pos += 1
816
+ if self.pos == self.buffer_size:
817
+ self.full = True
818
+
819
+ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
820
+ assert self.full, ""
821
+ indices = np.random.permutation(self.buffer_size * self.n_envs)
822
+ # Prepare the data
823
+ if not self.generator_ready:
824
+
825
+ for key, obs in self.observations.items():
826
+ self.observations[key] = self.swap_and_flatten(obs)
827
+
828
+ _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
829
+
830
+ for tensor in _tensor_names:
831
+ self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
832
+ self.generator_ready = True
833
+
834
+ # Return everything, don't create minibatches
835
+ if batch_size is None:
836
+ batch_size = self.buffer_size * self.n_envs
837
+
838
+ start_idx = 0
839
+ while start_idx < self.buffer_size * self.n_envs:
840
+ yield self._get_samples(indices[start_idx : start_idx + batch_size])
841
+ start_idx += batch_size
842
+
843
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples:
844
+
845
+ return DictRolloutBufferSamples(
846
+ observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
847
+ actions=self.to_torch(self.actions[batch_inds]),
848
+ old_values=self.to_torch(self.values[batch_inds].flatten()),
849
+ old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
850
+ advantages=self.to_torch(self.advantages[batch_inds].flatten()),
851
+ returns=self.to_torch(self.returns[batch_inds].flatten()),
852
+ )
853
+
854
+
855
+ class DictSSLRolloutBuffer(RolloutBuffer):
856
+ """
857
+ Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
858
+ Extends the RolloutBuffer to use dictionary observations
859
+
860
+ It corresponds to ``buffer_size`` transitions collected
861
+ using the current policy.
862
+ This experience will be discarded after the policy update.
863
+ In order to use PPO objective, we also store the current value of each state
864
+ and the log probability of each taken action.
865
+
866
+ The term rollout here refers to the model-free notion and should not
867
+ be used with the concept of rollout used in model-based RL or planning.
868
+ Hence, it is only involved in policy and value function training but not action selection.
869
+
870
+ :param buffer_size: Max number of element in the buffer
871
+ :param observation_space: Observation space
872
+ :param action_space: Action space
873
+ :param device:
874
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
875
+ Equivalent to Monte-Carlo advantage estimate when set to 1.
876
+ :param gamma: Discount factor
877
+ :param n_envs: Number of parallel environments
878
+ """
879
+
880
+ def __init__(
881
+ self,
882
+ buffer_size: int,
883
+ observation_space: spaces.Space,
884
+ action_space: spaces.Space,
885
+ device: Union[th.device, str] = "cpu",
886
+ gae_lambda: float = 1,
887
+ gamma: float = 0.99,
888
+ n_envs: int = 1,
889
+ ):
890
+
891
+ super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
892
+
893
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
894
+
895
+ self.gae_lambda = gae_lambda
896
+ self.gamma = gamma
897
+ self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
898
+ self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
899
+ self.generator_ready = False
900
+ self.reset()
901
+
902
+ def reset(self) -> None:
903
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
904
+ self.observations = {}
905
+ for key, obs_input_shape in self.obs_shape.items():
906
+ self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
907
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
908
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
909
+ self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
910
+ self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
911
+ self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
912
+ self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
913
+ self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
914
+ self.generator_ready = False
915
+ super(RolloutBuffer, self).reset()
916
+
917
+ def add(
918
+ self,
919
+ obs: Dict[str, np.ndarray],
920
+ action: np.ndarray,
921
+ reward: np.ndarray,
922
+ episode_start: np.ndarray,
923
+ value: th.Tensor,
924
+ log_prob: th.Tensor,
925
+ ) -> None:
926
+ """
927
+ :param obs: Observation
928
+ :param action: Action
929
+ :param reward:
930
+ :param episode_start: Start of episode signal.
931
+ :param value: estimated value of the current state
932
+ following the current policy.
933
+ :param log_prob: log probability of the action
934
+ following the current policy.
935
+ """
936
+ if len(log_prob.shape) == 0:
937
+ # Reshape 0-d tensor to avoid error
938
+ log_prob = log_prob.reshape(-1, 1)
939
+
940
+ for key in self.observations.keys():
941
+ obs_ = np.array(obs[key]).copy()
942
+ # Reshape needed when using multiple envs with discrete observations
943
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
944
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
945
+ obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
946
+ self.observations[key][self.pos] = obs_
947
+
948
+ self.actions[self.pos] = np.array(action).copy()
949
+ self.rewards[self.pos] = np.array(reward).copy()
950
+ self.episode_starts[self.pos] = np.array(episode_start).copy()
951
+ self.values[self.pos] = value.clone().cpu().numpy().flatten()
952
+ self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
953
+ self.pos += 1
954
+ if self.pos == self.buffer_size:
955
+ self.full = True
956
+
957
+ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
958
+ assert self.full, ""
959
+ indices = np.random.permutation(self.buffer_size * self.n_envs)
960
+ # Prepare the data
961
+ if not self.generator_ready:
962
+
963
+ for key, obs in self.observations.items():
964
+ self.observations[key] = self.swap_and_flatten(obs)
965
+
966
+ _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
967
+
968
+ for tensor in _tensor_names:
969
+ self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
970
+ self.generator_ready = True
971
+
972
+ # Return everything, don't create minibatches
973
+ if batch_size is None:
974
+ batch_size = self.buffer_size * self.n_envs
975
+
976
+ start_idx = 0
977
+ while start_idx < self.buffer_size * self.n_envs:
978
+ yield self._get_samples(indices[start_idx : start_idx + batch_size])
979
+ start_idx += batch_size
980
+
981
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictSSLRolloutBufferSamples:
982
+
983
+ def get_next_obs(obs, ind, i):
984
+ result = {}
985
+ for key, obs in obs.items():
986
+ future_batch_inds = np.clip(ind + i, 0, len(obs)-1)
987
+ next_obs = self.to_torch(obs[future_batch_inds])
988
+ result[key] = next_obs
989
+
990
+ return result
991
+
992
+ def get_next_action(actions, ind, i):
993
+ future_batch_inds = np.clip(ind + i, 0, len(actions) - 1)
994
+ return self.to_torch(actions[future_batch_inds])
995
+
996
+
997
+ next_observations = [get_next_obs(self.observations, batch_inds, i) for i in range(4)]
998
+ next_actions = [get_next_action(self.actions, batch_inds, i) for i in range(4)]
999
+
1000
+ return DictSSLRolloutBufferSamples(
1001
+ observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
1002
+ next_observations=next_observations,
1003
+ actions=self.to_torch(self.actions[batch_inds]),
1004
+ next_actions=next_actions,
1005
+ old_values=self.to_torch(self.values[batch_inds].flatten()),
1006
+ old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
1007
+ advantages=self.to_torch(self.advantages[batch_inds].flatten()),
1008
+ returns=self.to_torch(self.returns[batch_inds].flatten()),
1009
+ )
1010
+
dexart-release/stable_baselines3/common/callbacks.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
+
6
+ import gym
7
+ import numpy as np
8
+
9
+ from stable_baselines3.common import base_class # pytype: disable=pyi-error
10
+ from stable_baselines3.common.evaluation import evaluate_policy
11
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
12
+
13
+
14
+ class BaseCallback(ABC):
15
+ """
16
+ Base class for callback.
17
+
18
+ :param verbose:
19
+ """
20
+
21
+ def __init__(self, verbose: int = 0):
22
+ super().__init__()
23
+ # The RL model
24
+ self.model = None # type: Optional[base_class.BaseAlgorithm]
25
+ # An alias for self.model.get_env(), the environment used for training
26
+ self.training_env = None # type: Union[gym.Env, VecEnv, None]
27
+ # Number of time the callback was called
28
+ self.n_calls = 0 # type: int
29
+ # n_envs * n times env.step() was called
30
+ self.num_timesteps = 0 # type: int
31
+ self.verbose = verbose
32
+ self.locals: Dict[str, Any] = {}
33
+ self.globals: Dict[str, Any] = {}
34
+ self.logger = None
35
+ # Sometimes, for event callback, it is useful
36
+ # to have access to the parent object
37
+ self.parent = None # type: Optional[BaseCallback]
38
+
39
+ # Type hint as string to avoid circular import
40
+ def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
41
+ """
42
+ Initialize the callback by saving references to the
43
+ RL model and the training environment for convenience.
44
+ """
45
+ self.model = model
46
+ self.training_env = model.get_env()
47
+ self.logger = model.logger
48
+ self._init_callback()
49
+
50
+ def _init_callback(self) -> None:
51
+ pass
52
+
53
+ def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
54
+ # Those are reference and will be updated automatically
55
+ self.locals = locals_
56
+ self.globals = globals_
57
+ self._on_training_start()
58
+
59
+ def _on_training_start(self) -> None:
60
+ pass
61
+
62
+ def on_rollout_start(self) -> None:
63
+ self._on_rollout_start()
64
+
65
+ def _on_rollout_start(self) -> None:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def _on_step(self) -> bool:
70
+ """
71
+ :return: If the callback returns False, training is aborted early.
72
+ """
73
+ return True
74
+
75
+ def on_step(self) -> bool:
76
+ """
77
+ This method will be called by the model after each call to ``env.step()``.
78
+
79
+ For child callback (of an ``EventCallback``), this will be called
80
+ when the event is triggered.
81
+
82
+ :return: If the callback returns False, training is aborted early.
83
+ """
84
+ self.n_calls += 1
85
+ # timesteps start at zero
86
+ self.num_timesteps = self.model.num_timesteps
87
+
88
+ return self._on_step()
89
+
90
+ def on_training_end(self) -> None:
91
+ self._on_training_end()
92
+
93
+ def _on_training_end(self) -> None:
94
+ pass
95
+
96
+ def on_rollout_end(self) -> None:
97
+ self._on_rollout_end()
98
+
99
+ def _on_rollout_end(self) -> None:
100
+ pass
101
+
102
+ def update_locals(self, locals_: Dict[str, Any]) -> None:
103
+ """
104
+ Update the references to the local variables.
105
+
106
+ :param locals_: the local variables during rollout collection
107
+ """
108
+ self.locals.update(locals_)
109
+ self.update_child_locals(locals_)
110
+
111
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
112
+ """
113
+ Update the references to the local variables on sub callbacks.
114
+
115
+ :param locals_: the local variables during rollout collection
116
+ """
117
+ pass
118
+
119
+
120
+ class EventCallback(BaseCallback):
121
+ """
122
+ Base class for triggering callback on event.
123
+
124
+ :param callback: Callback that will be called
125
+ when an event is triggered.
126
+ :param verbose:
127
+ """
128
+
129
+ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
130
+ super().__init__(verbose=verbose)
131
+ self.callback = callback
132
+ # Give access to the parent
133
+ if callback is not None:
134
+ self.callback.parent = self
135
+
136
+ def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
137
+ super().init_callback(model)
138
+ if self.callback is not None:
139
+ self.callback.init_callback(self.model)
140
+
141
+ def _on_training_start(self) -> None:
142
+ if self.callback is not None:
143
+ self.callback.on_training_start(self.locals, self.globals)
144
+
145
+ def _on_event(self) -> bool:
146
+ if self.callback is not None:
147
+ return self.callback.on_step()
148
+ return True
149
+
150
+ def _on_step(self) -> bool:
151
+ return True
152
+
153
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
154
+ """
155
+ Update the references to the local variables.
156
+
157
+ :param locals_: the local variables during rollout collection
158
+ """
159
+ if self.callback is not None:
160
+ self.callback.update_locals(locals_)
161
+
162
+
163
+ class CallbackList(BaseCallback):
164
+ """
165
+ Class for chaining callbacks.
166
+
167
+ :param callbacks: A list of callbacks that will be called
168
+ sequentially.
169
+ """
170
+
171
+ def __init__(self, callbacks: List[BaseCallback]):
172
+ super().__init__()
173
+ assert isinstance(callbacks, list)
174
+ self.callbacks = callbacks
175
+
176
+ def _init_callback(self) -> None:
177
+ for callback in self.callbacks:
178
+ callback.init_callback(self.model)
179
+
180
+ def _on_training_start(self) -> None:
181
+ for callback in self.callbacks:
182
+ callback.on_training_start(self.locals, self.globals)
183
+
184
+ def _on_rollout_start(self) -> None:
185
+ for callback in self.callbacks:
186
+ callback.on_rollout_start()
187
+
188
+ def _on_step(self) -> bool:
189
+ continue_training = True
190
+ for callback in self.callbacks:
191
+ # Return False (stop training) if at least one callback returns False
192
+ continue_training = callback.on_step() and continue_training
193
+ return continue_training
194
+
195
+ def _on_rollout_end(self) -> None:
196
+ for callback in self.callbacks:
197
+ callback.on_rollout_end()
198
+
199
+ def _on_training_end(self) -> None:
200
+ for callback in self.callbacks:
201
+ callback.on_training_end()
202
+
203
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
204
+ """
205
+ Update the references to the local variables.
206
+
207
+ :param locals_: the local variables during rollout collection
208
+ """
209
+ for callback in self.callbacks:
210
+ callback.update_locals(locals_)
211
+
212
+
213
+ class CheckpointCallback(BaseCallback):
214
+ """
215
+ Callback for saving a model every ``save_freq`` calls
216
+ to ``env.step()``.
217
+
218
+ .. warning::
219
+
220
+ When using multiple environments, each call to ``env.step()``
221
+ will effectively correspond to ``n_envs`` steps.
222
+ To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
223
+
224
+ :param save_freq:
225
+ :param save_path: Path to the folder where the model will be saved.
226
+ :param name_prefix: Common prefix to the saved models
227
+ :param verbose:
228
+ """
229
+
230
+ def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
231
+ super().__init__(verbose)
232
+ self.save_freq = save_freq
233
+ self.save_path = save_path
234
+ self.name_prefix = name_prefix
235
+
236
+ def _init_callback(self) -> None:
237
+ # Create folder if needed
238
+ if self.save_path is not None:
239
+ os.makedirs(self.save_path, exist_ok=True)
240
+
241
+ def _on_step(self) -> bool:
242
+ if self.n_calls % self.save_freq == 0:
243
+ path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
244
+ self.model.save(path)
245
+ if self.verbose > 1:
246
+ print(f"Saving model checkpoint to {path}")
247
+ return True
248
+
249
+
250
+ class ConvertCallback(BaseCallback):
251
+ """
252
+ Convert functional callback (old-style) to object.
253
+
254
+ :param callback:
255
+ :param verbose:
256
+ """
257
+
258
+ def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
259
+ super().__init__(verbose)
260
+ self.callback = callback
261
+
262
+ def _on_step(self) -> bool:
263
+ if self.callback is not None:
264
+ return self.callback(self.locals, self.globals)
265
+ return True
266
+
267
+
268
+ class EvalCallback(EventCallback):
269
+ """
270
+ Callback for evaluating an agent.
271
+
272
+ .. warning::
273
+
274
+ When using multiple environments, each call to ``env.step()``
275
+ will effectively correspond to ``n_envs`` steps.
276
+ To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``
277
+
278
+ :param eval_env: The environment used for initialization
279
+ :param callback_on_new_best: Callback to trigger
280
+ when there is a new best model according to the ``mean_reward``
281
+ :param callback_after_eval: Callback to trigger after every evaluation
282
+ :param n_eval_episodes: The number of episodes to test the agent
283
+ :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
284
+ :param log_path: Path to a folder where the evaluations (``evaluations.npz``)
285
+ will be saved. It will be updated at each evaluation.
286
+ :param best_model_save_path: Path to a folder where the best model
287
+ according to performance on the eval env will be saved.
288
+ :param deterministic: Whether the evaluation should
289
+ use a stochastic or deterministic actions.
290
+ :param render: Whether to render or not the environment during evaluation
291
+ :param verbose:
292
+ :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
293
+ wrapped with a Monitor wrapper)
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ eval_env: Union[gym.Env, VecEnv],
299
+ callback_on_new_best: Optional[BaseCallback] = None,
300
+ callback_after_eval: Optional[BaseCallback] = None,
301
+ n_eval_episodes: int = 5,
302
+ eval_freq: int = 10000,
303
+ log_path: Optional[str] = None,
304
+ best_model_save_path: Optional[str] = None,
305
+ deterministic: bool = True,
306
+ render: bool = False,
307
+ verbose: int = 1,
308
+ warn: bool = True,
309
+ ):
310
+ super().__init__(callback_after_eval, verbose=verbose)
311
+
312
+ self.callback_on_new_best = callback_on_new_best
313
+ if self.callback_on_new_best is not None:
314
+ # Give access to the parent
315
+ self.callback_on_new_best.parent = self
316
+
317
+ self.n_eval_episodes = n_eval_episodes
318
+ self.eval_freq = eval_freq
319
+ self.best_mean_reward = -np.inf
320
+ self.last_mean_reward = -np.inf
321
+ self.deterministic = deterministic
322
+ self.render = render
323
+ self.warn = warn
324
+
325
+ # Convert to VecEnv for consistency
326
+ if not isinstance(eval_env, VecEnv):
327
+ eval_env = DummyVecEnv([lambda: eval_env])
328
+
329
+ self.eval_env = eval_env
330
+ self.best_model_save_path = best_model_save_path
331
+ # Logs will be written in ``evaluations.npz``
332
+ if log_path is not None:
333
+ log_path = os.path.join(log_path, "evaluations")
334
+ self.log_path = log_path
335
+ self.evaluations_results = []
336
+ self.evaluations_timesteps = []
337
+ self.evaluations_length = []
338
+ # For computing success rate
339
+ self._is_success_buffer = []
340
+ self.evaluations_successes = []
341
+
342
+ def _init_callback(self) -> None:
343
+ # Does not work in some corner cases, where the wrapper is not the same
344
+ if not isinstance(self.training_env, type(self.eval_env)):
345
+ warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
346
+
347
+ # Create folders if needed
348
+ if self.best_model_save_path is not None:
349
+ os.makedirs(self.best_model_save_path, exist_ok=True)
350
+ if self.log_path is not None:
351
+ os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
352
+
353
+ # Init callback called on new best model
354
+ if self.callback_on_new_best is not None:
355
+ self.callback_on_new_best.init_callback(self.model)
356
+
357
+ def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
358
+ """
359
+ Callback passed to the ``evaluate_policy`` function
360
+ in order to log the success rate (when applicable),
361
+ for instance when using HER.
362
+
363
+ :param locals_:
364
+ :param globals_:
365
+ """
366
+ info = locals_["info"]
367
+
368
+ if locals_["done"]:
369
+ maybe_is_success = info.get("is_success")
370
+ if maybe_is_success is not None:
371
+ self._is_success_buffer.append(maybe_is_success)
372
+
373
+ def _on_step(self) -> bool:
374
+
375
+ continue_training = True
376
+
377
+ if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
378
+
379
+ # Sync training and eval env if there is VecNormalize
380
+ if self.model.get_vec_normalize_env() is not None:
381
+ try:
382
+ sync_envs_normalization(self.training_env, self.eval_env)
383
+ except AttributeError:
384
+ raise AssertionError(
385
+ "Training and eval env are not wrapped the same way, "
386
+ "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
387
+ "and warning above."
388
+ )
389
+
390
+ # Reset success rate buffer
391
+ self._is_success_buffer = []
392
+
393
+ episode_rewards, episode_lengths = evaluate_policy(
394
+ self.model,
395
+ self.eval_env,
396
+ n_eval_episodes=self.n_eval_episodes,
397
+ render=self.render,
398
+ deterministic=self.deterministic,
399
+ return_episode_rewards=True,
400
+ warn=self.warn,
401
+ callback=self._log_success_callback,
402
+ )
403
+
404
+ if self.log_path is not None:
405
+ self.evaluations_timesteps.append(self.num_timesteps)
406
+ self.evaluations_results.append(episode_rewards)
407
+ self.evaluations_length.append(episode_lengths)
408
+
409
+ kwargs = {}
410
+ # Save success log if present
411
+ if len(self._is_success_buffer) > 0:
412
+ self.evaluations_successes.append(self._is_success_buffer)
413
+ kwargs = dict(successes=self.evaluations_successes)
414
+
415
+ np.savez(
416
+ self.log_path,
417
+ timesteps=self.evaluations_timesteps,
418
+ results=self.evaluations_results,
419
+ ep_lengths=self.evaluations_length,
420
+ **kwargs,
421
+ )
422
+
423
+ mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
424
+ mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
425
+ self.last_mean_reward = mean_reward
426
+
427
+ if self.verbose > 0:
428
+ print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
429
+ print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
430
+ # Add to current Logger
431
+ self.logger.record("eval/mean_reward", float(mean_reward))
432
+ self.logger.record("eval/mean_ep_length", mean_ep_length)
433
+
434
+ if len(self._is_success_buffer) > 0:
435
+ success_rate = np.mean(self._is_success_buffer)
436
+ if self.verbose > 0:
437
+ print(f"Success rate: {100 * success_rate:.2f}%")
438
+ self.logger.record("eval/success_rate", success_rate)
439
+
440
+ # Dump log so the evaluation results are printed with the correct timestep
441
+ self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
442
+ self.logger.dump(self.num_timesteps)
443
+
444
+ if mean_reward > self.best_mean_reward:
445
+ if self.verbose > 0:
446
+ print("New best mean reward!")
447
+ if self.best_model_save_path is not None:
448
+ self.model.save(os.path.join(self.best_model_save_path, "best_model"))
449
+ self.best_mean_reward = mean_reward
450
+ # Trigger callback on new best model, if needed
451
+ if self.callback_on_new_best is not None:
452
+ continue_training = self.callback_on_new_best.on_step()
453
+
454
+ # Trigger callback after every evaluation, if needed
455
+ if self.callback is not None:
456
+ continue_training = continue_training and self._on_event()
457
+
458
+ return continue_training
459
+
460
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
461
+ """
462
+ Update the references to the local variables.
463
+
464
+ :param locals_: the local variables during rollout collection
465
+ """
466
+ if self.callback:
467
+ self.callback.update_locals(locals_)
468
+
469
+
470
+ class StopTrainingOnRewardThreshold(BaseCallback):
471
+ """
472
+ Stop the training once a threshold in episodic reward
473
+ has been reached (i.e. when the model is good enough).
474
+
475
+ It must be used with the ``EvalCallback``.
476
+
477
+ :param reward_threshold: Minimum expected reward per episode
478
+ to stop training.
479
+ :param verbose:
480
+ """
481
+
482
+ def __init__(self, reward_threshold: float, verbose: int = 0):
483
+ super().__init__(verbose=verbose)
484
+ self.reward_threshold = reward_threshold
485
+
486
+ def _on_step(self) -> bool:
487
+ assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``"
488
+ # Convert np.bool_ to bool, otherwise callback() is False won't work
489
+ continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
490
+ if self.verbose > 0 and not continue_training:
491
+ print(
492
+ f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
493
+ f" is above the threshold {self.reward_threshold}"
494
+ )
495
+ return continue_training
496
+
497
+
498
+ class EveryNTimesteps(EventCallback):
499
+ """
500
+ Trigger a callback every ``n_steps`` timesteps
501
+
502
+ :param n_steps: Number of timesteps between two trigger.
503
+ :param callback: Callback that will be called
504
+ when the event is triggered.
505
+ """
506
+
507
+ def __init__(self, n_steps: int, callback: BaseCallback):
508
+ super().__init__(callback)
509
+ self.n_steps = n_steps
510
+ self.last_time_trigger = 0
511
+
512
+ def _on_step(self) -> bool:
513
+ if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
514
+ self.last_time_trigger = self.num_timesteps
515
+ return self._on_event()
516
+ return True
517
+
518
+
519
+ class StopTrainingOnMaxEpisodes(BaseCallback):
520
+ """
521
+ Stop the training once a maximum number of episodes are played.
522
+
523
+ For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
524
+ and in total for ``max_episodes * n_envs`` episodes.
525
+
526
+ :param max_episodes: Maximum number of episodes to stop training.
527
+ :param verbose: Select whether to print information about when training ended by reaching ``max_episodes``
528
+ """
529
+
530
+ def __init__(self, max_episodes: int, verbose: int = 0):
531
+ super().__init__(verbose=verbose)
532
+ self.max_episodes = max_episodes
533
+ self._total_max_episodes = max_episodes
534
+ self.n_episodes = 0
535
+
536
+ def _init_callback(self) -> None:
537
+ # At start set total max according to number of envirnments
538
+ self._total_max_episodes = self.max_episodes * self.training_env.num_envs
539
+
540
+ def _on_step(self) -> bool:
541
+ # Check that the `dones` local variable is defined
542
+ assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`"
543
+ self.n_episodes += np.sum(self.locals["dones"]).item()
544
+
545
+ continue_training = self.n_episodes < self._total_max_episodes
546
+
547
+ if self.verbose > 0 and not continue_training:
548
+ mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
549
+ mean_ep_str = (
550
+ f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
551
+ )
552
+
553
+ print(
554
+ f"Stopping training with a total of {self.num_timesteps} steps because the "
555
+ f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, "
556
+ f"by playing for {self.n_episodes} episodes "
557
+ f"{mean_ep_str}"
558
+ )
559
+ return continue_training
560
+
561
+
562
+ class StopTrainingOnNoModelImprovement(BaseCallback):
563
+ """
564
+ Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
565
+
566
+ It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
567
+
568
+ It must be used with the ``EvalCallback``.
569
+
570
+ :param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
571
+ :param min_evals: Number of evaluations before start to count evaluations without improvements.
572
+ :param verbose: Verbosity of the output (set to 1 for info messages)
573
+ """
574
+
575
+ def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
576
+ super().__init__(verbose=verbose)
577
+ self.max_no_improvement_evals = max_no_improvement_evals
578
+ self.min_evals = min_evals
579
+ self.last_best_mean_reward = -np.inf
580
+ self.no_improvement_evals = 0
581
+
582
+ def _on_step(self) -> bool:
583
+ assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
584
+
585
+ continue_training = True
586
+
587
+ if self.n_calls > self.min_evals:
588
+ if self.parent.best_mean_reward > self.last_best_mean_reward:
589
+ self.no_improvement_evals = 0
590
+ else:
591
+ self.no_improvement_evals += 1
592
+ if self.no_improvement_evals > self.max_no_improvement_evals:
593
+ continue_training = False
594
+
595
+ self.last_best_mean_reward = self.parent.best_mean_reward
596
+
597
+ if self.verbose > 0 and not continue_training:
598
+ print(
599
+ f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
600
+ )
601
+
602
+ return continue_training
dexart-release/stable_baselines3/common/distributions.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Probability distributions."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import gym
7
+ import torch as th
8
+ from gym import spaces
9
+ from torch import nn
10
+ from torch.distributions import Bernoulli, Categorical, Normal
11
+
12
+ from stable_baselines3.common.preprocessing import get_action_dim
13
+
14
+
15
+ class Distribution(ABC):
16
+ """Abstract base class for distributions."""
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.distribution = None
21
+
22
+ @abstractmethod
23
+ def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
24
+ """Create the layers and parameters that represent the distribution.
25
+
26
+ Subclasses must define this, but the arguments and return type vary between
27
+ concrete classes."""
28
+
29
+ @abstractmethod
30
+ def proba_distribution(self, *args, **kwargs) -> "Distribution":
31
+ """Set parameters of the distribution.
32
+
33
+ :return: self
34
+ """
35
+
36
+ @abstractmethod
37
+ def log_prob(self, x: th.Tensor) -> th.Tensor:
38
+ """
39
+ Returns the log likelihood
40
+
41
+ :param x: the taken action
42
+ :return: The log likelihood of the distribution
43
+ """
44
+
45
+ @abstractmethod
46
+ def entropy(self) -> Optional[th.Tensor]:
47
+ """
48
+ Returns Shannon's entropy of the probability
49
+
50
+ :return: the entropy, or None if no analytical form is known
51
+ """
52
+
53
+ @abstractmethod
54
+ def sample(self) -> th.Tensor:
55
+ """
56
+ Returns a sample from the probability distribution
57
+
58
+ :return: the stochastic action
59
+ """
60
+
61
+ @abstractmethod
62
+ def mode(self) -> th.Tensor:
63
+ """
64
+ Returns the most likely action (deterministic output)
65
+ from the probability distribution
66
+
67
+ :return: the stochastic action
68
+ """
69
+
70
+ def get_actions(self, deterministic: bool = False) -> th.Tensor:
71
+ """
72
+ Return actions according to the probability distribution.
73
+
74
+ :param deterministic:
75
+ :return:
76
+ """
77
+ if deterministic:
78
+ return self.mode()
79
+ return self.sample()
80
+
81
+ @abstractmethod
82
+ def actions_from_params(self, *args, **kwargs) -> th.Tensor:
83
+ """
84
+ Returns samples from the probability distribution
85
+ given its parameters.
86
+
87
+ :return: actions
88
+ """
89
+
90
+ @abstractmethod
91
+ def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
92
+ """
93
+ Returns samples and the associated log probabilities
94
+ from the probability distribution given its parameters.
95
+
96
+ :return: actions and log prob
97
+ """
98
+
99
+
100
+ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
101
+ """
102
+ Continuous actions are usually considered to be independent,
103
+ so we can sum components of the ``log_prob`` or the entropy.
104
+
105
+ :param tensor: shape: (n_batch, n_actions) or (n_batch,)
106
+ :return: shape: (n_batch,)
107
+ """
108
+ if len(tensor.shape) > 1:
109
+ tensor = tensor.sum(dim=1)
110
+ else:
111
+ tensor = tensor.sum()
112
+ return tensor
113
+
114
+
115
+ class DiagGaussianDistribution(Distribution):
116
+ """
117
+ Gaussian distribution with diagonal covariance matrix, for continuous actions.
118
+
119
+ :param action_dim: Dimension of the action space.
120
+ """
121
+
122
+ def __init__(self, action_dim: int):
123
+ super().__init__()
124
+ self.action_dim = action_dim
125
+ self.mean_actions = None
126
+ self.log_std = None
127
+
128
+ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
129
+ """
130
+ Create the layers and parameter that represent the distribution:
131
+ one output will be the mean of the Gaussian, the other parameter will be the
132
+ standard deviation (log std in fact to allow negative values)
133
+
134
+ :param latent_dim: Dimension of the last layer of the policy (before the action layer)
135
+ :param log_std_init: Initial value for the log standard deviation
136
+ :return:
137
+ """
138
+ mean_actions = nn.Linear(latent_dim, self.action_dim)
139
+ # TODO: allow action dependent std
140
+ log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
141
+ return mean_actions, log_std
142
+
143
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
144
+ """
145
+ Create the distribution given its parameters (mean, std)
146
+
147
+ :param mean_actions:
148
+ :param log_std:
149
+ :return:
150
+ """
151
+ action_std = th.ones_like(mean_actions) * log_std.exp()
152
+ self.distribution = Normal(mean_actions, action_std)
153
+ return self
154
+
155
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
156
+ """
157
+ Get the log probabilities of actions according to the distribution.
158
+ Note that you must first call the ``proba_distribution()`` method.
159
+
160
+ :param actions:
161
+ :return:
162
+ """
163
+ log_prob = self.distribution.log_prob(actions)
164
+ return sum_independent_dims(log_prob)
165
+
166
+ def entropy(self) -> th.Tensor:
167
+ return sum_independent_dims(self.distribution.entropy())
168
+
169
+ def sample(self) -> th.Tensor:
170
+ # Reparametrization trick to pass gradients
171
+ return self.distribution.rsample()
172
+
173
+ def mode(self) -> th.Tensor:
174
+ return self.distribution.mean
175
+
176
+ def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
177
+ # Update the proba distribution
178
+ self.proba_distribution(mean_actions, log_std)
179
+ return self.get_actions(deterministic=deterministic)
180
+
181
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
182
+ """
183
+ Compute the log probability of taking an action
184
+ given the distribution parameters.
185
+
186
+ :param mean_actions:
187
+ :param log_std:
188
+ :return:
189
+ """
190
+ actions = self.actions_from_params(mean_actions, log_std)
191
+ log_prob = self.log_prob(actions)
192
+ return actions, log_prob
193
+
194
+
195
+ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
196
+ """
197
+ Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
198
+
199
+ :param action_dim: Dimension of the action space.
200
+ :param epsilon: small value to avoid NaN due to numerical imprecision.
201
+ """
202
+
203
+ def __init__(self, action_dim: int, epsilon: float = 1e-6):
204
+ super().__init__(action_dim)
205
+ # Avoid NaN (prevents division by zero or log of zero)
206
+ self.epsilon = epsilon
207
+ self.gaussian_actions = None
208
+
209
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
210
+ super().proba_distribution(mean_actions, log_std)
211
+ return self
212
+
213
+ def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
214
+ # Inverse tanh
215
+ # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
216
+ # We use numpy to avoid numerical instability
217
+ if gaussian_actions is None:
218
+ # It will be clipped to avoid NaN when inversing tanh
219
+ gaussian_actions = TanhBijector.inverse(actions)
220
+
221
+ # Log likelihood for a Gaussian distribution
222
+ log_prob = super().log_prob(gaussian_actions)
223
+ # Squash correction (from original SAC implementation)
224
+ # this comes from the fact that tanh is bijective and differentiable
225
+ log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
226
+ return log_prob
227
+
228
+ def entropy(self) -> Optional[th.Tensor]:
229
+ # No analytical form,
230
+ # entropy needs to be estimated using -log_prob.mean()
231
+ return None
232
+
233
+ def sample(self) -> th.Tensor:
234
+ # Reparametrization trick to pass gradients
235
+ self.gaussian_actions = super().sample()
236
+ return th.tanh(self.gaussian_actions)
237
+
238
+ def mode(self) -> th.Tensor:
239
+ self.gaussian_actions = super().mode()
240
+ # Squash the output
241
+ return th.tanh(self.gaussian_actions)
242
+
243
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
244
+ action = self.actions_from_params(mean_actions, log_std)
245
+ log_prob = self.log_prob(action, self.gaussian_actions)
246
+ return action, log_prob
247
+
248
+
249
+ class CategoricalDistribution(Distribution):
250
+ """
251
+ Categorical distribution for discrete actions.
252
+
253
+ :param action_dim: Number of discrete actions
254
+ """
255
+
256
+ def __init__(self, action_dim: int):
257
+ super().__init__()
258
+ self.action_dim = action_dim
259
+
260
+ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
261
+ """
262
+ Create the layer that represents the distribution:
263
+ it will be the logits of the Categorical distribution.
264
+ You can then get probabilities using a softmax.
265
+
266
+ :param latent_dim: Dimension of the last layer
267
+ of the policy network (before the action layer)
268
+ :return:
269
+ """
270
+ action_logits = nn.Linear(latent_dim, self.action_dim)
271
+ return action_logits
272
+
273
+ def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution":
274
+ self.distribution = Categorical(logits=action_logits)
275
+ return self
276
+
277
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
278
+ return self.distribution.log_prob(actions)
279
+
280
+ def entropy(self) -> th.Tensor:
281
+ return self.distribution.entropy()
282
+
283
+ def sample(self) -> th.Tensor:
284
+ return self.distribution.sample()
285
+
286
+ def mode(self) -> th.Tensor:
287
+ return th.argmax(self.distribution.probs, dim=1)
288
+
289
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
290
+ # Update the proba distribution
291
+ self.proba_distribution(action_logits)
292
+ return self.get_actions(deterministic=deterministic)
293
+
294
+ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
295
+ actions = self.actions_from_params(action_logits)
296
+ log_prob = self.log_prob(actions)
297
+ return actions, log_prob
298
+
299
+
300
+ class MultiCategoricalDistribution(Distribution):
301
+ """
302
+ MultiCategorical distribution for multi discrete actions.
303
+
304
+ :param action_dims: List of sizes of discrete action spaces
305
+ """
306
+
307
+ def __init__(self, action_dims: List[int]):
308
+ super().__init__()
309
+ self.action_dims = action_dims
310
+
311
+ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
312
+ """
313
+ Create the layer that represents the distribution:
314
+ it will be the logits (flattened) of the MultiCategorical distribution.
315
+ You can then get probabilities using a softmax on each sub-space.
316
+
317
+ :param latent_dim: Dimension of the last layer
318
+ of the policy network (before the action layer)
319
+ :return:
320
+ """
321
+
322
+ action_logits = nn.Linear(latent_dim, sum(self.action_dims))
323
+ return action_logits
324
+
325
+ def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
326
+ self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
327
+ return self
328
+
329
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
330
+ # Extract each discrete action and compute log prob for their respective distributions
331
+ return th.stack(
332
+ [dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1
333
+ ).sum(dim=1)
334
+
335
+ def entropy(self) -> th.Tensor:
336
+ return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)
337
+
338
+ def sample(self) -> th.Tensor:
339
+ return th.stack([dist.sample() for dist in self.distribution], dim=1)
340
+
341
+ def mode(self) -> th.Tensor:
342
+ return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)
343
+
344
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
345
+ # Update the proba distribution
346
+ self.proba_distribution(action_logits)
347
+ return self.get_actions(deterministic=deterministic)
348
+
349
+ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
350
+ actions = self.actions_from_params(action_logits)
351
+ log_prob = self.log_prob(actions)
352
+ return actions, log_prob
353
+
354
+
355
+ class BernoulliDistribution(Distribution):
356
+ """
357
+ Bernoulli distribution for MultiBinary action spaces.
358
+
359
+ :param action_dim: Number of binary actions
360
+ """
361
+
362
+ def __init__(self, action_dims: int):
363
+ super().__init__()
364
+ self.action_dims = action_dims
365
+
366
+ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
367
+ """
368
+ Create the layer that represents the distribution:
369
+ it will be the logits of the Bernoulli distribution.
370
+
371
+ :param latent_dim: Dimension of the last layer
372
+ of the policy network (before the action layer)
373
+ :return:
374
+ """
375
+ action_logits = nn.Linear(latent_dim, self.action_dims)
376
+ return action_logits
377
+
378
+ def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution":
379
+ self.distribution = Bernoulli(logits=action_logits)
380
+ return self
381
+
382
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
383
+ return self.distribution.log_prob(actions).sum(dim=1)
384
+
385
+ def entropy(self) -> th.Tensor:
386
+ return self.distribution.entropy().sum(dim=1)
387
+
388
+ def sample(self) -> th.Tensor:
389
+ return self.distribution.sample()
390
+
391
+ def mode(self) -> th.Tensor:
392
+ return th.round(self.distribution.probs)
393
+
394
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
395
+ # Update the proba distribution
396
+ self.proba_distribution(action_logits)
397
+ return self.get_actions(deterministic=deterministic)
398
+
399
+ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
400
+ actions = self.actions_from_params(action_logits)
401
+ log_prob = self.log_prob(actions)
402
+ return actions, log_prob
403
+
404
+
405
+ class StateDependentNoiseDistribution(Distribution):
406
+ """
407
+ Distribution class for using generalized State Dependent Exploration (gSDE).
408
+ Paper: https://arxiv.org/abs/2005.05719
409
+
410
+ It is used to create the noise exploration matrix and
411
+ compute the log probability of an action with that noise.
412
+
413
+ :param action_dim: Dimension of the action space.
414
+ :param full_std: Whether to use (n_features x n_actions) parameters
415
+ for the std instead of only (n_features,)
416
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
417
+ a positive standard deviation (cf paper). It allows to keep variance
418
+ above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
419
+ :param squash_output: Whether to squash the output using a tanh function,
420
+ this ensures bounds are satisfied.
421
+ :param learn_features: Whether to learn features for gSDE or not.
422
+ This will enable gradients to be backpropagated through the features
423
+ ``latent_sde`` in the code.
424
+ :param epsilon: small value to avoid NaN due to numerical imprecision.
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ action_dim: int,
430
+ full_std: bool = True,
431
+ use_expln: bool = False,
432
+ squash_output: bool = False,
433
+ learn_features: bool = False,
434
+ epsilon: float = 1e-6,
435
+ ):
436
+ super().__init__()
437
+ self.action_dim = action_dim
438
+ self.latent_sde_dim = None
439
+ self.mean_actions = None
440
+ self.log_std = None
441
+ self.weights_dist = None
442
+ self.exploration_mat = None
443
+ self.exploration_matrices = None
444
+ self._latent_sde = None
445
+ self.use_expln = use_expln
446
+ self.full_std = full_std
447
+ self.epsilon = epsilon
448
+ self.learn_features = learn_features
449
+ if squash_output:
450
+ self.bijector = TanhBijector(epsilon)
451
+ else:
452
+ self.bijector = None
453
+
454
+ def get_std(self, log_std: th.Tensor) -> th.Tensor:
455
+ """
456
+ Get the standard deviation from the learned parameter
457
+ (log of it by default). This ensures that the std is positive.
458
+
459
+ :param log_std:
460
+ :return:
461
+ """
462
+ if self.use_expln:
463
+ # From gSDE paper, it allows to keep variance
464
+ # above zero and prevent it from growing too fast
465
+ below_threshold = th.exp(log_std) * (log_std <= 0)
466
+ # Avoid NaN: zeros values that are below zero
467
+ safe_log_std = log_std * (log_std > 0) + self.epsilon
468
+ above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0)
469
+ std = below_threshold + above_threshold
470
+ else:
471
+ # Use normal exponential
472
+ std = th.exp(log_std)
473
+
474
+ if self.full_std:
475
+ return std
476
+ # Reduce the number of parameters:
477
+ return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
478
+
479
+ def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
480
+ """
481
+ Sample weights for the noise exploration matrix,
482
+ using a centered Gaussian distribution.
483
+
484
+ :param log_std:
485
+ :param batch_size:
486
+ """
487
+ std = self.get_std(log_std)
488
+ self.weights_dist = Normal(th.zeros_like(std), std)
489
+ # Reparametrization trick to pass gradients
490
+ self.exploration_mat = self.weights_dist.rsample()
491
+ # Pre-compute matrices in case of parallel exploration
492
+ self.exploration_matrices = self.weights_dist.rsample((batch_size,))
493
+
494
+ def proba_distribution_net(
495
+ self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
496
+ ) -> Tuple[nn.Module, nn.Parameter]:
497
+ """
498
+ Create the layers and parameter that represent the distribution:
499
+ one output will be the deterministic action, the other parameter will be the
500
+ standard deviation of the distribution that control the weights of the noise matrix.
501
+
502
+ :param latent_dim: Dimension of the last layer of the policy (before the action layer)
503
+ :param log_std_init: Initial value for the log standard deviation
504
+ :param latent_sde_dim: Dimension of the last layer of the features extractor
505
+ for gSDE. By default, it is shared with the policy network.
506
+ :return:
507
+ """
508
+ # Network for the deterministic action, it represents the mean of the distribution
509
+ mean_actions_net = nn.Linear(latent_dim, self.action_dim)
510
+ # When we learn features for the noise, the feature dimension
511
+ # can be different between the policy and the noise network
512
+ self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
513
+ # Reduce the number of parameters if needed
514
+ log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
515
+ # Transform it to a parameter so it can be optimized
516
+ log_std = nn.Parameter(log_std * log_std_init, requires_grad=True)
517
+ # Sample an exploration matrix
518
+ self.sample_weights(log_std)
519
+ return mean_actions_net, log_std
520
+
521
+ def proba_distribution(
522
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
523
+ ) -> "StateDependentNoiseDistribution":
524
+ """
525
+ Create the distribution given its parameters (mean, std)
526
+
527
+ :param mean_actions:
528
+ :param log_std:
529
+ :param latent_sde:
530
+ :return:
531
+ """
532
+ # Stop gradient if we don't want to influence the features
533
+ self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
534
+ variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2)
535
+ self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
536
+ return self
537
+
538
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
539
+ if self.bijector is not None:
540
+ gaussian_actions = self.bijector.inverse(actions)
541
+ else:
542
+ gaussian_actions = actions
543
+ # log likelihood for a gaussian
544
+ log_prob = self.distribution.log_prob(gaussian_actions)
545
+ # Sum along action dim
546
+ log_prob = sum_independent_dims(log_prob)
547
+
548
+ if self.bijector is not None:
549
+ # Squash correction (from original SAC implementation)
550
+ log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
551
+ return log_prob
552
+
553
+ def entropy(self) -> Optional[th.Tensor]:
554
+ if self.bijector is not None:
555
+ # No analytical form,
556
+ # entropy needs to be estimated using -log_prob.mean()
557
+ return None
558
+ return sum_independent_dims(self.distribution.entropy())
559
+
560
+ def sample(self) -> th.Tensor:
561
+ noise = self.get_noise(self._latent_sde)
562
+ actions = self.distribution.mean + noise
563
+ if self.bijector is not None:
564
+ return self.bijector.forward(actions)
565
+ return actions
566
+
567
+ def mode(self) -> th.Tensor:
568
+ actions = self.distribution.mean
569
+ if self.bijector is not None:
570
+ return self.bijector.forward(actions)
571
+ return actions
572
+
573
+ def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
574
+ latent_sde = latent_sde if self.learn_features else latent_sde.detach()
575
+ # Default case: only one exploration matrix
576
+ if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
577
+ return th.mm(latent_sde, self.exploration_mat)
578
+ # Use batch matrix multiplication for efficient computation
579
+ # (batch_size, n_features) -> (batch_size, 1, n_features)
580
+ latent_sde = latent_sde.unsqueeze(1)
581
+ # (batch_size, 1, n_actions)
582
+ noise = th.bmm(latent_sde, self.exploration_matrices)
583
+ return noise.squeeze(1)
584
+
585
+ def actions_from_params(
586
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
587
+ ) -> th.Tensor:
588
+ # Update the proba distribution
589
+ self.proba_distribution(mean_actions, log_std, latent_sde)
590
+ return self.get_actions(deterministic=deterministic)
591
+
592
+ def log_prob_from_params(
593
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
594
+ ) -> Tuple[th.Tensor, th.Tensor]:
595
+ actions = self.actions_from_params(mean_actions, log_std, latent_sde)
596
+ log_prob = self.log_prob(actions)
597
+ return actions, log_prob
598
+
599
+
600
+ class TanhBijector:
601
+ """
602
+ Bijective transformation of a probability distribution
603
+ using a squashing function (tanh)
604
+ TODO: use Pyro instead (https://pyro.ai/)
605
+
606
+ :param epsilon: small value to avoid NaN due to numerical imprecision.
607
+ """
608
+
609
+ def __init__(self, epsilon: float = 1e-6):
610
+ super().__init__()
611
+ self.epsilon = epsilon
612
+
613
+ @staticmethod
614
+ def forward(x: th.Tensor) -> th.Tensor:
615
+ return th.tanh(x)
616
+
617
+ @staticmethod
618
+ def atanh(x: th.Tensor) -> th.Tensor:
619
+ """
620
+ Inverse of Tanh
621
+
622
+ Taken from Pyro: https://github.com/pyro-ppl/pyro
623
+ 0.5 * torch.log((1 + x ) / (1 - x))
624
+ """
625
+ return 0.5 * (x.log1p() - (-x).log1p())
626
+
627
+ @staticmethod
628
+ def inverse(y: th.Tensor) -> th.Tensor:
629
+ """
630
+ Inverse tanh.
631
+
632
+ :param y:
633
+ :return:
634
+ """
635
+ eps = th.finfo(y.dtype).eps
636
+ # Clip the action to avoid NaN
637
+ return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
638
+
639
+ def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
640
+ # Squash correction (from original SAC implementation)
641
+ return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
642
+
643
+
644
+ def make_proba_distribution(
645
+ action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
646
+ ) -> Distribution:
647
+ """
648
+ Return an instance of Distribution for the correct type of action space
649
+
650
+ :param action_space: the input action space
651
+ :param use_sde: Force the use of StateDependentNoiseDistribution
652
+ instead of DiagGaussianDistribution
653
+ :param dist_kwargs: Keyword arguments to pass to the probability distribution
654
+ :return: the appropriate Distribution object
655
+ """
656
+ if dist_kwargs is None:
657
+ dist_kwargs = {}
658
+
659
+ if isinstance(action_space, spaces.Box):
660
+ assert len(action_space.shape) == 1, "Error: the action space must be a vector"
661
+ cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
662
+ return cls(get_action_dim(action_space), **dist_kwargs)
663
+ elif isinstance(action_space, spaces.Discrete):
664
+ return CategoricalDistribution(action_space.n, **dist_kwargs)
665
+ elif isinstance(action_space, spaces.MultiDiscrete):
666
+ return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
667
+ elif isinstance(action_space, spaces.MultiBinary):
668
+ return BernoulliDistribution(action_space.n, **dist_kwargs)
669
+ else:
670
+ raise NotImplementedError(
671
+ "Error: probability distribution, not implemented for action space"
672
+ f"of type {type(action_space)}."
673
+ " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
674
+ )
675
+
676
+
677
+ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor:
678
+ """
679
+ Wrapper for the PyTorch implementation of the full form KL Divergence
680
+
681
+ :param dist_true: the p distribution
682
+ :param dist_pred: the q distribution
683
+ :return: KL(dist_true||dist_pred)
684
+ """
685
+ # KL Divergence for different distribution types is out of scope
686
+ assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
687
+
688
+ # MultiCategoricalDistribution is not a PyTorch Distribution subclass
689
+ # so we need to implement it ourselves!
690
+ if isinstance(dist_pred, MultiCategoricalDistribution):
691
+ assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space"
692
+ return th.stack(
693
+ [th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
694
+ dim=1,
695
+ ).sum(dim=1)
696
+
697
+ # Use the PyTorch kl_divergence implementation
698
+ else:
699
+ return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)
dexart-release/stable_baselines3/common/env_util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Callable, Dict, Optional, Type, Union
3
+
4
+ import gym
5
+
6
+ from stable_baselines3.common.monitor import Monitor
7
+ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
8
+
9
+
10
+ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
11
+ """
12
+ Retrieve a ``VecEnvWrapper`` object by recursively searching.
13
+
14
+ :param env: Environment to unwrap
15
+ :param wrapper_class: Wrapper to look for
16
+ :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
17
+ """
18
+ env_tmp = env
19
+ while isinstance(env_tmp, gym.Wrapper):
20
+ if isinstance(env_tmp, wrapper_class):
21
+ return env_tmp
22
+ env_tmp = env_tmp.env
23
+ return None
24
+
25
+
26
+ def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
27
+ """
28
+ Check if a given environment has been wrapped with a given wrapper.
29
+
30
+ :param env: Environment to check
31
+ :param wrapper_class: Wrapper class to look for
32
+ :return: True if environment has been wrapped with ``wrapper_class``.
33
+ """
34
+ return unwrap_wrapper(env, wrapper_class) is not None
35
+
36
+
37
+ def make_vec_env(
38
+ env_id: Union[str, Type[gym.Env]],
39
+ n_envs: int = 1,
40
+ seed: Optional[int] = None,
41
+ start_index: int = 0,
42
+ monitor_dir: Optional[str] = None,
43
+ wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
44
+ env_kwargs: Optional[Dict[str, Any]] = None,
45
+ vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
46
+ vec_env_kwargs: Optional[Dict[str, Any]] = None,
47
+ monitor_kwargs: Optional[Dict[str, Any]] = None,
48
+ wrapper_kwargs: Optional[Dict[str, Any]] = None,
49
+ ) -> VecEnv:
50
+ """
51
+ Create a wrapped, monitored ``VecEnv``.
52
+ By default it uses a ``DummyVecEnv`` which is usually faster
53
+ than a ``SubprocVecEnv``.
54
+
55
+ :param env_id: the environment ID or the environment class
56
+ :param n_envs: the number of environments you wish to have in parallel
57
+ :param seed: the initial seed for the random number generator
58
+ :param start_index: start rank index
59
+ :param monitor_dir: Path to a folder where the monitor files will be saved.
60
+ If None, no file will be written, however, the env will still be wrapped
61
+ in a Monitor wrapper to provide additional information about training.
62
+ :param wrapper_class: Additional wrapper to use on the environment.
63
+ This can also be a function with single argument that wraps the environment in many things.
64
+ :param env_kwargs: Optional keyword argument to pass to the env constructor
65
+ :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
66
+ :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
67
+ :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
68
+ :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
69
+ :return: The wrapped environment
70
+ """
71
+ env_kwargs = {} if env_kwargs is None else env_kwargs
72
+ vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
73
+ monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
74
+ wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs
75
+
76
+ def make_env(rank):
77
+ def _init():
78
+ if isinstance(env_id, str):
79
+ env = gym.make(env_id, **env_kwargs)
80
+ else:
81
+ env = env_id(**env_kwargs)
82
+ if seed is not None:
83
+ env.seed(seed + rank)
84
+ env.action_space.seed(seed + rank)
85
+ # Wrap the env in a Monitor wrapper
86
+ # to have additional training information
87
+ monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
88
+ # Create the monitor folder if needed
89
+ if monitor_path is not None:
90
+ os.makedirs(monitor_dir, exist_ok=True)
91
+ env = Monitor(env, filename=monitor_path, **monitor_kwargs)
92
+ # Optionally, wrap the environment with the provided wrapper
93
+ if wrapper_class is not None:
94
+ env = wrapper_class(env, **wrapper_kwargs)
95
+ return env
96
+
97
+ return _init
98
+
99
+ # No custom VecEnv is passed
100
+ if vec_env_cls is None:
101
+ # Default: use a DummyVecEnv
102
+ vec_env_cls = DummyVecEnv
103
+
104
+ return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
dexart-release/stable_baselines3/common/evaluation.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import gym
5
+ import numpy as np
6
+
7
+ from stable_baselines3.common import base_class
8
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
9
+
10
+
11
+ def evaluate_policy(
12
+ model: "base_class.BaseAlgorithm",
13
+ env: Union[gym.Env, VecEnv],
14
+ n_eval_episodes: int = 10,
15
+ deterministic: bool = True,
16
+ render: bool = False,
17
+ callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
18
+ reward_threshold: Optional[float] = None,
19
+ return_episode_rewards: bool = False,
20
+ warn: bool = True,
21
+ ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
22
+ """
23
+ Runs policy for ``n_eval_episodes`` episodes and returns average reward.
24
+ If a vector env is passed in, this divides the episodes to evaluate onto the
25
+ different elements of the vector env. This static division of work is done to
26
+ remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
27
+ details and discussion.
28
+
29
+ .. note::
30
+ If environment has not been wrapped with ``Monitor`` wrapper, reward and
31
+ episode lengths are counted as it appears with ``env.step`` calls. If
32
+ the environment contains wrappers that modify rewards or episode lengths
33
+ (e.g. reward scaling, early episode reset), these will affect the evaluation
34
+ results as well. You can avoid this by wrapping environment with ``Monitor``
35
+ wrapper before anything else.
36
+
37
+ :param model: The RL agent you want to evaluate.
38
+ :param env: The gym environment or ``VecEnv`` environment.
39
+ :param n_eval_episodes: Number of episode to evaluate the agent
40
+ :param deterministic: Whether to use deterministic or stochastic actions
41
+ :param render: Whether to render the environment or not
42
+ :param callback: callback function to do additional checks,
43
+ called after each step. Gets locals() and globals() passed as parameters.
44
+ :param reward_threshold: Minimum expected reward per episode,
45
+ this will raise an error if the performance is not met
46
+ :param return_episode_rewards: If True, a list of rewards and episode lengths
47
+ per episode will be returned instead of the mean.
48
+ :param warn: If True (default), warns user about lack of a Monitor wrapper in the
49
+ evaluation environment.
50
+ :return: Mean reward per episode, std of reward per episode.
51
+ Returns ([float], [int]) when ``return_episode_rewards`` is True, first
52
+ list containing per-episode rewards and second containing per-episode lengths
53
+ (in number of steps).
54
+ """
55
+ is_monitor_wrapped = False
56
+ # Avoid circular import
57
+ from stable_baselines3.common.monitor import Monitor
58
+
59
+ if not isinstance(env, VecEnv):
60
+ env = DummyVecEnv([lambda: env])
61
+
62
+ is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
63
+
64
+ if not is_monitor_wrapped and warn:
65
+ warnings.warn(
66
+ "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
67
+ "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
68
+ "Consider wrapping environment first with ``Monitor`` wrapper.",
69
+ UserWarning,
70
+ )
71
+
72
+ n_envs = env.num_envs
73
+ episode_rewards = []
74
+ episode_lengths = []
75
+
76
+ episode_counts = np.zeros(n_envs, dtype="int")
77
+ # Divides episodes among different sub environments in the vector as evenly as possible
78
+ episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
79
+
80
+ current_rewards = np.zeros(n_envs)
81
+ current_lengths = np.zeros(n_envs, dtype="int")
82
+ observations = env.reset()
83
+ states = None
84
+ episode_starts = np.ones((env.num_envs,), dtype=bool)
85
+ while (episode_counts < episode_count_targets).any():
86
+ actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
87
+ observations, rewards, dones, infos = env.step(actions)
88
+ current_rewards += rewards
89
+ current_lengths += 1
90
+ for i in range(n_envs):
91
+ if episode_counts[i] < episode_count_targets[i]:
92
+
93
+ # unpack values so that the callback can access the local variables
94
+ reward = rewards[i]
95
+ done = dones[i]
96
+ info = infos[i]
97
+ episode_starts[i] = done
98
+
99
+ if callback is not None:
100
+ callback(locals(), globals())
101
+
102
+ if dones[i]:
103
+ if is_monitor_wrapped:
104
+ # Atari wrapper can send a "done" signal when
105
+ # the agent loses a life, but it does not correspond
106
+ # to the true end of episode
107
+ if "episode" in info.keys():
108
+ # Do not trust "done" with episode endings.
109
+ # Monitor wrapper includes "episode" key in info if environment
110
+ # has been wrapped with it. Use those rewards instead.
111
+ episode_rewards.append(info["episode"]["r"])
112
+ episode_lengths.append(info["episode"]["l"])
113
+ # Only increment at the real end of an episode
114
+ episode_counts[i] += 1
115
+ else:
116
+ episode_rewards.append(current_rewards[i])
117
+ episode_lengths.append(current_lengths[i])
118
+ episode_counts[i] += 1
119
+ current_rewards[i] = 0
120
+ current_lengths[i] = 0
121
+
122
+ if render:
123
+ env.render()
124
+
125
+ mean_reward = np.mean(episode_rewards)
126
+ std_reward = np.std(episode_rewards)
127
+ if reward_threshold is not None:
128
+ assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
129
+ if return_episode_rewards:
130
+ return episode_rewards, episode_lengths
131
+ return mean_reward, std_reward
dexart-release/stable_baselines3/common/logger.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ import warnings
7
+ from collections import defaultdict
8
+ from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union
9
+
10
+ import numpy as np
11
+ import pandas
12
+ import torch as th
13
+ from matplotlib import pyplot as plt
14
+
15
+ try:
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ except ImportError:
18
+ SummaryWriter = None
19
+
20
+ DEBUG = 10
21
+ INFO = 20
22
+ WARN = 30
23
+ ERROR = 40
24
+ DISABLED = 50
25
+
26
+
27
+ class Video:
28
+ """
29
+ Video data class storing the video frames and the frame per seconds
30
+
31
+ :param frames: frames to create the video from
32
+ :param fps: frames per second
33
+ """
34
+
35
+ def __init__(self, frames: th.Tensor, fps: Union[float, int]):
36
+ self.frames = frames
37
+ self.fps = fps
38
+
39
+
40
+ class Figure:
41
+ """
42
+ Figure data class storing a matplotlib figure and whether to close the figure after logging it
43
+
44
+ :param figure: figure to log
45
+ :param close: if true, close the figure after logging it
46
+ """
47
+
48
+ def __init__(self, figure: plt.figure, close: bool):
49
+ self.figure = figure
50
+ self.close = close
51
+
52
+
53
+ class Image:
54
+ """
55
+ Image data class storing an image and data format
56
+
57
+ :param image: image to log
58
+ :param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc.
59
+ More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html
60
+ Gym envs normally use 'HWC' (channel last)
61
+ """
62
+
63
+ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):
64
+ self.image = image
65
+ self.dataformats = dataformats
66
+
67
+
68
+ class FormatUnsupportedError(NotImplementedError):
69
+ """
70
+ Custom error to display informative message when
71
+ a value is not supported by some formats.
72
+
73
+ :param unsupported_formats: A sequence of unsupported formats,
74
+ for instance ``["stdout"]``.
75
+ :param value_description: Description of the value that cannot be logged by this format.
76
+ """
77
+
78
+ def __init__(self, unsupported_formats: Sequence[str], value_description: str):
79
+ if len(unsupported_formats) > 1:
80
+ format_str = f"formats {', '.join(unsupported_formats)} are"
81
+ else:
82
+ format_str = f"format {unsupported_formats[0]} is"
83
+ super().__init__(
84
+ f"The {format_str} not supported for the {value_description} value logged.\n"
85
+ f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
86
+ )
87
+
88
+
89
+ class KVWriter:
90
+ """
91
+ Key Value writer
92
+ """
93
+
94
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
95
+ step: int = 0) -> None:
96
+ """
97
+ Write a dictionary to file
98
+
99
+ :param key_values:
100
+ :param key_excluded:
101
+ :param step:
102
+ """
103
+ raise NotImplementedError
104
+
105
+ def close(self) -> None:
106
+ """
107
+ Close owned resources
108
+ """
109
+ raise NotImplementedError
110
+
111
+
112
+ class SeqWriter:
113
+ """
114
+ sequence writer
115
+ """
116
+
117
+ def write_sequence(self, sequence: List) -> None:
118
+ """
119
+ write_sequence an array to file
120
+
121
+ :param sequence:
122
+ """
123
+ raise NotImplementedError
124
+
125
+
126
+ class HumanOutputFormat(KVWriter, SeqWriter):
127
+ """A human-readable output format producing ASCII tables of key-value pairs.
128
+
129
+ Set attribute ``max_length`` to change the maximum length of keys and values
130
+ to write to output (or specify it when calling ``__init__``).
131
+
132
+ :param filename_or_file: the file to write the log to
133
+ :param max_length: the maximum length of keys and values to write to output.
134
+ Outputs longer than this will be truncated. An error will be raised
135
+ if multiple keys are truncated to the same value. The maximum output
136
+ width will be ``2*max_length + 7``. The default of 36 produces output
137
+ no longer than 79 characters wide.
138
+ """
139
+
140
+ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
141
+ self.max_length = max_length
142
+ if isinstance(filename_or_file, str):
143
+ self.file = open(filename_or_file, "wt")
144
+ self.own_file = True
145
+ else:
146
+ assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
147
+ self.file = filename_or_file
148
+ self.own_file = False
149
+
150
+ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
151
+ # Create strings for printing
152
+ key2str = {}
153
+ tag = None
154
+ for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
155
+
156
+ if excluded is not None and ("stdout" in excluded or "log" in excluded):
157
+ continue
158
+
159
+ elif isinstance(value, Video):
160
+ raise FormatUnsupportedError(["stdout", "log"], "video")
161
+
162
+ elif isinstance(value, Figure):
163
+ raise FormatUnsupportedError(["stdout", "log"], "figure")
164
+
165
+ elif isinstance(value, Image):
166
+ raise FormatUnsupportedError(["stdout", "log"], "image")
167
+
168
+ elif isinstance(value, float):
169
+ # Align left
170
+ value_str = f"{value:<8.3g}"
171
+ else:
172
+ value_str = str(value)
173
+
174
+ if key.find("/") > 0: # Find tag and add it to the dict
175
+ tag = key[: key.find("/") + 1]
176
+ key2str[self._truncate(tag)] = ""
177
+ # Remove tag from key
178
+ if tag is not None and tag in key:
179
+ key = str(" " + key[len(tag):])
180
+
181
+ truncated_key = self._truncate(key)
182
+ if truncated_key in key2str:
183
+ raise ValueError(
184
+ f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
185
+ )
186
+ key2str[truncated_key] = self._truncate(value_str)
187
+
188
+ # Find max widths
189
+ if len(key2str) == 0:
190
+ warnings.warn("Tried to write empty key-value dict")
191
+ return
192
+ else:
193
+ key_width = max(map(len, key2str.keys()))
194
+ val_width = max(map(len, key2str.values()))
195
+
196
+ # Write out the data
197
+ dashes = "-" * (key_width + val_width + 7)
198
+ lines = [dashes]
199
+ for key, value in key2str.items():
200
+ key_space = " " * (key_width - len(key))
201
+ val_space = " " * (val_width - len(value))
202
+ lines.append(f"| {key}{key_space} | {value}{val_space} |")
203
+ lines.append(dashes)
204
+ self.file.write("\n".join(lines) + "\n")
205
+
206
+ # Flush the output to the file
207
+ self.file.flush()
208
+
209
+ def _truncate(self, string: str) -> str:
210
+ if len(string) > self.max_length:
211
+ string = string[: self.max_length - 3] + "..."
212
+ return string
213
+
214
+ def write_sequence(self, sequence: List) -> None:
215
+ sequence = list(sequence)
216
+ for i, elem in enumerate(sequence):
217
+ self.file.write(elem)
218
+ if i < len(sequence) - 1: # add space unless this is the last one
219
+ self.file.write(" ")
220
+ self.file.write("\n")
221
+ self.file.flush()
222
+
223
+ def close(self) -> None:
224
+ """
225
+ closes the file
226
+ """
227
+ if self.own_file:
228
+ self.file.close()
229
+
230
+
231
+ def filter_excluded_keys(
232
+ key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
233
+ ) -> Dict[str, Any]:
234
+ """
235
+ Filters the keys specified by ``key_exclude`` for the specified format
236
+
237
+ :param key_values: log dictionary to be filtered
238
+ :param key_excluded: keys to be excluded per format
239
+ :param _format: format for which this filter is run
240
+ :return: dict without the excluded keys
241
+ """
242
+
243
+ def is_excluded(key: str) -> bool:
244
+ return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key]
245
+
246
+ return {key: value for key, value in key_values.items() if not is_excluded(key)}
247
+
248
+
249
+ class JSONOutputFormat(KVWriter):
250
+ """
251
+ Log to a file, in the JSON format
252
+
253
+ :param filename: the file to write the log to
254
+ """
255
+
256
+ def __init__(self, filename: str):
257
+ self.file = open(filename, "wt")
258
+
259
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
260
+ step: int = 0) -> None:
261
+ def cast_to_json_serializable(value: Any):
262
+ if isinstance(value, Video):
263
+ raise FormatUnsupportedError(["json"], "video")
264
+ if isinstance(value, Figure):
265
+ raise FormatUnsupportedError(["json"], "figure")
266
+ if isinstance(value, Image):
267
+ raise FormatUnsupportedError(["json"], "image")
268
+ if hasattr(value, "dtype"):
269
+ if value.shape == () or len(value) == 1:
270
+ # if value is a dimensionless numpy array or of length 1, serialize as a float
271
+ return float(value)
272
+ else:
273
+ # otherwise, a value is a numpy array, serialize as a list or nested lists
274
+ return value.tolist()
275
+ return value
276
+
277
+ key_values = {
278
+ key: cast_to_json_serializable(value)
279
+ for key, value in filter_excluded_keys(key_values, key_excluded, "json").items()
280
+ }
281
+ self.file.write(json.dumps(key_values) + "\n")
282
+ self.file.flush()
283
+
284
+ def close(self) -> None:
285
+ """
286
+ closes the file
287
+ """
288
+
289
+ self.file.close()
290
+
291
+
292
+ class CSVOutputFormat(KVWriter):
293
+ """
294
+ Log to a file, in a CSV format
295
+
296
+ :param filename: the file to write the log to
297
+ """
298
+
299
+ def __init__(self, filename: str):
300
+ self.file = open(filename, "w+t")
301
+ self.keys = []
302
+ self.separator = ","
303
+ self.quotechar = '"'
304
+
305
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
306
+ step: int = 0) -> None:
307
+ # Add our current row to the history
308
+ key_values = filter_excluded_keys(key_values, key_excluded, "csv")
309
+ extra_keys = key_values.keys() - self.keys
310
+ if extra_keys:
311
+ self.keys.extend(extra_keys)
312
+ self.file.seek(0)
313
+ lines = self.file.readlines()
314
+ self.file.seek(0)
315
+ for (i, key) in enumerate(self.keys):
316
+ if i > 0:
317
+ self.file.write(",")
318
+ self.file.write(key)
319
+ self.file.write("\n")
320
+ for line in lines[1:]:
321
+ self.file.write(line[:-1])
322
+ self.file.write(self.separator * len(extra_keys))
323
+ self.file.write("\n")
324
+ for i, key in enumerate(self.keys):
325
+ if i > 0:
326
+ self.file.write(",")
327
+ value = key_values.get(key)
328
+
329
+ if isinstance(value, Video):
330
+ raise FormatUnsupportedError(["csv"], "video")
331
+
332
+ elif isinstance(value, Figure):
333
+ raise FormatUnsupportedError(["csv"], "figure")
334
+
335
+ elif isinstance(value, Image):
336
+ raise FormatUnsupportedError(["csv"], "image")
337
+
338
+ elif isinstance(value, str):
339
+ # escape quotechars by prepending them with another quotechar
340
+ value = value.replace(self.quotechar, self.quotechar + self.quotechar)
341
+
342
+ # additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers
343
+ self.file.write(self.quotechar + value + self.quotechar)
344
+
345
+ elif value is not None:
346
+ self.file.write(str(value))
347
+ self.file.write("\n")
348
+ self.file.flush()
349
+
350
+ def close(self) -> None:
351
+ """
352
+ closes the file
353
+ """
354
+ self.file.close()
355
+
356
+
357
+ class TensorBoardOutputFormat(KVWriter):
358
+ """
359
+ Dumps key/value pairs into TensorBoard's numeric format.
360
+
361
+ :param folder: the folder to write the log to
362
+ """
363
+
364
+ def __init__(self, folder: str):
365
+ assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
366
+ self.writer = SummaryWriter(log_dir=folder)
367
+
368
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
369
+ step: int = 0) -> None:
370
+ for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
371
+
372
+ if excluded is not None and "tensorboard" in excluded:
373
+ continue
374
+
375
+ if isinstance(value, np.ScalarType):
376
+ if isinstance(value, str):
377
+ # str is considered a np.ScalarType
378
+ self.writer.add_text(key, value, step)
379
+ else:
380
+ self.writer.add_scalar(key, value, step)
381
+
382
+ if isinstance(value, th.Tensor):
383
+ self.writer.add_histogram(key, value, step)
384
+
385
+ if isinstance(value, Video):
386
+ self.writer.add_video(key, value.frames, step, value.fps)
387
+
388
+ if isinstance(value, Figure):
389
+ self.writer.add_figure(key, value.figure, step, close=value.close)
390
+
391
+ if isinstance(value, Image):
392
+ self.writer.add_image(key, value.image, step, dataformats=value.dataformats)
393
+
394
+ # Flush the output to the file
395
+ self.writer.flush()
396
+
397
+ def close(self) -> None:
398
+ """
399
+ closes the file
400
+ """
401
+ if self.writer:
402
+ self.writer.close()
403
+ self.writer = None
404
+
405
+
406
+ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
407
+ """
408
+ return a logger for the requested format
409
+
410
+ :param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard')
411
+ :param log_dir: the logging directory
412
+ :param log_suffix: the suffix for the log file
413
+ :return: the logger
414
+ """
415
+ os.makedirs(log_dir, exist_ok=True)
416
+ if _format == "stdout":
417
+ return HumanOutputFormat(sys.stdout)
418
+ elif _format == "log":
419
+ return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt"))
420
+ elif _format == "json":
421
+ return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json"))
422
+ elif _format == "csv":
423
+ return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv"))
424
+ elif _format == "tensorboard":
425
+ return TensorBoardOutputFormat(log_dir)
426
+ else:
427
+ raise ValueError(f"Unknown format specified: {_format}")
428
+
429
+
430
+ # ================================================================
431
+ # Backend
432
+ # ================================================================
433
+
434
+
435
+ class Logger:
436
+ """
437
+ The logger class.
438
+
439
+ :param folder: the logging location
440
+ :param output_formats: the list of output formats
441
+ """
442
+
443
+ def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
444
+ self.name_to_value = defaultdict(float) # values this iteration
445
+ self.name_to_count = defaultdict(int)
446
+ self.name_to_excluded = defaultdict(str)
447
+ self.level = INFO
448
+ self.dir = folder
449
+ self.output_formats = output_formats
450
+
451
+ def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
452
+ """
453
+ Log a value of some diagnostic
454
+ Call this once for each diagnostic quantity, each iteration
455
+ If called many times, last value will be used.
456
+
457
+ :param key: save to log this key
458
+ :param value: save to log this value
459
+ :param exclude: outputs to be excluded
460
+ """
461
+ self.name_to_value[key] = value
462
+ self.name_to_excluded[key] = exclude
463
+
464
+ def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
465
+ """
466
+ The same as record(), but if called many times, values averaged.
467
+
468
+ :param key: save to log this key
469
+ :param value: save to log this value
470
+ :param exclude: outputs to be excluded
471
+ """
472
+ if value is None:
473
+ self.name_to_value[key] = None
474
+ return
475
+ old_val, count = self.name_to_value[key], self.name_to_count[key]
476
+ self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
477
+ self.name_to_count[key] = count + 1
478
+ self.name_to_excluded[key] = exclude
479
+
480
+ def dump(self, step: int = 0) -> None:
481
+ """
482
+ Write all of the diagnostics from the current iteration
483
+ """
484
+ if self.level == DISABLED:
485
+ return
486
+ for _format in self.output_formats:
487
+ if isinstance(_format, KVWriter):
488
+ _format.write(self.name_to_value, self.name_to_excluded, step)
489
+
490
+ self.name_to_value.clear()
491
+ self.name_to_count.clear()
492
+ self.name_to_excluded.clear()
493
+
494
+ def log(self, *args, level: int = INFO) -> None:
495
+ """
496
+ Write the sequence of args, with no separators,
497
+ to the console and output files (if you've configured an output file).
498
+
499
+ level: int. (see logger.py docs) If the global logger level is higher than
500
+ the level argument here, don't print to stdout.
501
+
502
+ :param args: log the arguments
503
+ :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
504
+ """
505
+ if self.level <= level:
506
+ self._do_log(args)
507
+
508
+ def debug(self, *args) -> None:
509
+ """
510
+ Write the sequence of args, with no separators,
511
+ to the console and output files (if you've configured an output file).
512
+ Using the DEBUG level.
513
+
514
+ :param args: log the arguments
515
+ """
516
+ self.log(*args, level=DEBUG)
517
+
518
+ def info(self, *args) -> None:
519
+ """
520
+ Write the sequence of args, with no separators,
521
+ to the console and output files (if you've configured an output file).
522
+ Using the INFO level.
523
+
524
+ :param args: log the arguments
525
+ """
526
+ self.log(*args, level=INFO)
527
+
528
+ def warn(self, *args) -> None:
529
+ """
530
+ Write the sequence of args, with no separators,
531
+ to the console and output files (if you've configured an output file).
532
+ Using the WARN level.
533
+
534
+ :param args: log the arguments
535
+ """
536
+ self.log(*args, level=WARN)
537
+
538
+ def error(self, *args) -> None:
539
+ """
540
+ Write the sequence of args, with no separators,
541
+ to the console and output files (if you've configured an output file).
542
+ Using the ERROR level.
543
+
544
+ :param args: log the arguments
545
+ """
546
+ self.log(*args, level=ERROR)
547
+
548
+ # Configuration
549
+ # ----------------------------------------
550
+ def set_level(self, level: int) -> None:
551
+ """
552
+ Set logging threshold on current logger.
553
+
554
+ :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
555
+ """
556
+ self.level = level
557
+
558
+ def get_dir(self) -> str:
559
+ """
560
+ Get directory that log files are being written to.
561
+ will be None if there is no output directory (i.e., if you didn't call start)
562
+
563
+ :return: the logging directory
564
+ """
565
+ return self.dir
566
+
567
+ def close(self) -> None:
568
+ """
569
+ closes the file
570
+ """
571
+ for _format in self.output_formats:
572
+ _format.close()
573
+
574
+ # Misc
575
+ # ----------------------------------------
576
+ def _do_log(self, args) -> None:
577
+ """
578
+ log to the requested format outputs
579
+
580
+ :param args: the arguments to log
581
+ """
582
+ for _format in self.output_formats:
583
+ if isinstance(_format, SeqWriter):
584
+ _format.write_sequence(map(str, args))
585
+
586
+
587
+ def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
588
+ """
589
+ Configure the current logger.
590
+
591
+ :param folder: the save location
592
+ (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time])
593
+ :param format_strings: the output logging format
594
+ (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
595
+ :return: The logger object.
596
+ """
597
+ if folder is None:
598
+ folder = os.getenv("SB3_LOGDIR")
599
+ if folder is None:
600
+ folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"))
601
+ assert isinstance(folder, str)
602
+ os.makedirs(folder, exist_ok=True)
603
+
604
+ log_suffix = ""
605
+ if format_strings is None:
606
+ format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",")
607
+
608
+ format_strings = list(filter(None, format_strings))
609
+ output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings]
610
+
611
+ logger = Logger(folder=folder, output_formats=output_formats)
612
+ # Only print when some files will be saved
613
+ if len(format_strings) > 0 and format_strings != ["stdout"]:
614
+ logger.log(f"Logging to {folder}")
615
+ return logger
616
+
617
+
618
+ # ================================================================
619
+ # Readers
620
+ # ================================================================
621
+
622
+
623
+ def read_json(filename: str) -> pandas.DataFrame:
624
+ """
625
+ read a json file using pandas
626
+
627
+ :param filename: the file path to read
628
+ :return: the data in the json
629
+ """
630
+ data = []
631
+ with open(filename) as file_handler:
632
+ for line in file_handler:
633
+ data.append(json.loads(line))
634
+ return pandas.DataFrame(data)
635
+
636
+
637
+ def read_csv(filename: str) -> pandas.DataFrame:
638
+ """
639
+ read a csv file using pandas
640
+
641
+ :param filename: the file path to read
642
+ :return: the data in the csv
643
+ """
644
+ return pandas.read_csv(filename, index_col=None, comment="#")
dexart-release/stable_baselines3/common/monitor.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
2
+
3
+ import csv
4
+ import json
5
+ import os
6
+ import time
7
+ from glob import glob
8
+ from typing import Dict, List, Optional, Tuple, Union
9
+
10
+ import gym
11
+ import numpy as np
12
+ import pandas
13
+
14
+ from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
15
+
16
+
17
+ class Monitor(gym.Wrapper):
18
+ """
19
+ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
20
+
21
+ :param env: The environment
22
+ :param filename: the location to save a log file, can be None for no log
23
+ :param allow_early_resets: allows the reset of the environment before it is done
24
+ :param reset_keywords: extra keywords for the reset call,
25
+ if extra parameters are needed at reset
26
+ :param info_keywords: extra information to log, from the information return of env.step()
27
+ """
28
+
29
+ EXT = "monitor.csv"
30
+
31
+ def __init__(
32
+ self,
33
+ env: gym.Env,
34
+ filename: Optional[str] = None,
35
+ allow_early_resets: bool = True,
36
+ reset_keywords: Tuple[str, ...] = (),
37
+ info_keywords: Tuple[str, ...] = (),
38
+ ):
39
+ super().__init__(env=env)
40
+ self.t_start = time.time()
41
+ if filename is not None:
42
+ self.results_writer = ResultsWriter(
43
+ filename,
44
+ header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
45
+ extra_keys=reset_keywords + info_keywords,
46
+ )
47
+ else:
48
+ self.results_writer = None
49
+ self.reset_keywords = reset_keywords
50
+ self.info_keywords = info_keywords
51
+ self.allow_early_resets = allow_early_resets
52
+ self.rewards = None
53
+ self.needs_reset = True
54
+ self.episode_returns = []
55
+ self.episode_lengths = []
56
+ self.episode_times = []
57
+ self.total_steps = 0
58
+ self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
59
+
60
+ def reset(self, **kwargs) -> GymObs:
61
+ """
62
+ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
63
+
64
+ :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
65
+ :return: the first observation of the environment
66
+ """
67
+ if not self.allow_early_resets and not self.needs_reset:
68
+ raise RuntimeError(
69
+ "Tried to reset an environment before done. If you want to allow early resets, "
70
+ "wrap your env with Monitor(env, path, allow_early_resets=True)"
71
+ )
72
+ self.rewards = []
73
+ self.needs_reset = False
74
+ for key in self.reset_keywords:
75
+ value = kwargs.get(key)
76
+ if value is None:
77
+ raise ValueError(f"Expected you to pass keyword argument {key} into reset")
78
+ self.current_reset_info[key] = value
79
+ return self.env.reset(**kwargs)
80
+
81
+ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
82
+ """
83
+ Step the environment with the given action
84
+
85
+ :param action: the action
86
+ :return: observation, reward, done, information
87
+ """
88
+ if self.needs_reset:
89
+ raise RuntimeError("Tried to step environment that needs reset")
90
+ observation, reward, done, info = self.env.step(action)
91
+ self.rewards.append(reward)
92
+ if done:
93
+ self.needs_reset = True
94
+ ep_rew = sum(self.rewards)
95
+ ep_len = len(self.rewards)
96
+ ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
97
+ for key in self.info_keywords:
98
+ ep_info[key] = info[key]
99
+ self.episode_returns.append(ep_rew)
100
+ self.episode_lengths.append(ep_len)
101
+ self.episode_times.append(time.time() - self.t_start)
102
+ ep_info.update(self.current_reset_info)
103
+ if self.results_writer:
104
+ self.results_writer.write_row(ep_info)
105
+ info["episode"] = ep_info
106
+ self.total_steps += 1
107
+ return observation, reward, done, info
108
+
109
+ def close(self) -> None:
110
+ """
111
+ Closes the environment
112
+ """
113
+ super().close()
114
+ if self.results_writer is not None:
115
+ self.results_writer.close()
116
+
117
+ def get_total_steps(self) -> int:
118
+ """
119
+ Returns the total number of timesteps
120
+
121
+ :return:
122
+ """
123
+ return self.total_steps
124
+
125
+ def get_episode_rewards(self) -> List[float]:
126
+ """
127
+ Returns the rewards of all the episodes
128
+
129
+ :return:
130
+ """
131
+ return self.episode_returns
132
+
133
+ def get_episode_lengths(self) -> List[int]:
134
+ """
135
+ Returns the number of timesteps of all the episodes
136
+
137
+ :return:
138
+ """
139
+ return self.episode_lengths
140
+
141
+ def get_episode_times(self) -> List[float]:
142
+ """
143
+ Returns the runtime in seconds of all the episodes
144
+
145
+ :return:
146
+ """
147
+ return self.episode_times
148
+
149
+
150
+ class LoadMonitorResultsError(Exception):
151
+ """
152
+ Raised when loading the monitor log fails.
153
+ """
154
+
155
+ pass
156
+
157
+
158
+ class ResultsWriter:
159
+ """
160
+ A result writer that saves the data from the `Monitor` class
161
+
162
+ :param filename: the location to save a log file, can be None for no log
163
+ :param header: the header dictionary object of the saved csv
164
+ :param reset_keywords: the extra information to log, typically is composed of
165
+ ``reset_keywords`` and ``info_keywords``
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ filename: str = "",
171
+ header: Optional[Dict[str, Union[float, str]]] = None,
172
+ extra_keys: Tuple[str, ...] = (),
173
+ ):
174
+ if header is None:
175
+ header = {}
176
+ if not filename.endswith(Monitor.EXT):
177
+ if os.path.isdir(filename):
178
+ filename = os.path.join(filename, Monitor.EXT)
179
+ else:
180
+ filename = filename + "." + Monitor.EXT
181
+ # Prevent newline issue on Windows, see GH issue #692
182
+ self.file_handler = open(filename, "wt", newline="\n")
183
+ self.file_handler.write("#%s\n" % json.dumps(header))
184
+ self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
185
+ self.logger.writeheader()
186
+ self.file_handler.flush()
187
+
188
+ def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
189
+ """
190
+ Close the file handler
191
+
192
+ :param epinfo: the information on episodic return, length, and time
193
+ """
194
+ if self.logger:
195
+ self.logger.writerow(epinfo)
196
+ self.file_handler.flush()
197
+
198
+ def close(self) -> None:
199
+ """
200
+ Close the file handler
201
+ """
202
+ self.file_handler.close()
203
+
204
+
205
+ def get_monitor_files(path: str) -> List[str]:
206
+ """
207
+ get all the monitor files in the given path
208
+
209
+ :param path: the logging folder
210
+ :return: the log files
211
+ """
212
+ return glob(os.path.join(path, "*" + Monitor.EXT))
213
+
214
+
215
+ def load_results(path: str) -> pandas.DataFrame:
216
+ """
217
+ Load all Monitor logs from a given directory path matching ``*monitor.csv``
218
+
219
+ :param path: the directory path containing the log file(s)
220
+ :return: the logged data
221
+ """
222
+ monitor_files = get_monitor_files(path)
223
+ if len(monitor_files) == 0:
224
+ raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
225
+ data_frames, headers = [], []
226
+ for file_name in monitor_files:
227
+ with open(file_name) as file_handler:
228
+ first_line = file_handler.readline()
229
+ assert first_line[0] == "#"
230
+ header = json.loads(first_line[1:])
231
+ data_frame = pandas.read_csv(file_handler, index_col=None)
232
+ headers.append(header)
233
+ data_frame["t"] += header["t_start"]
234
+ data_frames.append(data_frame)
235
+ data_frame = pandas.concat(data_frames)
236
+ data_frame.sort_values("t", inplace=True)
237
+ data_frame.reset_index(inplace=True)
238
+ data_frame["t"] -= min(header["t_start"] for header in headers)
239
+ return data_frame
dexart-release/stable_baselines3/common/noise.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from abc import ABC, abstractmethod
3
+ from typing import Iterable, List, Optional
4
+
5
+ import numpy as np
6
+
7
+
8
+ class ActionNoise(ABC):
9
+ """
10
+ The action noise base class
11
+ """
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def reset(self) -> None:
17
+ """
18
+ call end of episode reset for the noise
19
+ """
20
+ pass
21
+
22
+ @abstractmethod
23
+ def __call__(self) -> np.ndarray:
24
+ raise NotImplementedError()
25
+
26
+
27
+ class NormalActionNoise(ActionNoise):
28
+ """
29
+ A Gaussian action noise
30
+
31
+ :param mean: the mean value of the noise
32
+ :param sigma: the scale of the noise (std here)
33
+ """
34
+
35
+ def __init__(self, mean: np.ndarray, sigma: np.ndarray):
36
+ self._mu = mean
37
+ self._sigma = sigma
38
+ super().__init__()
39
+
40
+ def __call__(self) -> np.ndarray:
41
+ return np.random.normal(self._mu, self._sigma)
42
+
43
+ def __repr__(self) -> str:
44
+ return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
45
+
46
+
47
+ class OrnsteinUhlenbeckActionNoise(ActionNoise):
48
+ """
49
+ An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
50
+
51
+ Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
52
+
53
+ :param mean: the mean of the noise
54
+ :param sigma: the scale of the noise
55
+ :param theta: the rate of mean reversion
56
+ :param dt: the timestep for the noise
57
+ :param initial_noise: the initial value for the noise output, (if None: 0)
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ mean: np.ndarray,
63
+ sigma: np.ndarray,
64
+ theta: float = 0.15,
65
+ dt: float = 1e-2,
66
+ initial_noise: Optional[np.ndarray] = None,
67
+ ):
68
+ self._theta = theta
69
+ self._mu = mean
70
+ self._sigma = sigma
71
+ self._dt = dt
72
+ self.initial_noise = initial_noise
73
+ self.noise_prev = np.zeros_like(self._mu)
74
+ self.reset()
75
+ super().__init__()
76
+
77
+ def __call__(self) -> np.ndarray:
78
+ noise = (
79
+ self.noise_prev
80
+ + self._theta * (self._mu - self.noise_prev) * self._dt
81
+ + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
82
+ )
83
+ self.noise_prev = noise
84
+ return noise
85
+
86
+ def reset(self) -> None:
87
+ """
88
+ reset the Ornstein Uhlenbeck noise, to the initial position
89
+ """
90
+ self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
91
+
92
+ def __repr__(self) -> str:
93
+ return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
94
+
95
+
96
+ class VectorizedActionNoise(ActionNoise):
97
+ """
98
+ A Vectorized action noise for parallel environments.
99
+
100
+ :param base_noise: ActionNoise The noise generator to use
101
+ :param n_envs: The number of parallel environments
102
+ """
103
+
104
+ def __init__(self, base_noise: ActionNoise, n_envs: int):
105
+ try:
106
+ self.n_envs = int(n_envs)
107
+ assert self.n_envs > 0
108
+ except (TypeError, AssertionError):
109
+ raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
110
+
111
+ self.base_noise = base_noise
112
+ self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
113
+
114
+ def reset(self, indices: Optional[Iterable[int]] = None) -> None:
115
+ """
116
+ Reset all the noise processes, or those listed in indices
117
+
118
+ :param indices: Optional[Iterable[int]] The indices to reset. Default: None.
119
+ If the parameter is None, then all processes are reset to their initial position.
120
+ """
121
+ if indices is None:
122
+ indices = range(len(self.noises))
123
+
124
+ for index in indices:
125
+ self.noises[index].reset()
126
+
127
+ def __repr__(self) -> str:
128
+ return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"
129
+
130
+ def __call__(self) -> np.ndarray:
131
+ """
132
+ Generate and stack the action noise from each noise object
133
+ """
134
+ noise = np.stack([noise() for noise in self.noises])
135
+ return noise
136
+
137
+ @property
138
+ def base_noise(self) -> ActionNoise:
139
+ return self._base_noise
140
+
141
+ @base_noise.setter
142
+ def base_noise(self, base_noise: ActionNoise) -> None:
143
+ if base_noise is None:
144
+ raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
145
+ if not isinstance(base_noise, ActionNoise):
146
+ raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
147
+ self._base_noise = base_noise
148
+
149
+ @property
150
+ def noises(self) -> List[ActionNoise]:
151
+ return self._noises
152
+
153
+ @noises.setter
154
+ def noises(self, noises: List[ActionNoise]) -> None:
155
+ noises = list(noises) # raises TypeError if not iterable
156
+ assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
157
+
158
+ different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
159
+
160
+ if len(different_types):
161
+ raise ValueError(
162
+ f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise)
163
+ )
164
+
165
+ self._noises = noises
166
+ for noise in noises:
167
+ noise.reset()
dexart-release/stable_baselines3/common/on_policy_algorithm.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
3
+
4
+ import gym
5
+ import numpy as np
6
+ import torch as th
7
+
8
+ from stable_baselines3.common.base_class import BaseAlgorithm
9
+ from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
10
+ from stable_baselines3.common.callbacks import BaseCallback
11
+ from stable_baselines3.common.policies import ActorCriticPolicy
12
+ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13
+ from stable_baselines3.common.utils import obs_as_tensor, safe_mean
14
+ from stable_baselines3.common.vec_env import VecEnv
15
+
16
+
17
+ class OnPolicyAlgorithm(BaseAlgorithm):
18
+ """
19
+ The base for On-Policy algorithms (ex: A2C/PPO).
20
+
21
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
22
+ :param env: The environment to learn from (if registered in Gym, can be str)
23
+ :param learning_rate: The learning rate, it can be a function
24
+ of the current progress remaining (from 1 to 0)
25
+ :param n_steps: The number of steps to run for each environment per update
26
+ (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
27
+ :param gamma: Discount factor
28
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
29
+ Equivalent to classic advantage when set to 1.
30
+ :param ent_coef: Entropy coefficient for the loss calculation
31
+ :param vf_coef: Value function coefficient for the loss calculation
32
+ :param max_grad_norm: The maximum value for the gradient clipping
33
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
34
+ instead of action noise exploration (default: False)
35
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
36
+ Default: -1 (only sample at the beginning of the rollout)
37
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
38
+ :param create_eval_env: Whether to create a second environment that will be
39
+ used for evaluating the agent periodically. (Only available when passing string for the environment)
40
+ :param monitor_wrapper: When creating an environment, whether to wrap it
41
+ or not in a Monitor wrapper.
42
+ :param policy_kwargs: additional arguments to be passed to the policy on creation
43
+ :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
44
+ :param seed: Seed for the pseudo random generators
45
+ :param device: Device (cpu, cuda, ...) on which the code should be run.
46
+ Setting it to auto, the code will be run on the GPU if possible.
47
+ :param _init_setup_model: Whether or not to build the network at the creation of the instance
48
+ :param supported_action_spaces: The action spaces supported by the algorithm.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ policy: Union[str, Type[ActorCriticPolicy]],
54
+ env: Union[GymEnv, str],
55
+ learning_rate: Union[float, Schedule],
56
+ n_steps: int,
57
+ gamma: float,
58
+ gae_lambda: float,
59
+ ent_coef: float,
60
+ vf_coef: float,
61
+ max_grad_norm: float,
62
+ use_sde: bool,
63
+ sde_sample_freq: int,
64
+ tensorboard_log: Optional[str] = None,
65
+ create_eval_env: bool = False,
66
+ monitor_wrapper: bool = True,
67
+ policy_kwargs: Optional[Dict[str, Any]] = None,
68
+ verbose: int = 0,
69
+ seed: Optional[int] = None,
70
+ device: Union[th.device, str] = "auto",
71
+ _init_setup_model: bool = True,
72
+ supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
73
+ ):
74
+
75
+ super().__init__(
76
+ policy=policy,
77
+ env=env,
78
+ learning_rate=learning_rate,
79
+ policy_kwargs=policy_kwargs,
80
+ verbose=verbose,
81
+ device=device,
82
+ use_sde=use_sde,
83
+ sde_sample_freq=sde_sample_freq,
84
+ create_eval_env=create_eval_env,
85
+ support_multi_env=True,
86
+ seed=seed,
87
+ tensorboard_log=tensorboard_log,
88
+ supported_action_spaces=supported_action_spaces,
89
+ )
90
+
91
+ self.n_steps = n_steps
92
+ self.gamma = gamma
93
+ self.gae_lambda = gae_lambda
94
+ self.ent_coef = ent_coef
95
+ self.vf_coef = vf_coef
96
+ self.max_grad_norm = max_grad_norm
97
+ self.rollout_buffer = None
98
+
99
+ self.last_rollout_reward = -np.inf
100
+ self.need_restore = False
101
+ self.last_policy_saved: List[Dict] = [{}, {}]
102
+ self.current_restore_step = 0
103
+
104
+ if _init_setup_model:
105
+ self._setup_model()
106
+
107
+ def _setup_model(self) -> None:
108
+ self._setup_lr_schedule()
109
+ self.set_random_seed(self.seed)
110
+
111
+ buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer
112
+
113
+ self.rollout_buffer = buffer_cls(
114
+ self.n_steps,
115
+ self.observation_space,
116
+ self.action_space,
117
+ device=self.device,
118
+ gamma=self.gamma,
119
+ gae_lambda=self.gae_lambda,
120
+ n_envs=self.n_envs,
121
+ )
122
+ self.policy = self.policy_class( # pytype:disable=not-instantiable
123
+ self.observation_space,
124
+ self.action_space,
125
+ self.lr_schedule,
126
+ use_sde=self.use_sde,
127
+ **self.policy_kwargs # pytype:disable=not-instantiable
128
+ )
129
+ self.policy = self.policy.to(self.device)
130
+
131
+ def collect_rollouts(
132
+ self,
133
+ env: VecEnv,
134
+ callback: BaseCallback,
135
+ rollout_buffer: RolloutBuffer,
136
+ n_rollout_steps: int,
137
+ ) -> bool:
138
+ """
139
+ Collect experiences using the current policy and fill a ``RolloutBuffer``.
140
+ The term rollout here refers to the model-free notion and should not
141
+ be used with the concept of rollout used in model-based RL or planning.
142
+
143
+ :param env: The training environment
144
+ :param callback: Callback that will be called at each step
145
+ (and at the beginning and end of the rollout)
146
+ :param rollout_buffer: Buffer to fill with rollouts
147
+ :param n_steps: Number of experiences to collect per environment
148
+ :return: True if function returned with at least `n_rollout_steps`
149
+ collected, False if callback terminated rollout prematurely.
150
+ """
151
+ assert self._last_obs is not None, "No previous observation was provided"
152
+ # Switch to eval mode (this affects batch norm / dropout)
153
+ self.policy.set_training_mode(False)
154
+ last_episode_reward = self.last_rollout_reward
155
+ self.last_rollout_reward = 0
156
+ num_rollouts = 0
157
+ n_steps = 0
158
+ rollout_buffer.reset()
159
+ # Sample new weights for the state dependent exploration
160
+ if self.use_sde:
161
+ self.policy.reset_noise(env.num_envs)
162
+ callback.on_rollout_start()
163
+ while n_steps < n_rollout_steps:
164
+ if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
165
+ # Sample a new noise matrix
166
+ self.policy.reset_noise(env.num_envs)
167
+
168
+ with th.no_grad():
169
+ # Convert to pytorch tensor or to TensorDict
170
+ obs_tensor = obs_as_tensor(self._last_obs, self.device)
171
+ actions, values, log_probs = self.policy(obs_tensor)
172
+ actions = actions.cpu().numpy()
173
+
174
+ # Rescale and perform action
175
+ clipped_actions = actions
176
+ # Clip the actions to avoid out of bound error
177
+ if isinstance(self.action_space, gym.spaces.Box):
178
+ clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
179
+
180
+ new_obs, rewards, dones, infos = env.step(clipped_actions)
181
+
182
+ self.num_timesteps += env.num_envs
183
+
184
+ # Give access to local variables
185
+ callback.update_locals(locals())
186
+ if callback.on_step() is False:
187
+ return False
188
+
189
+ self._update_info_buffer(infos)
190
+ n_steps += 1
191
+
192
+ if isinstance(self.action_space, gym.spaces.Discrete):
193
+ # Reshape in case of discrete action
194
+ actions = actions.reshape(-1, 1)
195
+
196
+ # Handle timeout by bootstraping with value function
197
+ # see GitHub issue #633
198
+ for idx, done in enumerate(dones):
199
+ if (
200
+ done
201
+ and infos[idx].get("terminal_observation") is not None
202
+ and infos[idx].get("TimeLimit.truncated", False)
203
+ ):
204
+ terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
205
+ with th.no_grad():
206
+ terminal_value = self.policy.predict_values(terminal_obs)[0]
207
+ rewards[idx] += self.gamma * terminal_value
208
+
209
+ if done:
210
+ num_rollouts += 1
211
+
212
+ self.last_rollout_reward += rewards.sum()
213
+
214
+ rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
215
+ self._last_obs = new_obs
216
+ self._last_episode_starts = dones
217
+ with th.no_grad():
218
+ # Compute value for the last timestep
219
+ values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
220
+
221
+ rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
222
+
223
+ self.last_rollout_reward /= num_rollouts
224
+ reward_gap = last_episode_reward - self.last_rollout_reward
225
+ self.need_restore = False
226
+ self.current_restore_step = 0
227
+
228
+ callback.on_rollout_end()
229
+
230
+ return True
231
+
232
+ def train(self) -> None:
233
+ """
234
+ Consume current rollout data and update policy parameters.
235
+ Implemented by individual algorithms.
236
+ """
237
+ raise NotImplementedError
238
+
239
+ def learn(
240
+ self,
241
+ total_timesteps: int,
242
+ callback: MaybeCallback = None,
243
+ log_interval: int = 1,
244
+ eval_env: Optional[GymEnv] = None,
245
+ eval_freq: int = -1,
246
+ n_eval_episodes: int = 5,
247
+ tb_log_name: str = "OnPolicyAlgorithm",
248
+ eval_log_path: Optional[str] = None,
249
+ reset_num_timesteps: bool = True,
250
+ iter_start=0,
251
+ ) -> "OnPolicyAlgorithm":
252
+ iteration = iter_start
253
+
254
+ total_timesteps, callback = self._setup_learn(
255
+ total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps,
256
+ tb_log_name
257
+ )
258
+
259
+ callback.on_training_start(locals(), globals())
260
+
261
+ while self.num_timesteps < total_timesteps:
262
+
263
+ x = time.time()
264
+ continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer,
265
+ n_rollout_steps=self.n_steps)
266
+ print("Rollout time:", time.time() - x)
267
+
268
+ if continue_training is False:
269
+ break
270
+
271
+ if self.need_restore and self.current_restore_step < 5:
272
+ print(f"Large performance drop detected. Restore previous model.")
273
+ self.set_parameters(self.last_policy_saved[0], exact_match=True, device=self.device)
274
+ continue
275
+ else:
276
+ self.last_policy_saved[0] = self.last_policy_saved[1]
277
+ self.last_policy_saved[1] = self.get_parameters()
278
+
279
+ iteration += 1
280
+ self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
281
+
282
+ # Display training infos
283
+
284
+ if log_interval is not None and iteration % log_interval == 0:
285
+ fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
286
+ self.logger.record("time/iterations", iteration, exclude="wandb")
287
+ self.logger.record("rollout/rollout_rew_mean", self.last_rollout_reward)
288
+ if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
289
+ self.logger.record("rollout/ep_rew_mean",
290
+ safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
291
+ self.logger.record("rollout/ep_len_mean",
292
+ safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
293
+ self.logger.record("time/fps", fps)
294
+ self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
295
+ self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
296
+ self.logger.dump(step=iteration)
297
+
298
+ x = time.time()
299
+ self.train()
300
+ print("Train time:", time.time() - x)
301
+
302
+ callback.on_training_end()
303
+
304
+ return self
305
+
306
+ def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
307
+ state_dicts = ["policy", "policy.optimizer"]
308
+
309
+ return state_dicts, []
310
+
311
+ def _excluded_save_params(self) -> List[str]:
312
+ """
313
+ Returns the names of the parameters that should be excluded from being
314
+ saved by pickling. E.g. replay buffers are skipped by default
315
+ as they take up a lot of space. PyTorch variables should be excluded
316
+ with this so they can be stored with ``th.save``.
317
+
318
+ :return: List of parameters that should be excluded from being saved with pickle.
319
+ """
320
+ return super()._excluded_save_params() + ["last_policy_saved"]